Update app.py
Browse files
app.py
CHANGED
|
@@ -158,17 +158,50 @@ def initialize_llm():
|
|
| 158 |
# Skip torch.compile - can cause issues on Hugging Face Spaces
|
| 159 |
logger.info(" Model ready for inference")
|
| 160 |
|
| 161 |
-
#
|
| 162 |
-
#
|
| 163 |
-
logger.info("
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
tokenizer
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
CONFIG["llm_model"] = LOCAL_PHI_MODEL
|
| 174 |
CONFIG["model_type"] = "phi_local"
|
|
@@ -655,12 +688,12 @@ def generate_llm_answer(
|
|
| 655 |
scored_docs.sort(key=lambda x: x[1], reverse=True)
|
| 656 |
top_docs = [doc[0] for doc in scored_docs[:8]]
|
| 657 |
|
| 658 |
-
#
|
| 659 |
context_parts = []
|
| 660 |
-
for doc in top_docs[:
|
| 661 |
content = doc.page_content.strip()
|
| 662 |
-
if len(content) >
|
| 663 |
-
content = content[:
|
| 664 |
context_parts.append(content)
|
| 665 |
|
| 666 |
context_text = "\n\n".join(context_parts)
|
|
@@ -672,71 +705,90 @@ def generate_llm_answer(
|
|
| 672 |
max_iterations = 0 # Single-shot only for speed
|
| 673 |
|
| 674 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 675 |
-
"""Optimized for PHI-2
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
return ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
-
#
|
| 721 |
-
|
| 722 |
-
base_prompt = f"""Question: {query}
|
| 723 |
|
| 724 |
-
|
| 725 |
|
| 726 |
-
|
| 727 |
|
| 728 |
-
#
|
| 729 |
-
# Shorter outputs = faster generation on Hugging Face Spaces
|
| 730 |
if attempt == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
temperature = 0.7
|
| 732 |
-
max_new_tokens =
|
| 733 |
top_p = 0.9
|
| 734 |
-
repetition_penalty = 1.
|
| 735 |
-
else:
|
| 736 |
-
temperature = 0.75
|
| 737 |
-
max_new_tokens = 250
|
| 738 |
-
top_p = 0.92
|
| 739 |
-
repetition_penalty = 1.2
|
| 740 |
|
| 741 |
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...")
|
| 742 |
initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
|
|
@@ -774,6 +826,11 @@ Answer with fashion advice:"""
|
|
| 774 |
if word_count >= 10:
|
| 775 |
logger.info(f" β οΈ Very short response ({word_count} words) but accepting")
|
| 776 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
|
| 778 |
# Otherwise, try iterative continuation to build up to the target
|
| 779 |
accumulated = response
|
|
|
|
| 158 |
# Skip torch.compile - can cause issues on Hugging Face Spaces
|
| 159 |
logger.info(" Model ready for inference")
|
| 160 |
|
| 161 |
+
# Store model and tokenizer directly for faster inference
|
| 162 |
+
# We'll use direct generation instead of pipeline
|
| 163 |
+
logger.info(" Configuring direct model inference (faster than pipeline)...")
|
| 164 |
+
|
| 165 |
+
# Create a simple wrapper that mimics pipeline interface
|
| 166 |
+
class FastPHIGenerator:
|
| 167 |
+
def __init__(self, model, tokenizer):
|
| 168 |
+
self.model = model
|
| 169 |
+
self.tokenizer = tokenizer
|
| 170 |
+
|
| 171 |
+
def __call__(self, prompt, max_new_tokens=150, temperature=0.7, top_p=0.9,
|
| 172 |
+
do_sample=True, repetition_penalty=1.1, **kwargs):
|
| 173 |
+
"""Direct generation - faster than pipeline"""
|
| 174 |
+
try:
|
| 175 |
+
# Tokenize
|
| 176 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 177 |
+
input_ids = inputs["input_ids"]
|
| 178 |
+
|
| 179 |
+
# Generate
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
outputs = self.model.generate(
|
| 182 |
+
input_ids,
|
| 183 |
+
max_new_tokens=max_new_tokens,
|
| 184 |
+
temperature=temperature,
|
| 185 |
+
top_p=top_p,
|
| 186 |
+
do_sample=do_sample,
|
| 187 |
+
repetition_penalty=repetition_penalty,
|
| 188 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 189 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 190 |
+
early_stopping=True
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Decode only the new tokens
|
| 194 |
+
generated_ids = outputs[0][input_ids.shape[1]:]
|
| 195 |
+
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 196 |
+
|
| 197 |
+
return [{"generated_text": generated_text}]
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"Generation error: {e}")
|
| 201 |
+
return [{"generated_text": ""}]
|
| 202 |
+
|
| 203 |
+
llm_client = FastPHIGenerator(model, tokenizer)
|
| 204 |
+
llm_client.tokenizer = tokenizer # Add tokenizer reference for compatibility
|
| 205 |
|
| 206 |
CONFIG["llm_model"] = LOCAL_PHI_MODEL
|
| 207 |
CONFIG["model_type"] = "phi_local"
|
|
|
|
| 688 |
scored_docs.sort(key=lambda x: x[1], reverse=True)
|
| 689 |
top_docs = [doc[0] for doc in scored_docs[:8]]
|
| 690 |
|
| 691 |
+
# Minimal context for speed
|
| 692 |
context_parts = []
|
| 693 |
+
for doc in top_docs[:3]: # Only 3 best documents
|
| 694 |
content = doc.page_content.strip()
|
| 695 |
+
if len(content) > 200: # Much shorter snippets
|
| 696 |
+
content = content[:200] + "..."
|
| 697 |
context_parts.append(content)
|
| 698 |
|
| 699 |
context_text = "\n\n".join(context_parts)
|
|
|
|
| 705 |
max_iterations = 0 # Single-shot only for speed
|
| 706 |
|
| 707 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 708 |
+
"""Optimized for PHI-2 with timeout protection"""
|
| 709 |
+
import threading
|
| 710 |
+
|
| 711 |
+
result_container = {'output': None, 'error': None}
|
| 712 |
+
|
| 713 |
+
def generate_with_timeout():
|
| 714 |
+
try:
|
| 715 |
+
# Ultra-simple prompt
|
| 716 |
+
formatted_prompt = f"{prompt}\n\nAnswer:"
|
| 717 |
+
|
| 718 |
+
logger.info(f" β PHI-2 generating (max_tokens={max_new_tokens})")
|
| 719 |
+
|
| 720 |
+
# MINIMAL settings - most restrictive for speed
|
| 721 |
+
out = llm_client(
|
| 722 |
+
formatted_prompt,
|
| 723 |
+
max_new_tokens=max_new_tokens,
|
| 724 |
+
temperature=temperature,
|
| 725 |
+
top_p=top_p,
|
| 726 |
+
do_sample=False, # Greedy decoding for speed
|
| 727 |
+
repetition_penalty=repetition_penalty,
|
| 728 |
+
num_return_sequences=1,
|
| 729 |
+
return_full_text=False,
|
| 730 |
+
early_stopping=True
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
result_container['output'] = out
|
| 734 |
+
logger.info(f" β Generation done")
|
| 735 |
+
|
| 736 |
+
except Exception as e:
|
| 737 |
+
result_container['error'] = str(e)
|
| 738 |
+
logger.error(f" β Generation error: {e}")
|
| 739 |
+
|
| 740 |
+
# Run generation in thread with timeout
|
| 741 |
+
gen_thread = threading.Thread(target=generate_with_timeout)
|
| 742 |
+
gen_thread.daemon = True
|
| 743 |
+
gen_thread.start()
|
| 744 |
+
gen_thread.join(timeout=45) # 45 second timeout
|
| 745 |
+
|
| 746 |
+
if gen_thread.is_alive():
|
| 747 |
+
logger.error(" β Generation TIMEOUT after 45s")
|
| 748 |
+
return ''
|
| 749 |
+
|
| 750 |
+
if result_container['error']:
|
| 751 |
+
logger.error(f" β Error: {result_container['error']}")
|
| 752 |
+
return ''
|
| 753 |
+
|
| 754 |
+
out = result_container['output']
|
| 755 |
+
|
| 756 |
+
# Extract text quickly
|
| 757 |
+
if not out or not isinstance(out, list) or len(out) == 0:
|
| 758 |
+
logger.warning(" β Empty output")
|
| 759 |
return ''
|
| 760 |
+
|
| 761 |
+
generated = out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0])
|
| 762 |
+
|
| 763 |
+
# Quick cleanup
|
| 764 |
+
formatted_prompt = f"{prompt}\n\nAnswer:"
|
| 765 |
+
for remove in [formatted_prompt, 'Answer:', 'Response:', 'Output:']:
|
| 766 |
+
generated = generated.replace(remove, '')
|
| 767 |
+
|
| 768 |
+
generated = generated.strip()
|
| 769 |
+
word_count = len(generated.split())
|
| 770 |
+
|
| 771 |
+
logger.info(f" β
Generated {word_count} words")
|
| 772 |
+
return generated
|
| 773 |
|
| 774 |
+
# ULTRA-SHORT prompt for speed
|
| 775 |
+
base_prompt = f"""Q: {query}
|
|
|
|
| 776 |
|
| 777 |
+
{context_text[:300]}
|
| 778 |
|
| 779 |
+
A:"""
|
| 780 |
|
| 781 |
+
# AGGRESSIVE speed optimization
|
|
|
|
| 782 |
if attempt == 1:
|
| 783 |
+
temperature = 0.6 # Lower = faster
|
| 784 |
+
max_new_tokens = 150 # Much shorter
|
| 785 |
+
top_p = 0.85
|
| 786 |
+
repetition_penalty = 1.2
|
| 787 |
+
else:
|
| 788 |
temperature = 0.7
|
| 789 |
+
max_new_tokens = 180
|
| 790 |
top_p = 0.9
|
| 791 |
+
repetition_penalty = 1.25
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
|
| 793 |
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...")
|
| 794 |
initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
|
|
|
|
| 826 |
if word_count >= 10:
|
| 827 |
logger.info(f" β οΈ Very short response ({word_count} words) but accepting")
|
| 828 |
return response
|
| 829 |
+
|
| 830 |
+
# EMERGENCY: accept even 5+ words if that's all we get
|
| 831 |
+
if word_count >= 5:
|
| 832 |
+
logger.info(f" β οΈ EMERGENCY: Accepting tiny response ({word_count} words)")
|
| 833 |
+
return response
|
| 834 |
|
| 835 |
# Otherwise, try iterative continuation to build up to the target
|
| 836 |
accumulated = response
|