hamxaameer commited on
Commit
4fc4e4b
Β·
verified Β·
1 Parent(s): 0e156ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -8
app.py CHANGED
@@ -14,6 +14,7 @@ import pickle
14
  import torch
15
  from transformers import pipeline
16
  from sentence_transformers import SentenceTransformer
 
17
  from langchain_community.vectorstores import FAISS
18
  from langchain_community.embeddings import HuggingFaceEmbeddings
19
  from langchain.schema import Document
@@ -35,37 +36,116 @@ CONFIG = {
35
  "max_tokens": 350,
36
  }
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # ============================================================================
39
  # INITIALIZE MODELS
40
  # ============================================================================
41
 
42
  def initialize_llm():
 
 
 
 
 
 
 
 
 
43
  logger.info("πŸ”„ Initializing FREE local language model...")
44
  model_name = "google/flan-t5-large"
45
-
46
  try:
47
  logger.info(f" Loading {model_name}...")
48
  device = 0 if torch.cuda.is_available() else -1
49
-
50
  model_kwargs = {"low_cpu_mem_usage": True}
51
-
52
  llm_client = pipeline(
53
  "text2text-generation",
54
  model=model_name,
55
  device=device,
56
  model_kwargs=model_kwargs
57
  )
58
-
59
  CONFIG["llm_model"] = model_name
60
  CONFIG["model_type"] = "t5"
61
  logger.info(f"βœ… LLM initialized: {model_name}")
62
  logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
63
  return llm_client
64
-
65
  except Exception as e:
66
  logger.error(f"❌ Failed to load model: {str(e)}")
67
  raise Exception(f"Failed to initialize LLM: {str(e)}")
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def initialize_embeddings():
70
  logger.info("πŸ”„ Initializing embeddings model...")
71
 
@@ -185,6 +265,109 @@ def load_vector_store(embeddings):
185
  # RAG PIPELINE FUNCTIONS
186
  # ============================================================================
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def retrieve_knowledge_langchain(
189
  query: str,
190
  vectorstore,
@@ -277,14 +460,19 @@ def generate_llm_answer(
277
  # (too short or truncated), fall back to an iterative multi-pass generator
278
  # that appends continuation chunks until we reach the target word count.
279
 
280
- target_min_words = 400
281
- target_max_words = 700
282
- chunk_target_words = 200
 
283
  max_iterations = 4
284
 
285
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
286
  logger.info(f" β†’ Model call (temp={temperature}, max_new_tokens={max_new_tokens})")
287
  try:
 
 
 
 
288
  out = llm_client(
289
  prompt,
290
  max_new_tokens=max_new_tokens,
@@ -450,6 +638,18 @@ def generate_answer_langchain(
450
 
451
  if not llm_answer:
452
  logger.error(f" βœ— All 2 LLM attempts failed")
 
 
 
 
 
 
 
 
 
 
 
 
453
  return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
454
 
455
  return llm_answer
 
14
  import torch
15
  from transformers import pipeline
16
  from sentence_transformers import SentenceTransformer
17
+ import requests
18
  from langchain_community.vectorstores import FAISS
19
  from langchain_community.embeddings import HuggingFaceEmbeddings
20
  from langchain.schema import Document
 
36
  "max_tokens": 350,
37
  }
38
 
39
+ # Remote inference config (optional). If `HF_INFERENCE_API_KEY` is set in the
40
+ # environment, the app will prefer calling the Hugging Face Inference API (remote
41
+ # hosted model) which can generate longer outputs faster than a CPU-bound local
42
+ # model. Set `HF_INFERENCE_MODEL` to choose the remote model (instruction-tuned
43
+ # model recommended).
44
+ USE_REMOTE_LLM = False
45
+ REMOTE_LLM_MODEL = os.environ.get("HF_INFERENCE_MODEL", "tiiuae/falcon-7b-instruct")
46
+
47
+ # Prefer the environment variable, but also allow a local token file for users
48
+ # who don't know how to set env vars. Create a file named `hf_token.txt` in the
49
+ # project root containing only the token (no newline is necessary). DO NOT
50
+ # commit that file to version control. A .gitignore entry will be added.
51
+ HF_INFERENCE_API_KEY = os.environ.get("HF_INFERENCE_API_KEY")
52
+ if not HF_INFERENCE_API_KEY:
53
+ try:
54
+ token_path = Path("hf_token.txt")
55
+ if token_path.exists():
56
+ HF_INFERENCE_API_KEY = token_path.read_text(encoding="utf-8").strip()
57
+ logger.info("Loaded HF token from hf_token.txt (ensure this file is private and not committed)")
58
+ except Exception:
59
+ logger.warning("Could not read hf_token.txt for HF token")
60
+
61
+ if HF_INFERENCE_API_KEY:
62
+ USE_REMOTE_LLM = True
63
+
64
  # ============================================================================
65
  # INITIALIZE MODELS
66
  # ============================================================================
67
 
68
  def initialize_llm():
69
+ # If a remote HF Inference API key is provided, we won't instantiate a local
70
+ # heavy model; instead generation will be performed via the HTTP API.
71
+ global USE_REMOTE_LLM, REMOTE_LLM_MODEL
72
+ if USE_REMOTE_LLM:
73
+ logger.info(f"πŸ”„ Using remote Hugging Face Inference model: {REMOTE_LLM_MODEL}")
74
+ CONFIG["llm_model"] = REMOTE_LLM_MODEL
75
+ CONFIG["model_type"] = "remote"
76
+ return None
77
+
78
  logger.info("πŸ”„ Initializing FREE local language model...")
79
  model_name = "google/flan-t5-large"
80
+
81
  try:
82
  logger.info(f" Loading {model_name}...")
83
  device = 0 if torch.cuda.is_available() else -1
84
+
85
  model_kwargs = {"low_cpu_mem_usage": True}
86
+
87
  llm_client = pipeline(
88
  "text2text-generation",
89
  model=model_name,
90
  device=device,
91
  model_kwargs=model_kwargs
92
  )
93
+
94
  CONFIG["llm_model"] = model_name
95
  CONFIG["model_type"] = "t5"
96
  logger.info(f"βœ… LLM initialized: {model_name}")
97
  logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
98
  return llm_client
99
+
100
  except Exception as e:
101
  logger.error(f"❌ Failed to load model: {str(e)}")
102
  raise Exception(f"Failed to initialize LLM: {str(e)}")
103
 
104
+
105
+ def remote_generate(prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9) -> str:
106
+ """Call the Hugging Face Inference API for remote generation. Requires
107
+ `HF_INFERENCE_API_KEY` env var to be set and a model name in
108
+ `REMOTE_LLM_MODEL`.
109
+ """
110
+ if not HF_INFERENCE_API_KEY:
111
+ raise Exception("HF_INFERENCE_API_KEY not set for remote generation")
112
+
113
+ url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}"
114
+ headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}"}
115
+ payload = {
116
+ "inputs": prompt,
117
+ "parameters": {
118
+ "max_new_tokens": max_new_tokens,
119
+ "temperature": temperature,
120
+ "top_p": top_p,
121
+ "return_full_text": False
122
+ }
123
+ }
124
+
125
+ logger.info(f" β†’ Remote inference request to {REMOTE_LLM_MODEL} (tokens={max_new_tokens}, temp={temperature})")
126
+ r = requests.post(url, headers=headers, json=payload, timeout=60)
127
+ if r.status_code != 200:
128
+ logger.error(f" βœ— Remote inference error {r.status_code}: {r.text[:200]}")
129
+ return ""
130
+
131
+ result = r.json()
132
+ if isinstance(result, dict) and result.get("error"):
133
+ logger.error(f" βœ— Remote inference returned error: {result.get('error')}")
134
+ return ""
135
+
136
+ # The HF Inference API can return a list of generated outputs or text
137
+ if isinstance(result, list) and result:
138
+ # entries may be strings or dicts like {"generated_text": "..."}
139
+ first = result[0]
140
+ if isinstance(first, dict):
141
+ return first.get("generated_text", "").strip()
142
+ return str(first).strip()
143
+
144
+ if isinstance(result, dict) and "generated_text" in result:
145
+ return result["generated_text"].strip()
146
+
147
+ return str(result).strip()
148
+
149
  def initialize_embeddings():
150
  logger.info("πŸ”„ Initializing embeddings model...")
151
 
 
265
  # RAG PIPELINE FUNCTIONS
266
  # ============================================================================
267
 
268
+ def generate_extractive_answer(query: str, retrieved_docs: List[Document]) -> Optional[str]:
269
+ """Build a long-form answer from retrieved documents using extractive
270
+ selection + templated transitions. This avoids calling the LLM when it
271
+ repeatedly fails or returns very short outputs.
272
+ """
273
+ logger.info(f"πŸ”§ Running extractive fallback for: '{query}'")
274
+
275
+ # Collect text and split into sentences
276
+ import re
277
+
278
+ all_text = "\n\n".join([d.page_content for d in retrieved_docs])
279
+ # Basic sentence split (keeps punctuation)
280
+ sentences = re.split(r'(?<=[.!?])\s+', all_text)
281
+ sentences = [s.strip() for s in sentences if len(s.strip()) > 30]
282
+
283
+ if not sentences:
284
+ logger.warning(" βœ— No sentences found in retrieved documents for extractive fallback")
285
+ return None
286
+
287
+ # Scoring: keyword overlap with query and fashion terms
288
+ query_tokens = set(re.findall(r"\w+", query.lower()))
289
+ fashion_keywords = set(["outfit","wear","wardrobe","style","colors","color","layer","layering",
290
+ "blazer","trousers","dress","shirt","shoes","boots","sweater","jacket",
291
+ "care","wash","dry","clean","wool","cotton","silk","linen","fit","tailor",
292
+ "versatile","neutral","accessory","belt","bag","occasion","season","fall"])
293
+ keywords = query_tokens.union(fashion_keywords)
294
+
295
+ scored = []
296
+ for s in sentences:
297
+ s_tokens = set(re.findall(r"\w+", s.lower()))
298
+ score = len(s_tokens & keywords)
299
+ # length bonus to prefer richer sentences
300
+ score += min(3, len(s.split()) // 20)
301
+ scored.append((score, s))
302
+
303
+ scored.sort(key=lambda x: x[0], reverse=True)
304
+ top_sentences = [s for _, s in scored[:60]]
305
+
306
+ # Build structured sections using top sentences + templates
307
+ def pick(n, start=0):
308
+ return top_sentences[start:start+n]
309
+
310
+ intro = []
311
+ intro.extend(pick(2, 0))
312
+ key_items = pick(8, 2)
313
+ styling = pick(8, 10)
314
+ care = pick(6, 18)
315
+ conclusion = pick(4, 24)
316
+
317
+ # Add handcrafted, helpful transitions to improve flow
318
+ template_intro = f"Here's a detailed answer to '{query}'. I'll cover essential wardrobe items, styling tips, and care advice so you can apply these suggestions practically."
319
+
320
+ # Ensure care advice includes the user's specific care example if present or add it
321
+ care_text = "\n\n".join(care)
322
+ if "dry clean" not in care_text.lower() and "hand wash" not in care_text.lower():
323
+ care_text += "\n\nDry clean or hand wash in cold water with wool-specific detergent. Never wring out wool - gently squeeze excess water and lay flat to dry on a towel."
324
+
325
+ parts = []
326
+ parts.append(template_intro)
327
+ if intro:
328
+ parts.append(" ".join(intro))
329
+ if key_items:
330
+ parts.append("Key wardrobe items to prioritize:")
331
+ parts.append(" ".join(key_items))
332
+ if styling:
333
+ parts.append("Practical styling tips:")
334
+ parts.append(" ".join(styling))
335
+ if care_text:
336
+ parts.append("Care & maintenance:")
337
+ parts.append(care_text)
338
+ if conclusion:
339
+ parts.append("Wrapping up:")
340
+ parts.append(" ".join(conclusion))
341
+
342
+ # Combine and refine spacing
343
+ answer = "\n\n".join(parts)
344
+
345
+ # Post-process: ensure target length (approximately 400-700 words)
346
+ words = answer.split()
347
+ word_count = len(words)
348
+
349
+ # If too short, append templated practical paragraphs built from keywords
350
+ if word_count < 380:
351
+ logger.info(f" β†’ Extractive answer short ({word_count} words). Appending templated paragraphs.")
352
+ extra_paragraphs = []
353
+ extra_paragraphs.append("A reliable strategy is to build around versatile, neutral pieces: a well-fitted blazer, tailored trousers, a versatile dress, and quality shoes. These items can be mixed and matched for many occasions.")
354
+ extra_paragraphs.append("Focus on fit and fabric: ensure key items are well-tailored, prioritize breathable fabrics for comfort, and choose merino or wool blends for colder seasons to layer effectively.")
355
+ extra_paragraphs.append("Layering is essential for transitional weather; combine a lightweight sweater under a jacket, and carry a scarf for added warmth and visual interest.")
356
+ extra_paragraphs.append("Accessories like belts, a structured bag, and minimal jewelry can elevate basic outfits without extra effort. Neutral colors increase versatility and pair well with bolder accents.")
357
+ answer += "\n\n" + "\n\n".join(extra_paragraphs)
358
+ words = answer.split()
359
+ word_count = len(words)
360
+
361
+ # If still too long, truncate gracefully
362
+ if word_count > 750:
363
+ words = words[:700]
364
+ answer = " ".join(words) + '...'
365
+ word_count = 700
366
+
367
+ logger.info(f" βœ… Extractive answer ready ({word_count} words)")
368
+ return answer
369
+
370
+
371
  def retrieve_knowledge_langchain(
372
  query: str,
373
  vectorstore,
 
460
  # (too short or truncated), fall back to an iterative multi-pass generator
461
  # that appends continuation chunks until we reach the target word count.
462
 
463
+ # Adjusted targets for faster generation and user's request: aim ~350 words
464
+ target_min_words = 320
465
+ target_max_words = 420
466
+ chunk_target_words = 140
467
  max_iterations = 4
468
 
469
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
470
  logger.info(f" β†’ Model call (temp={temperature}, max_new_tokens={max_new_tokens})")
471
  try:
472
+ if USE_REMOTE_LLM:
473
+ # Use remote Hugging Face Inference API
474
+ return remote_generate(prompt, max_new_tokens, temperature, top_p)
475
+
476
  out = llm_client(
477
  prompt,
478
  max_new_tokens=max_new_tokens,
 
638
 
639
  if not llm_answer:
640
  logger.error(f" βœ— All 2 LLM attempts failed")
641
+ # Fallback: use an extractive + template-based generator to produce a long,
642
+ # natural-flowing answer without using the LLM. This helps when the model
643
+ # repeatedly returns very short outputs or errors.
644
+ try:
645
+ logger.info(" β†’ Using extractive fallback generator")
646
+ fallback = generate_extractive_answer(query, retrieved_docs)
647
+ if fallback:
648
+ logger.info(" βœ… Extractive fallback produced an answer")
649
+ return fallback
650
+ except Exception as e:
651
+ logger.error(f" βœ— Extractive fallback error: {e}")
652
+
653
  return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
654
 
655
  return llm_answer