hamxaameer commited on
Commit
9e4bdce
Β·
verified Β·
1 Parent(s): c4e1d4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -637,7 +637,8 @@ def generate_llm_answer(
637
  def call_model(prompt, max_new_tokens, temperature):
638
  """Generate with DistilGPT2 - simple and fast"""
639
  try:
640
- formatted_prompt = f"Q: {prompt}\nA:"
 
641
 
642
  logger.info(f" β†’ Generating (max_tokens={max_new_tokens})")
643
 
@@ -647,7 +648,12 @@ def generate_llm_answer(
647
  temperature=temperature,
648
  do_sample=True,
649
  return_full_text=False,
650
- pad_token_id=llm_client.tokenizer.eos_token_id
 
 
 
 
 
651
  )
652
 
653
  if not out or not isinstance(out, list) or len(out) == 0:
@@ -656,8 +662,15 @@ def generate_llm_answer(
656
  generated = out[0].get('generated_text', '').strip()
657
 
658
  # Remove prompt if present
659
- if formatted_prompt in generated:
660
- generated = generated.replace(formatted_prompt, '').strip()
 
 
 
 
 
 
 
661
 
662
  word_count = len(generated.split())
663
  logger.info(f" βœ… Generated {word_count} words")
@@ -667,20 +680,21 @@ def generate_llm_answer(
667
  logger.error(f" βœ— Error: {e}")
668
  return ''
669
 
670
- # ULTRA-SHORT prompt for speed
671
- base_prompt = f"""Q: {query}
672
 
673
- {context_text[:300]}
 
674
 
675
- A:"""
676
 
677
- # DistilGPT2 parameters
678
  if attempt == 1:
679
- max_new_tokens = 100
680
- temperature = 0.8
681
  else:
682
- max_new_tokens = 120
683
- temperature = 0.9
684
 
685
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
686
  initial_output = call_model(base_prompt, max_new_tokens, temperature)
 
637
  def call_model(prompt, max_new_tokens, temperature):
638
  """Generate with DistilGPT2 - simple and fast"""
639
  try:
640
+ # Better prompt format for DistilGPT2
641
+ formatted_prompt = f"Fashion advice: {prompt}\n\nAnswer:"
642
 
643
  logger.info(f" β†’ Generating (max_tokens={max_new_tokens})")
644
 
 
648
  temperature=temperature,
649
  do_sample=True,
650
  return_full_text=False,
651
+ repetition_penalty=1.2, # Prevent repetition
652
+ no_repeat_ngram_size=3, # Prevent repeating 3-grams
653
+ top_k=50,
654
+ top_p=0.95,
655
+ pad_token_id=llm_client.tokenizer.eos_token_id,
656
+ eos_token_id=llm_client.tokenizer.eos_token_id
657
  )
658
 
659
  if not out or not isinstance(out, list) or len(out) == 0:
 
662
  generated = out[0].get('generated_text', '').strip()
663
 
664
  # Remove prompt if present
665
+ if "Answer:" in generated:
666
+ generated = generated.split("Answer:")[-1].strip()
667
+
668
+ # Clean up bad patterns
669
+ import re
670
+ # Remove patterns like "A: B: C:" or "I: I: I:"
671
+ generated = re.sub(r'\b([A-Z]):\s*\1:\s*', '', generated)
672
+ generated = re.sub(r'\b[A-Z]:\s*(?=[A-Z]:)', '', generated)
673
+ generated = generated.strip()
674
 
675
  word_count = len(generated.split())
676
  logger.info(f" βœ… Generated {word_count} words")
 
680
  logger.error(f" βœ— Error: {e}")
681
  return ''
682
 
683
+ # Better prompt format with context
684
+ base_prompt = f"""Question: {query}
685
 
686
+ Context from fashion knowledge base:
687
+ {context_text[:400]}
688
 
689
+ Based on the above information, here is detailed fashion advice:"""
690
 
691
+ # DistilGPT2 parameters - adjusted for better quality
692
  if attempt == 1:
693
+ max_new_tokens = 150
694
+ temperature = 0.7
695
  else:
696
+ max_new_tokens = 200
697
+ temperature = 0.75
698
 
699
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
700
  initial_output = call_model(base_prompt, max_new_tokens, temperature)