Update app.py
Browse files
app.py
CHANGED
|
@@ -635,23 +635,21 @@ def generate_llm_answer(
|
|
| 635 |
max_iterations = 0 # Single-shot only for speed
|
| 636 |
|
| 637 |
def call_model(prompt, max_new_tokens, temperature):
|
| 638 |
-
"""Generate with DistilGPT2
|
| 639 |
try:
|
| 640 |
-
#
|
| 641 |
-
formatted_prompt = f"Fashion advice: {prompt}\n\nAnswer:"
|
| 642 |
-
|
| 643 |
logger.info(f" β Generating (max_tokens={max_new_tokens})")
|
| 644 |
|
| 645 |
out = llm_client(
|
| 646 |
-
|
| 647 |
max_new_tokens=max_new_tokens,
|
| 648 |
temperature=temperature,
|
| 649 |
do_sample=True,
|
| 650 |
return_full_text=False,
|
| 651 |
-
repetition_penalty=1.
|
| 652 |
-
no_repeat_ngram_size=
|
| 653 |
-
top_k=
|
| 654 |
-
top_p=0.
|
| 655 |
pad_token_id=llm_client.tokenizer.eos_token_id,
|
| 656 |
eos_token_id=llm_client.tokenizer.eos_token_id
|
| 657 |
)
|
|
@@ -661,15 +659,11 @@ def generate_llm_answer(
|
|
| 661 |
|
| 662 |
generated = out[0].get('generated_text', '').strip()
|
| 663 |
|
| 664 |
-
# Remove prompt if present
|
| 665 |
-
if "Answer:" in generated:
|
| 666 |
-
generated = generated.split("Answer:")[-1].strip()
|
| 667 |
-
|
| 668 |
# Clean up bad patterns
|
| 669 |
import re
|
| 670 |
-
# Remove patterns like "A: B: C:" or
|
| 671 |
-
generated = re.sub(r'\b([A-Z]):\s*\1:\s*', '', generated)
|
| 672 |
generated = re.sub(r'\b[A-Z]:\s*(?=[A-Z]:)', '', generated)
|
|
|
|
| 673 |
generated = generated.strip()
|
| 674 |
|
| 675 |
word_count = len(generated.split())
|
|
@@ -680,21 +674,20 @@ def generate_llm_answer(
|
|
| 680 |
logger.error(f" β Error: {e}")
|
| 681 |
return ''
|
| 682 |
|
| 683 |
-
#
|
| 684 |
-
base_prompt = f"""
|
| 685 |
|
| 686 |
-
|
| 687 |
-
{context_text[:400]}
|
| 688 |
|
| 689 |
-
|
| 690 |
|
| 691 |
-
# DistilGPT2 parameters -
|
| 692 |
if attempt == 1:
|
| 693 |
-
max_new_tokens =
|
| 694 |
-
temperature = 0.
|
| 695 |
else:
|
| 696 |
-
max_new_tokens =
|
| 697 |
-
temperature = 0.
|
| 698 |
|
| 699 |
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...")
|
| 700 |
initial_output = call_model(base_prompt, max_new_tokens, temperature)
|
|
@@ -823,17 +816,22 @@ def generate_answer_langchain(
|
|
| 823 |
if not retrieved_docs:
|
| 824 |
return "I couldn't find relevant information to answer your question."
|
| 825 |
|
| 826 |
-
#
|
| 827 |
-
|
| 828 |
-
logger.info(" β Using extractive answer generator (primary method)")
|
| 829 |
-
try:
|
| 830 |
-
extractive_answer = generate_extractive_answer(query, retrieved_docs)
|
| 831 |
-
if extractive_answer:
|
| 832 |
-
logger.info(" β
Extractive answer generated successfully")
|
| 833 |
-
return extractive_answer
|
| 834 |
-
except Exception as e:
|
| 835 |
-
logger.error(f" β Extractive answer error: {e}")
|
| 836 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
|
| 838 |
|
| 839 |
# ============================================================================
|
|
@@ -860,12 +858,17 @@ def fashion_chatbot(message: str, history: List[List[str]]):
|
|
| 860 |
|
| 861 |
yield f"π Generating answer ({len(retrieved_docs)} sources found)..."
|
| 862 |
|
| 863 |
-
#
|
| 864 |
-
|
| 865 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
|
| 867 |
if not llm_answer:
|
| 868 |
-
logger.error(f" β
|
| 869 |
yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
|
| 870 |
return
|
| 871 |
|
|
|
|
| 635 |
max_iterations = 0 # Single-shot only for speed
|
| 636 |
|
| 637 |
def call_model(prompt, max_new_tokens, temperature):
|
| 638 |
+
"""Generate with DistilGPT2"""
|
| 639 |
try:
|
| 640 |
+
# Simple, direct prompt - no special formatting
|
|
|
|
|
|
|
| 641 |
logger.info(f" β Generating (max_tokens={max_new_tokens})")
|
| 642 |
|
| 643 |
out = llm_client(
|
| 644 |
+
prompt,
|
| 645 |
max_new_tokens=max_new_tokens,
|
| 646 |
temperature=temperature,
|
| 647 |
do_sample=True,
|
| 648 |
return_full_text=False,
|
| 649 |
+
repetition_penalty=1.3, # Strong penalty against repetition
|
| 650 |
+
no_repeat_ngram_size=2, # Prevent repeating 2-grams
|
| 651 |
+
top_k=40,
|
| 652 |
+
top_p=0.9,
|
| 653 |
pad_token_id=llm_client.tokenizer.eos_token_id,
|
| 654 |
eos_token_id=llm_client.tokenizer.eos_token_id
|
| 655 |
)
|
|
|
|
| 659 |
|
| 660 |
generated = out[0].get('generated_text', '').strip()
|
| 661 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
# Clean up bad patterns
|
| 663 |
import re
|
| 664 |
+
# Remove nonsensical patterns like "A: B: C:" or single letters
|
|
|
|
| 665 |
generated = re.sub(r'\b[A-Z]:\s*(?=[A-Z]:)', '', generated)
|
| 666 |
+
generated = re.sub(r'^[A-Z]:\s*', '', generated) # Remove leading letters
|
| 667 |
generated = generated.strip()
|
| 668 |
|
| 669 |
word_count = len(generated.split())
|
|
|
|
| 674 |
logger.error(f" β Error: {e}")
|
| 675 |
return ''
|
| 676 |
|
| 677 |
+
# Simple, natural prompt that DistilGPT2 can handle
|
| 678 |
+
base_prompt = f"""For the question "{query}", here is helpful fashion advice:
|
| 679 |
|
| 680 |
+
{context_text[:300]}
|
|
|
|
| 681 |
|
| 682 |
+
To summarize:"""
|
| 683 |
|
| 684 |
+
# DistilGPT2 parameters - lower temperature for more coherent output
|
| 685 |
if attempt == 1:
|
| 686 |
+
max_new_tokens = 120
|
| 687 |
+
temperature = 0.6
|
| 688 |
else:
|
| 689 |
+
max_new_tokens = 150
|
| 690 |
+
temperature = 0.65
|
| 691 |
|
| 692 |
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...")
|
| 693 |
initial_output = call_model(base_prompt, max_new_tokens, temperature)
|
|
|
|
| 816 |
if not retrieved_docs:
|
| 817 |
return "I couldn't find relevant information to answer your question."
|
| 818 |
|
| 819 |
+
# Try LLM generation with multiple attempts
|
| 820 |
+
max_attempts = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
|
| 822 |
+
llm_answer = None
|
| 823 |
+
for attempt in range(1, max_attempts + 1):
|
| 824 |
+
logger.info(f"\n π€ LLM Generation Attempt {attempt}/{max_attempts}")
|
| 825 |
+
llm_answer = generate_llm_answer(query, retrieved_docs, llm_client, attempt)
|
| 826 |
+
|
| 827 |
+
if llm_answer:
|
| 828 |
+
logger.info(f" β
LLM answer generated successfully")
|
| 829 |
+
return llm_answer
|
| 830 |
+
else:
|
| 831 |
+
if attempt < max_attempts:
|
| 832 |
+
logger.warning(f" β Attempt {attempt}/{max_attempts} failed, retrying...")
|
| 833 |
+
|
| 834 |
+
logger.error(f" β All {max_attempts} LLM attempts failed")
|
| 835 |
return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
|
| 836 |
|
| 837 |
# ============================================================================
|
|
|
|
| 858 |
|
| 859 |
yield f"π Generating answer ({len(retrieved_docs)} sources found)..."
|
| 860 |
|
| 861 |
+
# Generate with LLM
|
| 862 |
+
llm_answer = None
|
| 863 |
+
for attempt in range(1, 3):
|
| 864 |
+
logger.info(f"\n π€ LLM Generation Attempt {attempt}/2")
|
| 865 |
+
llm_answer = generate_llm_answer(message.strip(), retrieved_docs, llm_client, attempt)
|
| 866 |
+
|
| 867 |
+
if llm_answer:
|
| 868 |
+
break
|
| 869 |
|
| 870 |
if not llm_answer:
|
| 871 |
+
logger.error(f" β All LLM attempts failed")
|
| 872 |
yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
|
| 873 |
return
|
| 874 |
|