hamxaameer commited on
Commit
35c968b
Β·
verified Β·
1 Parent(s): a62a145

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -39
app.py CHANGED
@@ -635,23 +635,21 @@ def generate_llm_answer(
635
  max_iterations = 0 # Single-shot only for speed
636
 
637
  def call_model(prompt, max_new_tokens, temperature):
638
- """Generate with DistilGPT2 - simple and fast"""
639
  try:
640
- # Better prompt format for DistilGPT2
641
- formatted_prompt = f"Fashion advice: {prompt}\n\nAnswer:"
642
-
643
  logger.info(f" β†’ Generating (max_tokens={max_new_tokens})")
644
 
645
  out = llm_client(
646
- formatted_prompt,
647
  max_new_tokens=max_new_tokens,
648
  temperature=temperature,
649
  do_sample=True,
650
  return_full_text=False,
651
- repetition_penalty=1.2, # Prevent repetition
652
- no_repeat_ngram_size=3, # Prevent repeating 3-grams
653
- top_k=50,
654
- top_p=0.95,
655
  pad_token_id=llm_client.tokenizer.eos_token_id,
656
  eos_token_id=llm_client.tokenizer.eos_token_id
657
  )
@@ -661,15 +659,11 @@ def generate_llm_answer(
661
 
662
  generated = out[0].get('generated_text', '').strip()
663
 
664
- # Remove prompt if present
665
- if "Answer:" in generated:
666
- generated = generated.split("Answer:")[-1].strip()
667
-
668
  # Clean up bad patterns
669
  import re
670
- # Remove patterns like "A: B: C:" or "I: I: I:"
671
- generated = re.sub(r'\b([A-Z]):\s*\1:\s*', '', generated)
672
  generated = re.sub(r'\b[A-Z]:\s*(?=[A-Z]:)', '', generated)
 
673
  generated = generated.strip()
674
 
675
  word_count = len(generated.split())
@@ -680,21 +674,20 @@ def generate_llm_answer(
680
  logger.error(f" βœ— Error: {e}")
681
  return ''
682
 
683
- # Better prompt format with context
684
- base_prompt = f"""Question: {query}
685
 
686
- Context from fashion knowledge base:
687
- {context_text[:400]}
688
 
689
- Based on the above information, here is detailed fashion advice:"""
690
 
691
- # DistilGPT2 parameters - adjusted for better quality
692
  if attempt == 1:
693
- max_new_tokens = 150
694
- temperature = 0.7
695
  else:
696
- max_new_tokens = 200
697
- temperature = 0.75
698
 
699
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
700
  initial_output = call_model(base_prompt, max_new_tokens, temperature)
@@ -823,17 +816,22 @@ def generate_answer_langchain(
823
  if not retrieved_docs:
824
  return "I couldn't find relevant information to answer your question."
825
 
826
- # Use extractive answer as PRIMARY method - reliable and high-quality
827
- # Small LLMs (DistilGPT2) produce nonsensical output on CPU
828
- logger.info(" β†’ Using extractive answer generator (primary method)")
829
- try:
830
- extractive_answer = generate_extractive_answer(query, retrieved_docs)
831
- if extractive_answer:
832
- logger.info(" βœ… Extractive answer generated successfully")
833
- return extractive_answer
834
- except Exception as e:
835
- logger.error(f" βœ— Extractive answer error: {e}")
836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
  return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
838
 
839
  # ============================================================================
@@ -860,12 +858,17 @@ def fashion_chatbot(message: str, history: List[List[str]]):
860
 
861
  yield f"πŸ’­ Generating answer ({len(retrieved_docs)} sources found)..."
862
 
863
- # Use extractive answer - reliable and high-quality
864
- logger.info(" β†’ Generating extractive answer")
865
- llm_answer = generate_extractive_answer(message.strip(), retrieved_docs)
 
 
 
 
 
866
 
867
  if not llm_answer:
868
- logger.error(f" βœ— Extractive answer generation failed")
869
  yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
870
  return
871
 
 
635
  max_iterations = 0 # Single-shot only for speed
636
 
637
  def call_model(prompt, max_new_tokens, temperature):
638
+ """Generate with DistilGPT2"""
639
  try:
640
+ # Simple, direct prompt - no special formatting
 
 
641
  logger.info(f" β†’ Generating (max_tokens={max_new_tokens})")
642
 
643
  out = llm_client(
644
+ prompt,
645
  max_new_tokens=max_new_tokens,
646
  temperature=temperature,
647
  do_sample=True,
648
  return_full_text=False,
649
+ repetition_penalty=1.3, # Strong penalty against repetition
650
+ no_repeat_ngram_size=2, # Prevent repeating 2-grams
651
+ top_k=40,
652
+ top_p=0.9,
653
  pad_token_id=llm_client.tokenizer.eos_token_id,
654
  eos_token_id=llm_client.tokenizer.eos_token_id
655
  )
 
659
 
660
  generated = out[0].get('generated_text', '').strip()
661
 
 
 
 
 
662
  # Clean up bad patterns
663
  import re
664
+ # Remove nonsensical patterns like "A: B: C:" or single letters
 
665
  generated = re.sub(r'\b[A-Z]:\s*(?=[A-Z]:)', '', generated)
666
+ generated = re.sub(r'^[A-Z]:\s*', '', generated) # Remove leading letters
667
  generated = generated.strip()
668
 
669
  word_count = len(generated.split())
 
674
  logger.error(f" βœ— Error: {e}")
675
  return ''
676
 
677
+ # Simple, natural prompt that DistilGPT2 can handle
678
+ base_prompt = f"""For the question "{query}", here is helpful fashion advice:
679
 
680
+ {context_text[:300]}
 
681
 
682
+ To summarize:"""
683
 
684
+ # DistilGPT2 parameters - lower temperature for more coherent output
685
  if attempt == 1:
686
+ max_new_tokens = 120
687
+ temperature = 0.6
688
  else:
689
+ max_new_tokens = 150
690
+ temperature = 0.65
691
 
692
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
693
  initial_output = call_model(base_prompt, max_new_tokens, temperature)
 
816
  if not retrieved_docs:
817
  return "I couldn't find relevant information to answer your question."
818
 
819
+ # Try LLM generation with multiple attempts
820
+ max_attempts = 2
 
 
 
 
 
 
 
 
821
 
822
+ llm_answer = None
823
+ for attempt in range(1, max_attempts + 1):
824
+ logger.info(f"\n πŸ€– LLM Generation Attempt {attempt}/{max_attempts}")
825
+ llm_answer = generate_llm_answer(query, retrieved_docs, llm_client, attempt)
826
+
827
+ if llm_answer:
828
+ logger.info(f" βœ… LLM answer generated successfully")
829
+ return llm_answer
830
+ else:
831
+ if attempt < max_attempts:
832
+ logger.warning(f" β†’ Attempt {attempt}/{max_attempts} failed, retrying...")
833
+
834
+ logger.error(f" βœ— All {max_attempts} LLM attempts failed")
835
  return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else."
836
 
837
  # ============================================================================
 
858
 
859
  yield f"πŸ’­ Generating answer ({len(retrieved_docs)} sources found)..."
860
 
861
+ # Generate with LLM
862
+ llm_answer = None
863
+ for attempt in range(1, 3):
864
+ logger.info(f"\n πŸ€– LLM Generation Attempt {attempt}/2")
865
+ llm_answer = generate_llm_answer(message.strip(), retrieved_docs, llm_client, attempt)
866
+
867
+ if llm_answer:
868
+ break
869
 
870
  if not llm_answer:
871
+ logger.error(f" βœ— All LLM attempts failed")
872
  yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
873
  return
874