hamxaameer commited on
Commit
21776b6
Β·
verified Β·
1 Parent(s): d85f59c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -13
app.py CHANGED
@@ -19,6 +19,10 @@ from langchain_community.vectorstores import FAISS
19
  from langchain_community.embeddings import HuggingFaceEmbeddings
20
  from langchain.schema import Document
21
 
 
 
 
 
22
  # Setup logging
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
@@ -27,10 +31,22 @@ logger = logging.getLogger(__name__)
27
  torch.set_num_threads(4) # Limit threads for better CPU performance
28
  torch.set_grad_enabled(False) # Disable gradients (inference only)
29
 
30
- # Suppress specific warnings
31
  import warnings
32
  warnings.filterwarnings("ignore", message="MatMul8bitLt")
33
  warnings.filterwarnings("ignore", message="torch_dtype")
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # ============================================================================
36
  # CONFIGURATION
@@ -666,8 +682,9 @@ def generate_llm_answer(
666
 
667
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
668
  logger.info(f" β†’ PHI model call (temp={temperature}, max_new_tokens={max_new_tokens})")
 
669
  try:
670
- # Call local PHI model with speed optimizations
671
  out = llm_client(
672
  prompt,
673
  max_new_tokens=max_new_tokens,
@@ -678,26 +695,41 @@ def generate_llm_answer(
678
  num_return_sequences=1,
679
  pad_token_id=llm_client.tokenizer.eos_token_id,
680
  eos_token_id=llm_client.tokenizer.eos_token_id,
681
- num_beams=1, # Greedy/sampling is faster than beam search
682
- early_stopping=True, # Stop as soon as EOS is generated
683
- use_cache=True # Use KV cache for speed
684
  )
685
 
 
 
686
  # Extract generated text from pipeline output
687
- if isinstance(out, list) and out:
688
- generated = out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0])
 
 
 
 
689
  else:
690
- generated = str(out)
 
 
691
 
692
- # PHI models return prompt + completion, extract only new text
693
- if prompt in generated:
694
- # Remove the prompt from the output
 
 
 
 
695
  generated = generated[len(prompt):].strip()
696
 
697
- return generated
 
 
698
 
699
  except Exception as e:
700
  logger.error(f" βœ— PHI model call error: {e}")
 
 
701
  return ''
702
 
703
  # Natural prompt: let the model generate complete, flowing responses
@@ -724,16 +756,20 @@ Answer:"""
724
  top_p = 0.93
725
  repetition_penalty = 1.10
726
 
 
727
  initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
728
  response = (initial_output or '').strip()
729
 
730
  # Basic sanity checks
731
  if not response:
732
- logger.warning(" βœ— Empty initial response")
 
733
  response = ''
734
 
735
  words = response.split()
736
  word_count = len(words)
 
 
737
 
738
  # Natural mode: accept ANY response length - let model decide
739
  # No truncation, no artificial limits
@@ -746,6 +782,11 @@ Answer:"""
746
  if word_count >= 50:
747
  logger.info(f" βœ… Accepted natural response ({word_count} words)")
748
  return response
 
 
 
 
 
749
 
750
  # Otherwise, try iterative continuation to build up to the target
751
  accumulated = response
 
19
  from langchain_community.embeddings import HuggingFaceEmbeddings
20
  from langchain.schema import Document
21
 
22
+ # Suppress transformers warnings about generation flags
23
+ import os
24
+ os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
25
+
26
  # Setup logging
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
 
31
  torch.set_num_threads(4) # Limit threads for better CPU performance
32
  torch.set_grad_enabled(False) # Disable gradients (inference only)
33
 
34
+ # Suppress specific warnings and asyncio issues
35
  import warnings
36
  warnings.filterwarnings("ignore", message="MatMul8bitLt")
37
  warnings.filterwarnings("ignore", message="torch_dtype")
38
+ warnings.filterwarnings("ignore", message="Invalid file descriptor")
39
+ warnings.filterwarnings("ignore", message="generation flags")
40
+ warnings.filterwarnings("ignore", category=UserWarning)
41
+
42
+ # Fix asyncio file descriptor warnings
43
+ import asyncio
44
+ import sys
45
+ if sys.platform == 'linux':
46
+ try:
47
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
48
+ except:
49
+ pass
50
 
51
  # ============================================================================
52
  # CONFIGURATION
 
682
 
683
  def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
684
  logger.info(f" β†’ PHI model call (temp={temperature}, max_new_tokens={max_new_tokens})")
685
+ logger.info(f" β†’ Prompt length: {len(prompt)} chars")
686
  try:
687
+ # Call local PHI model with optimized parameters
688
  out = llm_client(
689
  prompt,
690
  max_new_tokens=max_new_tokens,
 
695
  num_return_sequences=1,
696
  pad_token_id=llm_client.tokenizer.eos_token_id,
697
  eos_token_id=llm_client.tokenizer.eos_token_id,
698
+ truncation=True,
699
+ return_full_text=False # Only return new generation, not prompt
 
700
  )
701
 
702
+ logger.info(f" β†’ Raw output type: {type(out)}")
703
+
704
  # Extract generated text from pipeline output
705
+ if isinstance(out, list) and len(out) > 0:
706
+ first_item = out[0]
707
+ if isinstance(first_item, dict):
708
+ generated = first_item.get('generated_text', '')
709
+ else:
710
+ generated = str(first_item)
711
  else:
712
+ generated = str(out) if out else ''
713
+
714
+ logger.info(f" β†’ Generated length before cleanup: {len(generated)} chars")
715
 
716
+ # PHI models may still include prompt, remove it
717
+ if generated and prompt in generated:
718
+ prompt_end = generated.find(prompt) + len(prompt)
719
+ generated = generated[prompt_end:].strip()
720
+
721
+ # Additional cleanup: remove any leading prompt fragments
722
+ if generated and generated.startswith(prompt[:50]):
723
  generated = generated[len(prompt):].strip()
724
 
725
+ logger.info(f" β†’ Final generated length: {len(generated)} chars, words: {len(generated.split())}")
726
+
727
+ return generated.strip()
728
 
729
  except Exception as e:
730
  logger.error(f" βœ— PHI model call error: {e}")
731
+ import traceback
732
+ logger.error(f" βœ— Traceback: {traceback.format_exc()}")
733
  return ''
734
 
735
  # Natural prompt: let the model generate complete, flowing responses
 
756
  top_p = 0.93
757
  repetition_penalty = 1.10
758
 
759
+ logger.info(f" β†’ Starting generation with prompt: {base_prompt[:200]}...")
760
  initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
761
  response = (initial_output or '').strip()
762
 
763
  # Basic sanity checks
764
  if not response:
765
+ logger.warning(" βœ— Empty initial response - model may not be generating")
766
+ logger.warning(f" βœ— Prompt was: {base_prompt[:300]}")
767
  response = ''
768
 
769
  words = response.split()
770
  word_count = len(words)
771
+
772
+ logger.info(f" β†’ Initial response: {word_count} words")
773
 
774
  # Natural mode: accept ANY response length - let model decide
775
  # No truncation, no artificial limits
 
782
  if word_count >= 50:
783
  logger.info(f" βœ… Accepted natural response ({word_count} words)")
784
  return response
785
+
786
+ # Very permissive: accept anything with 20+ words
787
+ if word_count >= 20:
788
+ logger.info(f" ⚠️ Short but acceptable response ({word_count} words)")
789
+ return response
790
 
791
  # Otherwise, try iterative continuation to build up to the target
792
  accumulated = response