Update app.py
Browse files
app.py
CHANGED
|
@@ -62,8 +62,8 @@ CONFIG = {
|
|
| 62 |
}
|
| 63 |
|
| 64 |
# LLM Configuration - LOCAL ONLY
|
| 65 |
-
# Using
|
| 66 |
-
LOCAL_LLM_MODEL = os.environ.get("LOCAL_LLM_MODEL", "
|
| 67 |
USE_8BIT_QUANTIZATION = False
|
| 68 |
USE_REMOTE_LLM = False # LOCAL ONLY
|
| 69 |
|
|
@@ -94,18 +94,18 @@ if HF_INFERENCE_API_KEY:
|
|
| 94 |
# ============================================================================
|
| 95 |
|
| 96 |
def initialize_llm():
|
| 97 |
-
"""Initialize
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
"""
|
| 102 |
global LOCAL_LLM_MODEL
|
| 103 |
|
| 104 |
-
logger.info(f"π Initializing
|
| 105 |
-
logger.info("
|
| 106 |
|
| 107 |
try:
|
| 108 |
-
from transformers import AutoTokenizer,
|
| 109 |
|
| 110 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 111 |
logger.info(f" Device: {device}")
|
|
@@ -113,15 +113,11 @@ def initialize_llm():
|
|
| 113 |
# Load tokenizer
|
| 114 |
logger.info(" Loading tokenizer...")
|
| 115 |
tokenizer = AutoTokenizer.from_pretrained(LOCAL_LLM_MODEL)
|
| 116 |
-
|
| 117 |
-
# Set pad token
|
| 118 |
-
if tokenizer.pad_token is None:
|
| 119 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 120 |
-
|
| 121 |
logger.info(" Tokenizer ready")
|
|
|
|
| 122 |
# Load model
|
| 123 |
-
logger.info(" Loading
|
| 124 |
-
model =
|
| 125 |
LOCAL_LLM_MODEL,
|
| 126 |
torch_dtype=torch.float32
|
| 127 |
)
|
|
@@ -130,23 +126,20 @@ def initialize_llm():
|
|
| 130 |
model.eval()
|
| 131 |
logger.info(" Model ready")
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
device=0 if device == "cuda" else -1,
|
| 140 |
-
max_new_tokens=150
|
| 141 |
-
)
|
| 142 |
|
| 143 |
CONFIG["llm_model"] = LOCAL_LLM_MODEL
|
| 144 |
-
CONFIG["model_type"] = "
|
| 145 |
|
| 146 |
-
logger.info(f"β
|
| 147 |
-
logger.info(f" Size:
|
| 148 |
-
logger.info(f" Quality:
|
| 149 |
-
logger.info(f" Speed: 5
|
| 150 |
|
| 151 |
return llm_client
|
| 152 |
|
|
@@ -355,87 +348,84 @@ def load_vector_store(embeddings):
|
|
| 355 |
# ============================================================================
|
| 356 |
|
| 357 |
def generate_extractive_answer(query: str, retrieved_docs: List[Document]) -> Optional[str]:
|
| 358 |
-
"""Build a
|
| 359 |
-
|
| 360 |
-
repeatedly fails or returns very short outputs.
|
| 361 |
"""
|
| 362 |
-
logger.info(f"π§
|
| 363 |
|
| 364 |
-
# Collect text and split into sentences
|
| 365 |
import re
|
| 366 |
|
| 367 |
-
all_text = "\n\n".join([d.page_content for d in retrieved_docs])
|
| 368 |
-
# Basic sentence split (keeps punctuation)
|
| 369 |
sentences = re.split(r'(?<=[.!?])\s+', all_text)
|
| 370 |
-
sentences = [s.strip() for s in sentences if len(s.strip()) >
|
| 371 |
|
| 372 |
if not sentences:
|
| 373 |
-
logger.warning(" β No sentences found
|
| 374 |
return None
|
| 375 |
|
| 376 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
query_tokens = set(re.findall(r"\w+", query.lower()))
|
| 378 |
-
|
| 379 |
-
"blazer","trousers","dress","shirt","shoes","boots","sweater","jacket",
|
| 380 |
-
"care","wash","dry","clean","wool","cotton","silk","linen","fit","tailor",
|
| 381 |
-
"versatile","neutral","accessory","belt","bag","occasion","season","fall"])
|
| 382 |
-
keywords = query_tokens.union(fashion_keywords)
|
| 383 |
-
|
| 384 |
scored = []
|
| 385 |
-
for s in
|
| 386 |
s_tokens = set(re.findall(r"\w+", s.lower()))
|
| 387 |
-
score = len(s_tokens &
|
| 388 |
-
#
|
| 389 |
-
score += min(
|
| 390 |
scored.append((score, s))
|
| 391 |
|
| 392 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
#
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
parts.append(care_text)
|
| 427 |
-
if conclusion:
|
| 428 |
-
parts.append("Wrapping up:")
|
| 429 |
-
parts.append(" ".join(conclusion))
|
| 430 |
-
|
| 431 |
-
# Combine and refine spacing
|
| 432 |
-
answer = "\n\n".join(parts)
|
| 433 |
-
|
| 434 |
-
# Natural length - no artificial padding or truncation
|
| 435 |
-
words = answer.split()
|
| 436 |
-
word_count = len(words)
|
| 437 |
-
|
| 438 |
-
logger.info(f" β
Extractive answer ready ({word_count} words)")
|
| 439 |
return answer
|
| 440 |
|
| 441 |
|
|
@@ -594,16 +584,22 @@ def generate_llm_answer(
|
|
| 594 |
llm_client,
|
| 595 |
attempt: int = 1
|
| 596 |
) -> Optional[str]:
|
| 597 |
-
|
| 598 |
if not llm_client:
|
| 599 |
-
logger.error(" β
|
| 600 |
return None
|
| 601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
query_lower = query.lower()
|
| 603 |
query_words = set(query_lower.split())
|
| 604 |
|
| 605 |
scored_docs = []
|
| 606 |
-
for doc in retrieved_docs[:
|
| 607 |
content = doc.page_content.lower()
|
| 608 |
doc_words = set(content.split())
|
| 609 |
overlap = len(query_words.intersection(doc_words))
|
|
@@ -617,180 +613,73 @@ def generate_llm_answer(
|
|
| 617 |
scored_docs.append((doc, overlap))
|
| 618 |
|
| 619 |
scored_docs.sort(key=lambda x: x[1], reverse=True)
|
| 620 |
-
top_docs = [doc[0] for doc in scored_docs[:
|
| 621 |
|
| 622 |
-
#
|
| 623 |
context_parts = []
|
| 624 |
-
for doc in top_docs
|
| 625 |
content = doc.page_content.strip()
|
| 626 |
-
if len(content) >
|
| 627 |
-
content = content[:
|
| 628 |
context_parts.append(content)
|
| 629 |
|
| 630 |
context_text = "\n\n".join(context_parts)
|
| 631 |
|
| 632 |
-
#
|
| 633 |
-
|
| 634 |
-
target_max_words = 999999 # No maximum - let model complete naturally
|
| 635 |
-
chunk_target_words = 0 # Not used in natural mode
|
| 636 |
-
max_iterations = 0 # Single-shot only for speed
|
| 637 |
-
|
| 638 |
-
def call_model(prompt, max_new_tokens, temperature):
|
| 639 |
-
"""Generate with GPT-2 Medium - better quality"""
|
| 640 |
-
try:
|
| 641 |
-
logger.info(f" β Generating (max_tokens={max_new_tokens})")
|
| 642 |
-
|
| 643 |
-
out = llm_client(
|
| 644 |
-
prompt,
|
| 645 |
-
max_new_tokens=max_new_tokens,
|
| 646 |
-
temperature=temperature,
|
| 647 |
-
do_sample=True,
|
| 648 |
-
return_full_text=False,
|
| 649 |
-
repetition_penalty=1.15, # Moderate penalty for better flow
|
| 650 |
-
top_k=50,
|
| 651 |
-
top_p=0.92,
|
| 652 |
-
pad_token_id=llm_client.tokenizer.eos_token_id,
|
| 653 |
-
eos_token_id=llm_client.tokenizer.eos_token_id
|
| 654 |
-
)
|
| 655 |
-
|
| 656 |
-
if not out or not isinstance(out, list) or len(out) == 0:
|
| 657 |
-
return ''
|
| 658 |
-
|
| 659 |
-
generated = out[0].get('generated_text', '').strip()
|
| 660 |
-
|
| 661 |
-
word_count = len(generated.split())
|
| 662 |
-
logger.info(f" β
Generated {word_count} words")
|
| 663 |
-
return generated
|
| 664 |
-
|
| 665 |
-
except Exception as e:
|
| 666 |
-
logger.error(f" β Error: {e}")
|
| 667 |
-
return ''
|
| 668 |
|
| 669 |
-
|
| 670 |
-
base_prompt = f"""Question: {query}
|
| 671 |
|
| 672 |
-
|
| 673 |
-
{context_text[:
|
| 674 |
|
| 675 |
-
|
| 676 |
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
logger.info(f" β οΈ Short but acceptable response ({word_count} words)")
|
| 715 |
-
return response
|
| 716 |
-
|
| 717 |
-
# Ultra permissive: accept ANYTHING with 10+ words to show something
|
| 718 |
-
if word_count >= 10:
|
| 719 |
-
logger.info(f" β οΈ Very short response ({word_count} words) but accepting")
|
| 720 |
-
return response
|
| 721 |
-
|
| 722 |
-
# EMERGENCY: accept even 5+ words if that's all we get
|
| 723 |
-
if word_count >= 5:
|
| 724 |
-
logger.info(f" β οΈ EMERGENCY: Accepting tiny response ({word_count} words)")
|
| 725 |
return response
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
for i in range(max_iterations):
|
| 732 |
-
remaining = max(0, target_min_words - len(accumulated.split()))
|
| 733 |
-
if remaining <= 0:
|
| 734 |
-
break
|
| 735 |
-
|
| 736 |
-
# Ask the model to continue without repeating previous content
|
| 737 |
-
continue_prompt = f"""Add {min(chunk_target_words, remaining)} more words to complete this answer:
|
| 738 |
-
|
| 739 |
-
{accumulated[-400:]}
|
| 740 |
-
|
| 741 |
-
Continue naturally:
|
| 742 |
-
"""
|
| 743 |
-
|
| 744 |
-
# Optimized continuation parameters for speed
|
| 745 |
-
cont_output = call_model(continue_prompt, max_new_tokens=250, temperature=0.80, top_p=0.90, repetition_penalty=1.10)
|
| 746 |
-
cont_text = (cont_output or '').strip()
|
| 747 |
-
|
| 748 |
-
if not cont_text:
|
| 749 |
-
logger.warning(f" β Continuation {i+1} returned empty β stopping")
|
| 750 |
-
break
|
| 751 |
-
|
| 752 |
-
# Avoid trivial repeats: if continuation repeats the accumulated text, stop
|
| 753 |
-
if cont_text in accumulated or accumulated.endswith(cont_text[:50]):
|
| 754 |
-
logger.warning(f" β Continuation {i+1} appears repetitive β stopping")
|
| 755 |
-
break
|
| 756 |
-
|
| 757 |
-
# Append and normalize spacing
|
| 758 |
-
accumulated = accumulated.rstrip() + '\n\n' + cont_text
|
| 759 |
-
|
| 760 |
-
current_word_count = len(accumulated.split())
|
| 761 |
-
logger.info(f" β After continuation {i+1}, words={current_word_count}")
|
| 762 |
-
|
| 763 |
-
# Stop early if we've reached or exceeded the minimum target
|
| 764 |
-
if current_word_count >= target_min_words:
|
| 765 |
-
break
|
| 766 |
-
|
| 767 |
-
# Safety: if no progress, break
|
| 768 |
-
if current_word_count == prev_word_count:
|
| 769 |
-
logger.warning(" β No progress from continuation β stopping")
|
| 770 |
-
break
|
| 771 |
-
prev_word_count = current_word_count
|
| 772 |
-
|
| 773 |
-
final_words = accumulated.split()
|
| 774 |
-
final_count = len(final_words)
|
| 775 |
-
|
| 776 |
-
if final_count < target_min_words:
|
| 777 |
-
logger.warning(f" β Final answer too short ({final_count} words) after continuations")
|
| 778 |
-
return None
|
| 779 |
-
|
| 780 |
-
if final_count > target_max_words:
|
| 781 |
-
logger.info(f" β οΈ Final answer long ({final_count} words). Truncating to {target_max_words} words.")
|
| 782 |
-
accumulated = ' '.join(final_words[:target_max_words]) + '...'
|
| 783 |
-
final_count = target_max_words
|
| 784 |
-
|
| 785 |
-
# Final check for apology/hedging at start
|
| 786 |
-
apology_phrases = ["i cannot", "i can't", "i'm sorry", "i apologize", "i don't have"]
|
| 787 |
-
if any(phrase in accumulated.lower()[:200] for phrase in apology_phrases):
|
| 788 |
-
logger.warning(" β Apology/hedging detected in final answer")
|
| 789 |
return None
|
| 790 |
|
| 791 |
-
logger.info(f" β
Built long-form answer ({final_count} words)")
|
| 792 |
-
return accumulated
|
| 793 |
-
|
| 794 |
def generate_answer_langchain(
|
| 795 |
query: str,
|
| 796 |
vectorstore,
|
|
@@ -809,8 +698,18 @@ def generate_answer_langchain(
|
|
| 809 |
if not retrieved_docs:
|
| 810 |
return "I couldn't find relevant information to answer your question."
|
| 811 |
|
| 812 |
-
#
|
| 813 |
-
logger.info(" β
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
try:
|
| 815 |
extractive_answer = generate_extractive_answer(query, retrieved_docs)
|
| 816 |
if extractive_answer:
|
|
@@ -845,12 +744,17 @@ def fashion_chatbot(message: str, history: List[List[str]]):
|
|
| 845 |
|
| 846 |
yield f"π Generating fashion advice ({len(retrieved_docs)} sources found)..."
|
| 847 |
|
| 848 |
-
#
|
| 849 |
-
logger.info(" β Generating
|
| 850 |
-
llm_answer =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
|
| 852 |
if not llm_answer:
|
| 853 |
-
logger.error(f" β
|
| 854 |
yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
|
| 855 |
return
|
| 856 |
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
# LLM Configuration - LOCAL ONLY
|
| 65 |
+
# Using Flan-T5 Base: 250M params, instruction-tuned, fast and high quality
|
| 66 |
+
LOCAL_LLM_MODEL = os.environ.get("LOCAL_LLM_MODEL", "google/flan-t5-base")
|
| 67 |
USE_8BIT_QUANTIZATION = False
|
| 68 |
USE_REMOTE_LLM = False # LOCAL ONLY
|
| 69 |
|
|
|
|
| 94 |
# ============================================================================
|
| 95 |
|
| 96 |
def initialize_llm():
|
| 97 |
+
"""Initialize Flan-T5 Base for local CPU generation.
|
| 98 |
|
| 99 |
+
Flan-T5 is instruction-tuned, produces high-quality answers,
|
| 100 |
+
and is fast on CPU (3-5 seconds per response).
|
| 101 |
"""
|
| 102 |
global LOCAL_LLM_MODEL
|
| 103 |
|
| 104 |
+
logger.info(f"π Initializing Flan-T5 Base: {LOCAL_LLM_MODEL}")
|
| 105 |
+
logger.info(" Instruction-tuned for high-quality Q&A")
|
| 106 |
|
| 107 |
try:
|
| 108 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 109 |
|
| 110 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 111 |
logger.info(f" Device: {device}")
|
|
|
|
| 113 |
# Load tokenizer
|
| 114 |
logger.info(" Loading tokenizer...")
|
| 115 |
tokenizer = AutoTokenizer.from_pretrained(LOCAL_LLM_MODEL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
logger.info(" Tokenizer ready")
|
| 117 |
+
|
| 118 |
# Load model
|
| 119 |
+
logger.info(" Loading Flan-T5 Base (10-15 seconds)...")
|
| 120 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 121 |
LOCAL_LLM_MODEL,
|
| 122 |
torch_dtype=torch.float32
|
| 123 |
)
|
|
|
|
| 126 |
model.eval()
|
| 127 |
logger.info(" Model ready")
|
| 128 |
|
| 129 |
+
# Store model and tokenizer for custom generation
|
| 130 |
+
llm_client = {
|
| 131 |
+
'model': model,
|
| 132 |
+
'tokenizer': tokenizer,
|
| 133 |
+
'device': device
|
| 134 |
+
}
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
CONFIG["llm_model"] = LOCAL_LLM_MODEL
|
| 137 |
+
CONFIG["model_type"] = "flan_t5_base_local"
|
| 138 |
|
| 139 |
+
logger.info(f"β
Flan-T5 Base initialized: {LOCAL_LLM_MODEL}")
|
| 140 |
+
logger.info(f" Size: 250M parameters (instruction-tuned)")
|
| 141 |
+
logger.info(f" Quality: Excellent for fashion Q&A")
|
| 142 |
+
logger.info(f" Speed: 3-5 seconds per 200 words")
|
| 143 |
|
| 144 |
return llm_client
|
| 145 |
|
|
|
|
| 348 |
# ============================================================================
|
| 349 |
|
| 350 |
def generate_extractive_answer(query: str, retrieved_docs: List[Document]) -> Optional[str]:
|
| 351 |
+
"""Build a focused, intelligent answer from retrieved documents.
|
| 352 |
+
Filters out product catalogs and provides concise, relevant fashion advice.
|
|
|
|
| 353 |
"""
|
| 354 |
+
logger.info(f"π§ Generating smart extractive answer for: '{query}'")
|
| 355 |
|
|
|
|
| 356 |
import re
|
| 357 |
|
| 358 |
+
all_text = "\n\n".join([d.page_content for d in retrieved_docs[:10]]) # Top 10 docs only
|
|
|
|
| 359 |
sentences = re.split(r'(?<=[.!?])\s+', all_text)
|
| 360 |
+
sentences = [s.strip() for s in sentences if len(s.strip()) > 40]
|
| 361 |
|
| 362 |
if not sentences:
|
| 363 |
+
logger.warning(" β No sentences found")
|
| 364 |
return None
|
| 365 |
|
| 366 |
+
# Filter out product catalog noise
|
| 367 |
+
filtered_sentences = []
|
| 368 |
+
for s in sentences:
|
| 369 |
+
# Skip sentences that are clearly product listings
|
| 370 |
+
if re.search(r'Category:|Season:|Usage:|Color:|Price:|SKU:', s, re.IGNORECASE):
|
| 371 |
+
continue
|
| 372 |
+
# Skip sentences with brand names followed by product codes
|
| 373 |
+
if re.search(r'(Men|Women|Kids|Boys|Girls)\s+[A-Z][a-z]+\s+[A-Z]', s):
|
| 374 |
+
continue
|
| 375 |
+
# Keep only advice/guidance sentences
|
| 376 |
+
if any(word in s.lower() for word in ['wear', 'pair', 'choose', 'opt', 'works', 'complement',
|
| 377 |
+
'match', 'combine', 'style', 'look', 'consider', 'add']):
|
| 378 |
+
filtered_sentences.append(s)
|
| 379 |
+
|
| 380 |
+
if not filtered_sentences:
|
| 381 |
+
# Fallback: use all sentences if filtering was too aggressive
|
| 382 |
+
filtered_sentences = [s for s in sentences if len(s.split()) > 10][:15]
|
| 383 |
+
|
| 384 |
+
# Score by relevance to query
|
| 385 |
query_tokens = set(re.findall(r"\w+", query.lower()))
|
| 386 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
scored = []
|
| 388 |
+
for s in filtered_sentences:
|
| 389 |
s_tokens = set(re.findall(r"\w+", s.lower()))
|
| 390 |
+
score = len(s_tokens & query_tokens)
|
| 391 |
+
# Bonus for sentence length (prefer substantial advice)
|
| 392 |
+
score += min(2, len(s.split()) // 30)
|
| 393 |
scored.append((score, s))
|
| 394 |
|
| 395 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 396 |
+
|
| 397 |
+
# Take top 5-8 most relevant sentences
|
| 398 |
+
top_sentences = [s for _, s in scored[:8] if s]
|
| 399 |
+
|
| 400 |
+
if not top_sentences:
|
| 401 |
+
return None
|
| 402 |
+
|
| 403 |
+
# Build concise answer
|
| 404 |
+
answer_parts = []
|
| 405 |
+
|
| 406 |
+
# Add 3-5 best sentences with natural flow
|
| 407 |
+
for i, sentence in enumerate(top_sentences[:5]):
|
| 408 |
+
answer_parts.append(sentence)
|
| 409 |
+
|
| 410 |
+
answer = " ".join(answer_parts)
|
| 411 |
+
|
| 412 |
+
# Clean up any remaining noise
|
| 413 |
+
answer = re.sub(r'\s+', ' ', answer).strip()
|
| 414 |
+
|
| 415 |
+
word_count = len(answer.split())
|
| 416 |
+
|
| 417 |
+
# Ensure answer is substantial but not too long (100-200 words ideal)
|
| 418 |
+
if word_count < 50:
|
| 419 |
+
logger.warning(f" β οΈ Answer too short ({word_count} words)")
|
| 420 |
+
return None
|
| 421 |
+
|
| 422 |
+
if word_count > 250:
|
| 423 |
+
# Trim to ~200 words
|
| 424 |
+
words = answer.split()[:200]
|
| 425 |
+
answer = " ".join(words) + "..."
|
| 426 |
+
word_count = 200
|
| 427 |
+
|
| 428 |
+
logger.info(f" β
Smart answer ready ({word_count} words)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
return answer
|
| 430 |
|
| 431 |
|
|
|
|
| 584 |
llm_client,
|
| 585 |
attempt: int = 1
|
| 586 |
) -> Optional[str]:
|
| 587 |
+
"""Generate answer using Flan-T5 Base - instruction-tuned for Q&A."""
|
| 588 |
if not llm_client:
|
| 589 |
+
logger.error(" β Flan-T5 model not initialized")
|
| 590 |
return None
|
| 591 |
|
| 592 |
+
# Extract model components
|
| 593 |
+
model = llm_client['model']
|
| 594 |
+
tokenizer = llm_client['tokenizer']
|
| 595 |
+
device = llm_client['device']
|
| 596 |
+
|
| 597 |
+
# Select best documents
|
| 598 |
query_lower = query.lower()
|
| 599 |
query_words = set(query_lower.split())
|
| 600 |
|
| 601 |
scored_docs = []
|
| 602 |
+
for doc in retrieved_docs[:15]:
|
| 603 |
content = doc.page_content.lower()
|
| 604 |
doc_words = set(content.split())
|
| 605 |
overlap = len(query_words.intersection(doc_words))
|
|
|
|
| 613 |
scored_docs.append((doc, overlap))
|
| 614 |
|
| 615 |
scored_docs.sort(key=lambda x: x[1], reverse=True)
|
| 616 |
+
top_docs = [doc[0] for doc in scored_docs[:5]]
|
| 617 |
|
| 618 |
+
# Build rich context (Flan-T5 can handle more context)
|
| 619 |
context_parts = []
|
| 620 |
+
for doc in top_docs:
|
| 621 |
content = doc.page_content.strip()
|
| 622 |
+
if len(content) > 300:
|
| 623 |
+
content = content[:300] + "..."
|
| 624 |
context_parts.append(content)
|
| 625 |
|
| 626 |
context_text = "\n\n".join(context_parts)
|
| 627 |
|
| 628 |
+
# Flan-T5 instruction prompt - direct and clear
|
| 629 |
+
prompt = f"""Answer this fashion question with specific, practical advice (150-200 words):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
|
| 631 |
+
Question: {query}
|
|
|
|
| 632 |
|
| 633 |
+
Fashion Knowledge:
|
| 634 |
+
{context_text[:600]}
|
| 635 |
|
| 636 |
+
Provide detailed fashion advice:"""
|
| 637 |
|
| 638 |
+
try:
|
| 639 |
+
logger.info(f" β Generating with Flan-T5 (target: 200 words)")
|
| 640 |
+
|
| 641 |
+
# Tokenize input
|
| 642 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 643 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 644 |
+
|
| 645 |
+
# Generate with Flan-T5 optimized parameters
|
| 646 |
+
with torch.no_grad():
|
| 647 |
+
outputs = model.generate(
|
| 648 |
+
**inputs,
|
| 649 |
+
max_new_tokens=250, # ~200 words
|
| 650 |
+
min_length=120, # Ensure substantial answers
|
| 651 |
+
temperature=0.8, # Balanced creativity
|
| 652 |
+
top_p=0.9,
|
| 653 |
+
do_sample=True,
|
| 654 |
+
repetition_penalty=1.2,
|
| 655 |
+
no_repeat_ngram_size=3,
|
| 656 |
+
early_stopping=False
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Decode output
|
| 660 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 661 |
+
|
| 662 |
+
word_count = len(response.split())
|
| 663 |
+
logger.info(f" β
Generated {word_count} words with Flan-T5")
|
| 664 |
+
|
| 665 |
+
# Validate quality
|
| 666 |
+
if word_count < 50:
|
| 667 |
+
logger.warning(f" β οΈ Response too short ({word_count} words)")
|
| 668 |
+
return None
|
| 669 |
+
|
| 670 |
+
# Check for generic/irrelevant content
|
| 671 |
+
if any(phrase in response.lower() for phrase in ["i cannot", "i can't", "i'm sorry", "as an ai"]):
|
| 672 |
+
logger.warning(" β οΈ Generic response detected")
|
| 673 |
+
return None
|
| 674 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
return response
|
| 676 |
+
|
| 677 |
+
except Exception as e:
|
| 678 |
+
logger.error(f" β Flan-T5 generation error: {e}")
|
| 679 |
+
import traceback
|
| 680 |
+
logger.error(traceback.format_exc())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
return None
|
| 682 |
|
|
|
|
|
|
|
|
|
|
| 683 |
def generate_answer_langchain(
|
| 684 |
query: str,
|
| 685 |
vectorstore,
|
|
|
|
| 698 |
if not retrieved_docs:
|
| 699 |
return "I couldn't find relevant information to answer your question."
|
| 700 |
|
| 701 |
+
# Try Flan-T5 first (instruction-tuned, high quality)
|
| 702 |
+
logger.info(" β Attempting Flan-T5 generation (primary method)")
|
| 703 |
+
try:
|
| 704 |
+
llm_answer = generate_llm_answer(query, retrieved_docs, llm_client, attempt=1)
|
| 705 |
+
if llm_answer:
|
| 706 |
+
logger.info(f" β
Flan-T5 answer generated successfully")
|
| 707 |
+
return llm_answer
|
| 708 |
+
except Exception as e:
|
| 709 |
+
logger.error(f" β Flan-T5 error: {e}")
|
| 710 |
+
|
| 711 |
+
# Fallback to extractive if Flan-T5 fails
|
| 712 |
+
logger.info(" β Fallback: Using extractive answer generator")
|
| 713 |
try:
|
| 714 |
extractive_answer = generate_extractive_answer(query, retrieved_docs)
|
| 715 |
if extractive_answer:
|
|
|
|
| 744 |
|
| 745 |
yield f"π Generating fashion advice ({len(retrieved_docs)} sources found)..."
|
| 746 |
|
| 747 |
+
# Try Flan-T5 first (fast and high quality)
|
| 748 |
+
logger.info(" β Generating with Flan-T5")
|
| 749 |
+
llm_answer = generate_llm_answer(message.strip(), retrieved_docs, llm_client, attempt=1)
|
| 750 |
+
|
| 751 |
+
# Fallback to extractive if needed
|
| 752 |
+
if not llm_answer:
|
| 753 |
+
logger.info(" β Fallback: Using extractive answer")
|
| 754 |
+
llm_answer = generate_extractive_answer(message.strip(), retrieved_docs)
|
| 755 |
|
| 756 |
if not llm_answer:
|
| 757 |
+
logger.error(f" β All generation methods failed")
|
| 758 |
yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
|
| 759 |
return
|
| 760 |
|