Update app.py
Browse files
app.py
CHANGED
|
@@ -19,6 +19,10 @@ from langchain_community.vectorstores import FAISS
|
|
| 19 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 20 |
from langchain.schema import Document
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# Setup logging
|
| 23 |
logging.basicConfig(level=logging.INFO)
|
| 24 |
logger = logging.getLogger(__name__)
|
|
@@ -27,10 +31,22 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
torch.set_num_threads(4) # Limit threads for better CPU performance
|
| 28 |
torch.set_grad_enabled(False) # Disable gradients (inference only)
|
| 29 |
|
| 30 |
-
# Suppress specific warnings
|
| 31 |
import warnings
|
| 32 |
warnings.filterwarnings("ignore", message="MatMul8bitLt")
|
| 33 |
warnings.filterwarnings("ignore", message="torch_dtype")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# ============================================================================
|
| 36 |
# CONFIGURATION
|
|
@@ -666,8 +682,9 @@ def generate_llm_answer(
|
|
| 666 |
|
| 667 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 668 |
logger.info(f" β PHI model call (temp={temperature}, max_new_tokens={max_new_tokens})")
|
|
|
|
| 669 |
try:
|
| 670 |
-
# Call local PHI model with
|
| 671 |
out = llm_client(
|
| 672 |
prompt,
|
| 673 |
max_new_tokens=max_new_tokens,
|
|
@@ -678,26 +695,41 @@ def generate_llm_answer(
|
|
| 678 |
num_return_sequences=1,
|
| 679 |
pad_token_id=llm_client.tokenizer.eos_token_id,
|
| 680 |
eos_token_id=llm_client.tokenizer.eos_token_id,
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
use_cache=True # Use KV cache for speed
|
| 684 |
)
|
| 685 |
|
|
|
|
|
|
|
| 686 |
# Extract generated text from pipeline output
|
| 687 |
-
if isinstance(out, list) and out:
|
| 688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
else:
|
| 690 |
-
generated = str(out)
|
|
|
|
|
|
|
| 691 |
|
| 692 |
-
# PHI models
|
| 693 |
-
if prompt in generated:
|
| 694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
generated = generated[len(prompt):].strip()
|
| 696 |
|
| 697 |
-
|
|
|
|
|
|
|
| 698 |
|
| 699 |
except Exception as e:
|
| 700 |
logger.error(f" β PHI model call error: {e}")
|
|
|
|
|
|
|
| 701 |
return ''
|
| 702 |
|
| 703 |
# Natural prompt: let the model generate complete, flowing responses
|
|
@@ -724,16 +756,20 @@ Answer:"""
|
|
| 724 |
top_p = 0.93
|
| 725 |
repetition_penalty = 1.10
|
| 726 |
|
|
|
|
| 727 |
initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
|
| 728 |
response = (initial_output or '').strip()
|
| 729 |
|
| 730 |
# Basic sanity checks
|
| 731 |
if not response:
|
| 732 |
-
logger.warning(" β Empty initial response")
|
|
|
|
| 733 |
response = ''
|
| 734 |
|
| 735 |
words = response.split()
|
| 736 |
word_count = len(words)
|
|
|
|
|
|
|
| 737 |
|
| 738 |
# Natural mode: accept ANY response length - let model decide
|
| 739 |
# No truncation, no artificial limits
|
|
@@ -746,6 +782,11 @@ Answer:"""
|
|
| 746 |
if word_count >= 50:
|
| 747 |
logger.info(f" β
Accepted natural response ({word_count} words)")
|
| 748 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
|
| 750 |
# Otherwise, try iterative continuation to build up to the target
|
| 751 |
accumulated = response
|
|
|
|
| 19 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 20 |
from langchain.schema import Document
|
| 21 |
|
| 22 |
+
# Suppress transformers warnings about generation flags
|
| 23 |
+
import os
|
| 24 |
+
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
|
| 25 |
+
|
| 26 |
# Setup logging
|
| 27 |
logging.basicConfig(level=logging.INFO)
|
| 28 |
logger = logging.getLogger(__name__)
|
|
|
|
| 31 |
torch.set_num_threads(4) # Limit threads for better CPU performance
|
| 32 |
torch.set_grad_enabled(False) # Disable gradients (inference only)
|
| 33 |
|
| 34 |
+
# Suppress specific warnings and asyncio issues
|
| 35 |
import warnings
|
| 36 |
warnings.filterwarnings("ignore", message="MatMul8bitLt")
|
| 37 |
warnings.filterwarnings("ignore", message="torch_dtype")
|
| 38 |
+
warnings.filterwarnings("ignore", message="Invalid file descriptor")
|
| 39 |
+
warnings.filterwarnings("ignore", message="generation flags")
|
| 40 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 41 |
+
|
| 42 |
+
# Fix asyncio file descriptor warnings
|
| 43 |
+
import asyncio
|
| 44 |
+
import sys
|
| 45 |
+
if sys.platform == 'linux':
|
| 46 |
+
try:
|
| 47 |
+
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
|
| 48 |
+
except:
|
| 49 |
+
pass
|
| 50 |
|
| 51 |
# ============================================================================
|
| 52 |
# CONFIGURATION
|
|
|
|
| 682 |
|
| 683 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 684 |
logger.info(f" β PHI model call (temp={temperature}, max_new_tokens={max_new_tokens})")
|
| 685 |
+
logger.info(f" β Prompt length: {len(prompt)} chars")
|
| 686 |
try:
|
| 687 |
+
# Call local PHI model with optimized parameters
|
| 688 |
out = llm_client(
|
| 689 |
prompt,
|
| 690 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 695 |
num_return_sequences=1,
|
| 696 |
pad_token_id=llm_client.tokenizer.eos_token_id,
|
| 697 |
eos_token_id=llm_client.tokenizer.eos_token_id,
|
| 698 |
+
truncation=True,
|
| 699 |
+
return_full_text=False # Only return new generation, not prompt
|
|
|
|
| 700 |
)
|
| 701 |
|
| 702 |
+
logger.info(f" β Raw output type: {type(out)}")
|
| 703 |
+
|
| 704 |
# Extract generated text from pipeline output
|
| 705 |
+
if isinstance(out, list) and len(out) > 0:
|
| 706 |
+
first_item = out[0]
|
| 707 |
+
if isinstance(first_item, dict):
|
| 708 |
+
generated = first_item.get('generated_text', '')
|
| 709 |
+
else:
|
| 710 |
+
generated = str(first_item)
|
| 711 |
else:
|
| 712 |
+
generated = str(out) if out else ''
|
| 713 |
+
|
| 714 |
+
logger.info(f" β Generated length before cleanup: {len(generated)} chars")
|
| 715 |
|
| 716 |
+
# PHI models may still include prompt, remove it
|
| 717 |
+
if generated and prompt in generated:
|
| 718 |
+
prompt_end = generated.find(prompt) + len(prompt)
|
| 719 |
+
generated = generated[prompt_end:].strip()
|
| 720 |
+
|
| 721 |
+
# Additional cleanup: remove any leading prompt fragments
|
| 722 |
+
if generated and generated.startswith(prompt[:50]):
|
| 723 |
generated = generated[len(prompt):].strip()
|
| 724 |
|
| 725 |
+
logger.info(f" β Final generated length: {len(generated)} chars, words: {len(generated.split())}")
|
| 726 |
+
|
| 727 |
+
return generated.strip()
|
| 728 |
|
| 729 |
except Exception as e:
|
| 730 |
logger.error(f" β PHI model call error: {e}")
|
| 731 |
+
import traceback
|
| 732 |
+
logger.error(f" β Traceback: {traceback.format_exc()}")
|
| 733 |
return ''
|
| 734 |
|
| 735 |
# Natural prompt: let the model generate complete, flowing responses
|
|
|
|
| 756 |
top_p = 0.93
|
| 757 |
repetition_penalty = 1.10
|
| 758 |
|
| 759 |
+
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...")
|
| 760 |
initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
|
| 761 |
response = (initial_output or '').strip()
|
| 762 |
|
| 763 |
# Basic sanity checks
|
| 764 |
if not response:
|
| 765 |
+
logger.warning(" β Empty initial response - model may not be generating")
|
| 766 |
+
logger.warning(f" β Prompt was: {base_prompt[:300]}")
|
| 767 |
response = ''
|
| 768 |
|
| 769 |
words = response.split()
|
| 770 |
word_count = len(words)
|
| 771 |
+
|
| 772 |
+
logger.info(f" β Initial response: {word_count} words")
|
| 773 |
|
| 774 |
# Natural mode: accept ANY response length - let model decide
|
| 775 |
# No truncation, no artificial limits
|
|
|
|
| 782 |
if word_count >= 50:
|
| 783 |
logger.info(f" β
Accepted natural response ({word_count} words)")
|
| 784 |
return response
|
| 785 |
+
|
| 786 |
+
# Very permissive: accept anything with 20+ words
|
| 787 |
+
if word_count >= 20:
|
| 788 |
+
logger.info(f" β οΈ Short but acceptable response ({word_count} words)")
|
| 789 |
+
return response
|
| 790 |
|
| 791 |
# Otherwise, try iterative continuation to build up to the target
|
| 792 |
accumulated = response
|