Update app.py
Browse files
app.py
CHANGED
|
@@ -14,6 +14,7 @@ import pickle
|
|
| 14 |
import torch
|
| 15 |
from transformers import pipeline
|
| 16 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 17 |
from langchain_community.vectorstores import FAISS
|
| 18 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 19 |
from langchain.schema import Document
|
|
@@ -35,37 +36,116 @@ CONFIG = {
|
|
| 35 |
"max_tokens": 350,
|
| 36 |
}
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# ============================================================================
|
| 39 |
# INITIALIZE MODELS
|
| 40 |
# ============================================================================
|
| 41 |
|
| 42 |
def initialize_llm():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
logger.info("π Initializing FREE local language model...")
|
| 44 |
model_name = "google/flan-t5-large"
|
| 45 |
-
|
| 46 |
try:
|
| 47 |
logger.info(f" Loading {model_name}...")
|
| 48 |
device = 0 if torch.cuda.is_available() else -1
|
| 49 |
-
|
| 50 |
model_kwargs = {"low_cpu_mem_usage": True}
|
| 51 |
-
|
| 52 |
llm_client = pipeline(
|
| 53 |
"text2text-generation",
|
| 54 |
model=model_name,
|
| 55 |
device=device,
|
| 56 |
model_kwargs=model_kwargs
|
| 57 |
)
|
| 58 |
-
|
| 59 |
CONFIG["llm_model"] = model_name
|
| 60 |
CONFIG["model_type"] = "t5"
|
| 61 |
logger.info(f"β
LLM initialized: {model_name}")
|
| 62 |
logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
|
| 63 |
return llm_client
|
| 64 |
-
|
| 65 |
except Exception as e:
|
| 66 |
logger.error(f"β Failed to load model: {str(e)}")
|
| 67 |
raise Exception(f"Failed to initialize LLM: {str(e)}")
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def initialize_embeddings():
|
| 70 |
logger.info("π Initializing embeddings model...")
|
| 71 |
|
|
@@ -185,6 +265,109 @@ def load_vector_store(embeddings):
|
|
| 185 |
# RAG PIPELINE FUNCTIONS
|
| 186 |
# ============================================================================
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def retrieve_knowledge_langchain(
|
| 189 |
query: str,
|
| 190 |
vectorstore,
|
|
@@ -277,14 +460,19 @@ def generate_llm_answer(
|
|
| 277 |
# (too short or truncated), fall back to an iterative multi-pass generator
|
| 278 |
# that appends continuation chunks until we reach the target word count.
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 283 |
max_iterations = 4
|
| 284 |
|
| 285 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 286 |
logger.info(f" β Model call (temp={temperature}, max_new_tokens={max_new_tokens})")
|
| 287 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
out = llm_client(
|
| 289 |
prompt,
|
| 290 |
max_new_tokens=max_new_tokens,
|
|
@@ -450,6 +638,18 @@ def generate_answer_langchain(
|
|
| 450 |
|
| 451 |
if not llm_answer:
|
| 452 |
logger.error(f" β All 2 LLM attempts failed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
|
| 454 |
|
| 455 |
return llm_answer
|
|
|
|
| 14 |
import torch
|
| 15 |
from transformers import pipeline
|
| 16 |
from sentence_transformers import SentenceTransformer
|
| 17 |
+
import requests
|
| 18 |
from langchain_community.vectorstores import FAISS
|
| 19 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 20 |
from langchain.schema import Document
|
|
|
|
| 36 |
"max_tokens": 350,
|
| 37 |
}
|
| 38 |
|
| 39 |
+
# Remote inference config (optional). If `HF_INFERENCE_API_KEY` is set in the
|
| 40 |
+
# environment, the app will prefer calling the Hugging Face Inference API (remote
|
| 41 |
+
# hosted model) which can generate longer outputs faster than a CPU-bound local
|
| 42 |
+
# model. Set `HF_INFERENCE_MODEL` to choose the remote model (instruction-tuned
|
| 43 |
+
# model recommended).
|
| 44 |
+
USE_REMOTE_LLM = False
|
| 45 |
+
REMOTE_LLM_MODEL = os.environ.get("HF_INFERENCE_MODEL", "tiiuae/falcon-7b-instruct")
|
| 46 |
+
|
| 47 |
+
# Prefer the environment variable, but also allow a local token file for users
|
| 48 |
+
# who don't know how to set env vars. Create a file named `hf_token.txt` in the
|
| 49 |
+
# project root containing only the token (no newline is necessary). DO NOT
|
| 50 |
+
# commit that file to version control. A .gitignore entry will be added.
|
| 51 |
+
HF_INFERENCE_API_KEY = os.environ.get("HF_INFERENCE_API_KEY")
|
| 52 |
+
if not HF_INFERENCE_API_KEY:
|
| 53 |
+
try:
|
| 54 |
+
token_path = Path("hf_token.txt")
|
| 55 |
+
if token_path.exists():
|
| 56 |
+
HF_INFERENCE_API_KEY = token_path.read_text(encoding="utf-8").strip()
|
| 57 |
+
logger.info("Loaded HF token from hf_token.txt (ensure this file is private and not committed)")
|
| 58 |
+
except Exception:
|
| 59 |
+
logger.warning("Could not read hf_token.txt for HF token")
|
| 60 |
+
|
| 61 |
+
if HF_INFERENCE_API_KEY:
|
| 62 |
+
USE_REMOTE_LLM = True
|
| 63 |
+
|
| 64 |
# ============================================================================
|
| 65 |
# INITIALIZE MODELS
|
| 66 |
# ============================================================================
|
| 67 |
|
| 68 |
def initialize_llm():
|
| 69 |
+
# If a remote HF Inference API key is provided, we won't instantiate a local
|
| 70 |
+
# heavy model; instead generation will be performed via the HTTP API.
|
| 71 |
+
global USE_REMOTE_LLM, REMOTE_LLM_MODEL
|
| 72 |
+
if USE_REMOTE_LLM:
|
| 73 |
+
logger.info(f"π Using remote Hugging Face Inference model: {REMOTE_LLM_MODEL}")
|
| 74 |
+
CONFIG["llm_model"] = REMOTE_LLM_MODEL
|
| 75 |
+
CONFIG["model_type"] = "remote"
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
logger.info("π Initializing FREE local language model...")
|
| 79 |
model_name = "google/flan-t5-large"
|
| 80 |
+
|
| 81 |
try:
|
| 82 |
logger.info(f" Loading {model_name}...")
|
| 83 |
device = 0 if torch.cuda.is_available() else -1
|
| 84 |
+
|
| 85 |
model_kwargs = {"low_cpu_mem_usage": True}
|
| 86 |
+
|
| 87 |
llm_client = pipeline(
|
| 88 |
"text2text-generation",
|
| 89 |
model=model_name,
|
| 90 |
device=device,
|
| 91 |
model_kwargs=model_kwargs
|
| 92 |
)
|
| 93 |
+
|
| 94 |
CONFIG["llm_model"] = model_name
|
| 95 |
CONFIG["model_type"] = "t5"
|
| 96 |
logger.info(f"β
LLM initialized: {model_name}")
|
| 97 |
logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
|
| 98 |
return llm_client
|
| 99 |
+
|
| 100 |
except Exception as e:
|
| 101 |
logger.error(f"β Failed to load model: {str(e)}")
|
| 102 |
raise Exception(f"Failed to initialize LLM: {str(e)}")
|
| 103 |
|
| 104 |
+
|
| 105 |
+
def remote_generate(prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9) -> str:
|
| 106 |
+
"""Call the Hugging Face Inference API for remote generation. Requires
|
| 107 |
+
`HF_INFERENCE_API_KEY` env var to be set and a model name in
|
| 108 |
+
`REMOTE_LLM_MODEL`.
|
| 109 |
+
"""
|
| 110 |
+
if not HF_INFERENCE_API_KEY:
|
| 111 |
+
raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
|
| 112 |
+
|
| 113 |
+
url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}"
|
| 114 |
+
headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}"}
|
| 115 |
+
payload = {
|
| 116 |
+
"inputs": prompt,
|
| 117 |
+
"parameters": {
|
| 118 |
+
"max_new_tokens": max_new_tokens,
|
| 119 |
+
"temperature": temperature,
|
| 120 |
+
"top_p": top_p,
|
| 121 |
+
"return_full_text": False
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
logger.info(f" β Remote inference request to {REMOTE_LLM_MODEL} (tokens={max_new_tokens}, temp={temperature})")
|
| 126 |
+
r = requests.post(url, headers=headers, json=payload, timeout=60)
|
| 127 |
+
if r.status_code != 200:
|
| 128 |
+
logger.error(f" β Remote inference error {r.status_code}: {r.text[:200]}")
|
| 129 |
+
return ""
|
| 130 |
+
|
| 131 |
+
result = r.json()
|
| 132 |
+
if isinstance(result, dict) and result.get("error"):
|
| 133 |
+
logger.error(f" β Remote inference returned error: {result.get('error')}")
|
| 134 |
+
return ""
|
| 135 |
+
|
| 136 |
+
# The HF Inference API can return a list of generated outputs or text
|
| 137 |
+
if isinstance(result, list) and result:
|
| 138 |
+
# entries may be strings or dicts like {"generated_text": "..."}
|
| 139 |
+
first = result[0]
|
| 140 |
+
if isinstance(first, dict):
|
| 141 |
+
return first.get("generated_text", "").strip()
|
| 142 |
+
return str(first).strip()
|
| 143 |
+
|
| 144 |
+
if isinstance(result, dict) and "generated_text" in result:
|
| 145 |
+
return result["generated_text"].strip()
|
| 146 |
+
|
| 147 |
+
return str(result).strip()
|
| 148 |
+
|
| 149 |
def initialize_embeddings():
|
| 150 |
logger.info("π Initializing embeddings model...")
|
| 151 |
|
|
|
|
| 265 |
# RAG PIPELINE FUNCTIONS
|
| 266 |
# ============================================================================
|
| 267 |
|
| 268 |
+
def generate_extractive_answer(query: str, retrieved_docs: List[Document]) -> Optional[str]:
|
| 269 |
+
"""Build a long-form answer from retrieved documents using extractive
|
| 270 |
+
selection + templated transitions. This avoids calling the LLM when it
|
| 271 |
+
repeatedly fails or returns very short outputs.
|
| 272 |
+
"""
|
| 273 |
+
logger.info(f"π§ Running extractive fallback for: '{query}'")
|
| 274 |
+
|
| 275 |
+
# Collect text and split into sentences
|
| 276 |
+
import re
|
| 277 |
+
|
| 278 |
+
all_text = "\n\n".join([d.page_content for d in retrieved_docs])
|
| 279 |
+
# Basic sentence split (keeps punctuation)
|
| 280 |
+
sentences = re.split(r'(?<=[.!?])\s+', all_text)
|
| 281 |
+
sentences = [s.strip() for s in sentences if len(s.strip()) > 30]
|
| 282 |
+
|
| 283 |
+
if not sentences:
|
| 284 |
+
logger.warning(" β No sentences found in retrieved documents for extractive fallback")
|
| 285 |
+
return None
|
| 286 |
+
|
| 287 |
+
# Scoring: keyword overlap with query and fashion terms
|
| 288 |
+
query_tokens = set(re.findall(r"\w+", query.lower()))
|
| 289 |
+
fashion_keywords = set(["outfit","wear","wardrobe","style","colors","color","layer","layering",
|
| 290 |
+
"blazer","trousers","dress","shirt","shoes","boots","sweater","jacket",
|
| 291 |
+
"care","wash","dry","clean","wool","cotton","silk","linen","fit","tailor",
|
| 292 |
+
"versatile","neutral","accessory","belt","bag","occasion","season","fall"])
|
| 293 |
+
keywords = query_tokens.union(fashion_keywords)
|
| 294 |
+
|
| 295 |
+
scored = []
|
| 296 |
+
for s in sentences:
|
| 297 |
+
s_tokens = set(re.findall(r"\w+", s.lower()))
|
| 298 |
+
score = len(s_tokens & keywords)
|
| 299 |
+
# length bonus to prefer richer sentences
|
| 300 |
+
score += min(3, len(s.split()) // 20)
|
| 301 |
+
scored.append((score, s))
|
| 302 |
+
|
| 303 |
+
scored.sort(key=lambda x: x[0], reverse=True)
|
| 304 |
+
top_sentences = [s for _, s in scored[:60]]
|
| 305 |
+
|
| 306 |
+
# Build structured sections using top sentences + templates
|
| 307 |
+
def pick(n, start=0):
|
| 308 |
+
return top_sentences[start:start+n]
|
| 309 |
+
|
| 310 |
+
intro = []
|
| 311 |
+
intro.extend(pick(2, 0))
|
| 312 |
+
key_items = pick(8, 2)
|
| 313 |
+
styling = pick(8, 10)
|
| 314 |
+
care = pick(6, 18)
|
| 315 |
+
conclusion = pick(4, 24)
|
| 316 |
+
|
| 317 |
+
# Add handcrafted, helpful transitions to improve flow
|
| 318 |
+
template_intro = f"Here's a detailed answer to '{query}'. I'll cover essential wardrobe items, styling tips, and care advice so you can apply these suggestions practically."
|
| 319 |
+
|
| 320 |
+
# Ensure care advice includes the user's specific care example if present or add it
|
| 321 |
+
care_text = "\n\n".join(care)
|
| 322 |
+
if "dry clean" not in care_text.lower() and "hand wash" not in care_text.lower():
|
| 323 |
+
care_text += "\n\nDry clean or hand wash in cold water with wool-specific detergent. Never wring out wool - gently squeeze excess water and lay flat to dry on a towel."
|
| 324 |
+
|
| 325 |
+
parts = []
|
| 326 |
+
parts.append(template_intro)
|
| 327 |
+
if intro:
|
| 328 |
+
parts.append(" ".join(intro))
|
| 329 |
+
if key_items:
|
| 330 |
+
parts.append("Key wardrobe items to prioritize:")
|
| 331 |
+
parts.append(" ".join(key_items))
|
| 332 |
+
if styling:
|
| 333 |
+
parts.append("Practical styling tips:")
|
| 334 |
+
parts.append(" ".join(styling))
|
| 335 |
+
if care_text:
|
| 336 |
+
parts.append("Care & maintenance:")
|
| 337 |
+
parts.append(care_text)
|
| 338 |
+
if conclusion:
|
| 339 |
+
parts.append("Wrapping up:")
|
| 340 |
+
parts.append(" ".join(conclusion))
|
| 341 |
+
|
| 342 |
+
# Combine and refine spacing
|
| 343 |
+
answer = "\n\n".join(parts)
|
| 344 |
+
|
| 345 |
+
# Post-process: ensure target length (approximately 400-700 words)
|
| 346 |
+
words = answer.split()
|
| 347 |
+
word_count = len(words)
|
| 348 |
+
|
| 349 |
+
# If too short, append templated practical paragraphs built from keywords
|
| 350 |
+
if word_count < 380:
|
| 351 |
+
logger.info(f" β Extractive answer short ({word_count} words). Appending templated paragraphs.")
|
| 352 |
+
extra_paragraphs = []
|
| 353 |
+
extra_paragraphs.append("A reliable strategy is to build around versatile, neutral pieces: a well-fitted blazer, tailored trousers, a versatile dress, and quality shoes. These items can be mixed and matched for many occasions.")
|
| 354 |
+
extra_paragraphs.append("Focus on fit and fabric: ensure key items are well-tailored, prioritize breathable fabrics for comfort, and choose merino or wool blends for colder seasons to layer effectively.")
|
| 355 |
+
extra_paragraphs.append("Layering is essential for transitional weather; combine a lightweight sweater under a jacket, and carry a scarf for added warmth and visual interest.")
|
| 356 |
+
extra_paragraphs.append("Accessories like belts, a structured bag, and minimal jewelry can elevate basic outfits without extra effort. Neutral colors increase versatility and pair well with bolder accents.")
|
| 357 |
+
answer += "\n\n" + "\n\n".join(extra_paragraphs)
|
| 358 |
+
words = answer.split()
|
| 359 |
+
word_count = len(words)
|
| 360 |
+
|
| 361 |
+
# If still too long, truncate gracefully
|
| 362 |
+
if word_count > 750:
|
| 363 |
+
words = words[:700]
|
| 364 |
+
answer = " ".join(words) + '...'
|
| 365 |
+
word_count = 700
|
| 366 |
+
|
| 367 |
+
logger.info(f" β
Extractive answer ready ({word_count} words)")
|
| 368 |
+
return answer
|
| 369 |
+
|
| 370 |
+
|
| 371 |
def retrieve_knowledge_langchain(
|
| 372 |
query: str,
|
| 373 |
vectorstore,
|
|
|
|
| 460 |
# (too short or truncated), fall back to an iterative multi-pass generator
|
| 461 |
# that appends continuation chunks until we reach the target word count.
|
| 462 |
|
| 463 |
+
# Adjusted targets for faster generation and user's request: aim ~350 words
|
| 464 |
+
target_min_words = 320
|
| 465 |
+
target_max_words = 420
|
| 466 |
+
chunk_target_words = 140
|
| 467 |
max_iterations = 4
|
| 468 |
|
| 469 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 470 |
logger.info(f" β Model call (temp={temperature}, max_new_tokens={max_new_tokens})")
|
| 471 |
try:
|
| 472 |
+
if USE_REMOTE_LLM:
|
| 473 |
+
# Use remote Hugging Face Inference API
|
| 474 |
+
return remote_generate(prompt, max_new_tokens, temperature, top_p)
|
| 475 |
+
|
| 476 |
out = llm_client(
|
| 477 |
prompt,
|
| 478 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 638 |
|
| 639 |
if not llm_answer:
|
| 640 |
logger.error(f" β All 2 LLM attempts failed")
|
| 641 |
+
# Fallback: use an extractive + template-based generator to produce a long,
|
| 642 |
+
# natural-flowing answer without using the LLM. This helps when the model
|
| 643 |
+
# repeatedly returns very short outputs or errors.
|
| 644 |
+
try:
|
| 645 |
+
logger.info(" β Using extractive fallback generator")
|
| 646 |
+
fallback = generate_extractive_answer(query, retrieved_docs)
|
| 647 |
+
if fallback:
|
| 648 |
+
logger.info(" β
Extractive fallback produced an answer")
|
| 649 |
+
return fallback
|
| 650 |
+
except Exception as e:
|
| 651 |
+
logger.error(f" β Extractive fallback error: {e}")
|
| 652 |
+
|
| 653 |
return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
|
| 654 |
|
| 655 |
return llm_answer
|