hamxaameer commited on
Commit
4b54bb9
Β·
verified Β·
1 Parent(s): 7a3d769

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -72
app.py CHANGED
@@ -158,17 +158,50 @@ def initialize_llm():
158
  # Skip torch.compile - can cause issues on Hugging Face Spaces
159
  logger.info(" Model ready for inference")
160
 
161
- # Create pipeline for generation
162
- # CRITICAL: Do NOT specify device when using device_map="auto"
163
- logger.info(" Creating text-generation pipeline...")
164
- llm_client = pipeline(
165
- "text-generation",
166
- model=model,
167
- tokenizer=tokenizer,
168
- max_new_tokens=200, # Reduced for faster generation
169
- pad_token_id=tokenizer.eos_token_id,
170
- eos_token_id=tokenizer.eos_token_id
171
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  CONFIG["llm_model"] = LOCAL_PHI_MODEL
174
  CONFIG["model_type"] = "phi_local"
@@ -655,12 +688,12 @@ def generate_llm_answer(
655
  scored_docs.sort(key=lambda x: x[1], reverse=True)
656
  top_docs = [doc[0] for doc in scored_docs[:8]]
657
 
658
- # Natural flow: use rich context from top documents
659
  context_parts = []
660
- for doc in top_docs[:6]: # Use 6 best documents
661
  content = doc.page_content.strip()
662
- if len(content) > 500: # Keep more content
663
- content = content[:500] + "..."
664
  context_parts.append(content)
665
 
666
  context_text = "\n\n".join(context_parts)
@@ -672,71 +705,90 @@ def generate_llm_answer(
672
  max_iterations = 0 # Single-shot only for speed
673
 
674
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
675
- """Optimized for PHI-2 - fast generation on CPU"""
676
- try:
677
- # Simple direct prompt - no fancy formatting
678
- formatted_prompt = f"{prompt}\n\nAnswer:"
679
-
680
- logger.info(f" β†’ Calling PHI-2 (tokens={max_new_tokens}, temp={temperature})")
681
- logger.info(f" β†’ Formatted prompt length: {len(formatted_prompt)} chars")
682
-
683
- # Call PHI-2 with MINIMAL settings for speed
684
- out = llm_client(
685
- formatted_prompt,
686
- max_new_tokens=max_new_tokens,
687
- temperature=temperature,
688
- top_p=top_p,
689
- do_sample=True,
690
- repetition_penalty=repetition_penalty,
691
- num_return_sequences=1,
692
- return_full_text=False
693
- )
694
-
695
- logger.info(f" β†’ Generation completed")
696
-
697
- # Extract text quickly
698
- if not out or not isinstance(out, list) or len(out) == 0:
699
- logger.warning(" βœ— Empty output")
700
- return ''
701
-
702
- generated = out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0])
703
-
704
- # Quick cleanup
705
- for remove in [formatted_prompt, 'Answer:', 'Response:', 'Output:']:
706
- generated = generated.replace(remove, '')
707
-
708
- generated = generated.strip()
709
- word_count = len(generated.split())
710
-
711
- logger.info(f" βœ… Generated {word_count} words")
712
- return generated
713
-
714
- except Exception as e:
715
- logger.error(f" βœ— Error: {e}")
716
- import traceback
717
- logger.error(traceback.format_exc())
 
 
 
 
 
 
 
 
718
  return ''
 
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
- # PHI-2 optimized: VERY short prompt for fast generation
721
- # Long prompts cause slow/hanging generation on CPU
722
- base_prompt = f"""Question: {query}
723
 
724
- Context: {context_text[:400]}
725
 
726
- Answer with fashion advice:"""
727
 
728
- # PHI-2 generation parameters: SPEED OPTIMIZED for CPU
729
- # Shorter outputs = faster generation on Hugging Face Spaces
730
  if attempt == 1:
 
 
 
 
 
731
  temperature = 0.7
732
- max_new_tokens = 200 # Reduced for faster generation
733
  top_p = 0.9
734
- repetition_penalty = 1.15 # Higher to prevent loops
735
- else:
736
- temperature = 0.75
737
- max_new_tokens = 250
738
- top_p = 0.92
739
- repetition_penalty = 1.2
740
 
741
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
742
  initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
@@ -774,6 +826,11 @@ Answer with fashion advice:"""
774
  if word_count >= 10:
775
  logger.info(f" ⚠️ Very short response ({word_count} words) but accepting")
776
  return response
 
 
 
 
 
777
 
778
  # Otherwise, try iterative continuation to build up to the target
779
  accumulated = response
 
158
  # Skip torch.compile - can cause issues on Hugging Face Spaces
159
  logger.info(" Model ready for inference")
160
 
161
+ # Store model and tokenizer directly for faster inference
162
+ # We'll use direct generation instead of pipeline
163
+ logger.info(" Configuring direct model inference (faster than pipeline)...")
164
+
165
+ # Create a simple wrapper that mimics pipeline interface
166
+ class FastPHIGenerator:
167
+ def __init__(self, model, tokenizer):
168
+ self.model = model
169
+ self.tokenizer = tokenizer
170
+
171
+ def __call__(self, prompt, max_new_tokens=150, temperature=0.7, top_p=0.9,
172
+ do_sample=True, repetition_penalty=1.1, **kwargs):
173
+ """Direct generation - faster than pipeline"""
174
+ try:
175
+ # Tokenize
176
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
177
+ input_ids = inputs["input_ids"]
178
+
179
+ # Generate
180
+ with torch.no_grad():
181
+ outputs = self.model.generate(
182
+ input_ids,
183
+ max_new_tokens=max_new_tokens,
184
+ temperature=temperature,
185
+ top_p=top_p,
186
+ do_sample=do_sample,
187
+ repetition_penalty=repetition_penalty,
188
+ pad_token_id=self.tokenizer.eos_token_id,
189
+ eos_token_id=self.tokenizer.eos_token_id,
190
+ early_stopping=True
191
+ )
192
+
193
+ # Decode only the new tokens
194
+ generated_ids = outputs[0][input_ids.shape[1]:]
195
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
196
+
197
+ return [{"generated_text": generated_text}]
198
+
199
+ except Exception as e:
200
+ logger.error(f"Generation error: {e}")
201
+ return [{"generated_text": ""}]
202
+
203
+ llm_client = FastPHIGenerator(model, tokenizer)
204
+ llm_client.tokenizer = tokenizer # Add tokenizer reference for compatibility
205
 
206
  CONFIG["llm_model"] = LOCAL_PHI_MODEL
207
  CONFIG["model_type"] = "phi_local"
 
688
  scored_docs.sort(key=lambda x: x[1], reverse=True)
689
  top_docs = [doc[0] for doc in scored_docs[:8]]
690
 
691
+ # Minimal context for speed
692
  context_parts = []
693
+ for doc in top_docs[:3]: # Only 3 best documents
694
  content = doc.page_content.strip()
695
+ if len(content) > 200: # Much shorter snippets
696
+ content = content[:200] + "..."
697
  context_parts.append(content)
698
 
699
  context_text = "\n\n".join(context_parts)
 
705
  max_iterations = 0 # Single-shot only for speed
706
 
707
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
708
+ """Optimized for PHI-2 with timeout protection"""
709
+ import threading
710
+
711
+ result_container = {'output': None, 'error': None}
712
+
713
+ def generate_with_timeout():
714
+ try:
715
+ # Ultra-simple prompt
716
+ formatted_prompt = f"{prompt}\n\nAnswer:"
717
+
718
+ logger.info(f" β†’ PHI-2 generating (max_tokens={max_new_tokens})")
719
+
720
+ # MINIMAL settings - most restrictive for speed
721
+ out = llm_client(
722
+ formatted_prompt,
723
+ max_new_tokens=max_new_tokens,
724
+ temperature=temperature,
725
+ top_p=top_p,
726
+ do_sample=False, # Greedy decoding for speed
727
+ repetition_penalty=repetition_penalty,
728
+ num_return_sequences=1,
729
+ return_full_text=False,
730
+ early_stopping=True
731
+ )
732
+
733
+ result_container['output'] = out
734
+ logger.info(f" βœ“ Generation done")
735
+
736
+ except Exception as e:
737
+ result_container['error'] = str(e)
738
+ logger.error(f" βœ— Generation error: {e}")
739
+
740
+ # Run generation in thread with timeout
741
+ gen_thread = threading.Thread(target=generate_with_timeout)
742
+ gen_thread.daemon = True
743
+ gen_thread.start()
744
+ gen_thread.join(timeout=45) # 45 second timeout
745
+
746
+ if gen_thread.is_alive():
747
+ logger.error(" βœ— Generation TIMEOUT after 45s")
748
+ return ''
749
+
750
+ if result_container['error']:
751
+ logger.error(f" βœ— Error: {result_container['error']}")
752
+ return ''
753
+
754
+ out = result_container['output']
755
+
756
+ # Extract text quickly
757
+ if not out or not isinstance(out, list) or len(out) == 0:
758
+ logger.warning(" βœ— Empty output")
759
  return ''
760
+
761
+ generated = out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0])
762
+
763
+ # Quick cleanup
764
+ formatted_prompt = f"{prompt}\n\nAnswer:"
765
+ for remove in [formatted_prompt, 'Answer:', 'Response:', 'Output:']:
766
+ generated = generated.replace(remove, '')
767
+
768
+ generated = generated.strip()
769
+ word_count = len(generated.split())
770
+
771
+ logger.info(f" βœ… Generated {word_count} words")
772
+ return generated
773
 
774
+ # ULTRA-SHORT prompt for speed
775
+ base_prompt = f"""Q: {query}
 
776
 
777
+ {context_text[:300]}
778
 
779
+ A:"""
780
 
781
+ # AGGRESSIVE speed optimization
 
782
  if attempt == 1:
783
+ temperature = 0.6 # Lower = faster
784
+ max_new_tokens = 150 # Much shorter
785
+ top_p = 0.85
786
+ repetition_penalty = 1.2
787
+ else:
788
  temperature = 0.7
789
+ max_new_tokens = 180
790
  top_p = 0.9
791
+ repetition_penalty = 1.25
 
 
 
 
 
792
 
793
  logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
794
  initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
 
826
  if word_count >= 10:
827
  logger.info(f" ⚠️ Very short response ({word_count} words) but accepting")
828
  return response
829
+
830
+ # EMERGENCY: accept even 5+ words if that's all we get
831
+ if word_count >= 5:
832
+ logger.info(f" ⚠️ EMERGENCY: Accepting tiny response ({word_count} words)")
833
+ return response
834
 
835
  # Otherwise, try iterative continuation to build up to the target
836
  accumulated = response