Update app.py
Browse files
app.py
CHANGED
|
@@ -637,7 +637,8 @@ def generate_llm_answer(
|
|
| 637 |
def call_model(prompt, max_new_tokens, temperature):
|
| 638 |
"""Generate with DistilGPT2 - simple and fast"""
|
| 639 |
try:
|
| 640 |
-
|
|
|
|
| 641 |
|
| 642 |
logger.info(f" β Generating (max_tokens={max_new_tokens})")
|
| 643 |
|
|
@@ -647,7 +648,12 @@ def generate_llm_answer(
|
|
| 647 |
temperature=temperature,
|
| 648 |
do_sample=True,
|
| 649 |
return_full_text=False,
|
| 650 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
)
|
| 652 |
|
| 653 |
if not out or not isinstance(out, list) or len(out) == 0:
|
|
@@ -656,8 +662,15 @@ def generate_llm_answer(
|
|
| 656 |
generated = out[0].get('generated_text', '').strip()
|
| 657 |
|
| 658 |
# Remove prompt if present
|
| 659 |
-
if
|
| 660 |
-
generated = generated.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
|
| 662 |
word_count = len(generated.split())
|
| 663 |
logger.info(f" β
Generated {word_count} words")
|
|
@@ -667,20 +680,21 @@ def generate_llm_answer(
|
|
| 667 |
logger.error(f" β Error: {e}")
|
| 668 |
return ''
|
| 669 |
|
| 670 |
-
#
|
| 671 |
-
base_prompt = f"""
|
| 672 |
|
| 673 |
-
|
|
|
|
| 674 |
|
| 675 |
-
|
| 676 |
|
| 677 |
-
# DistilGPT2 parameters
|
| 678 |
if attempt == 1:
|
| 679 |
-
max_new_tokens =
|
| 680 |
-
temperature = 0.
|
| 681 |
else:
|
| 682 |
-
max_new_tokens =
|
| 683 |
-
temperature = 0.
|
| 684 |
|
| 685 |
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...")
|
| 686 |
initial_output = call_model(base_prompt, max_new_tokens, temperature)
|
|
|
|
| 637 |
def call_model(prompt, max_new_tokens, temperature):
|
| 638 |
"""Generate with DistilGPT2 - simple and fast"""
|
| 639 |
try:
|
| 640 |
+
# Better prompt format for DistilGPT2
|
| 641 |
+
formatted_prompt = f"Fashion advice: {prompt}\n\nAnswer:"
|
| 642 |
|
| 643 |
logger.info(f" β Generating (max_tokens={max_new_tokens})")
|
| 644 |
|
|
|
|
| 648 |
temperature=temperature,
|
| 649 |
do_sample=True,
|
| 650 |
return_full_text=False,
|
| 651 |
+
repetition_penalty=1.2, # Prevent repetition
|
| 652 |
+
no_repeat_ngram_size=3, # Prevent repeating 3-grams
|
| 653 |
+
top_k=50,
|
| 654 |
+
top_p=0.95,
|
| 655 |
+
pad_token_id=llm_client.tokenizer.eos_token_id,
|
| 656 |
+
eos_token_id=llm_client.tokenizer.eos_token_id
|
| 657 |
)
|
| 658 |
|
| 659 |
if not out or not isinstance(out, list) or len(out) == 0:
|
|
|
|
| 662 |
generated = out[0].get('generated_text', '').strip()
|
| 663 |
|
| 664 |
# Remove prompt if present
|
| 665 |
+
if "Answer:" in generated:
|
| 666 |
+
generated = generated.split("Answer:")[-1].strip()
|
| 667 |
+
|
| 668 |
+
# Clean up bad patterns
|
| 669 |
+
import re
|
| 670 |
+
# Remove patterns like "A: B: C:" or "I: I: I:"
|
| 671 |
+
generated = re.sub(r'\b([A-Z]):\s*\1:\s*', '', generated)
|
| 672 |
+
generated = re.sub(r'\b[A-Z]:\s*(?=[A-Z]:)', '', generated)
|
| 673 |
+
generated = generated.strip()
|
| 674 |
|
| 675 |
word_count = len(generated.split())
|
| 676 |
logger.info(f" β
Generated {word_count} words")
|
|
|
|
| 680 |
logger.error(f" β Error: {e}")
|
| 681 |
return ''
|
| 682 |
|
| 683 |
+
# Better prompt format with context
|
| 684 |
+
base_prompt = f"""Question: {query}
|
| 685 |
|
| 686 |
+
Context from fashion knowledge base:
|
| 687 |
+
{context_text[:400]}
|
| 688 |
|
| 689 |
+
Based on the above information, here is detailed fashion advice:"""
|
| 690 |
|
| 691 |
+
# DistilGPT2 parameters - adjusted for better quality
|
| 692 |
if attempt == 1:
|
| 693 |
+
max_new_tokens = 150
|
| 694 |
+
temperature = 0.7
|
| 695 |
else:
|
| 696 |
+
max_new_tokens = 200
|
| 697 |
+
temperature = 0.75
|
| 698 |
|
| 699 |
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...")
|
| 700 |
initial_output = call_model(base_prompt, max_new_tokens, temperature)
|