Update app.py
Browse files
app.py
CHANGED
|
@@ -65,11 +65,11 @@ CONFIG = {
|
|
| 65 |
# PHI-2 is optimal for CPU deployment: 2.7B parameters, excellent quality
|
| 66 |
# Can be swapped with Phi-3-mini-4k-instruct if more memory is available
|
| 67 |
LOCAL_PHI_MODEL = os.environ.get("LOCAL_PHI_MODEL", "microsoft/phi-2")
|
| 68 |
-
USE_8BIT_QUANTIZATION =
|
| 69 |
USE_REMOTE_LLM = False
|
| 70 |
|
| 71 |
# Natural flow mode: No word limits, let model decide length
|
| 72 |
-
MAX_CONTEXT_LENGTH =
|
| 73 |
USE_CACHING = True # Cache model outputs for repeated patterns
|
| 74 |
ENABLE_FAST_MODE = False # Allow natural completion, no artificial limits
|
| 75 |
|
|
@@ -122,26 +122,22 @@ def initialize_llm():
|
|
| 122 |
use_fast=True
|
| 123 |
)
|
| 124 |
|
| 125 |
-
#
|
| 126 |
if tokenizer.pad_token is None:
|
| 127 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
model_kwargs = {
|
| 131 |
"trust_remote_code": True,
|
| 132 |
"low_cpu_mem_usage": True,
|
| 133 |
"torch_dtype": torch.float32, # CPU works best with float32
|
|
|
|
| 134 |
}
|
| 135 |
|
| 136 |
-
# Try to use 8-bit quantization if available (requires bitsandbytes)
|
| 137 |
-
if USE_8BIT_QUANTIZATION and device == "cpu":
|
| 138 |
-
try:
|
| 139 |
-
logger.info(" Attempting 8-bit quantization for memory efficiency...")
|
| 140 |
-
model_kwargs["load_in_8bit"] = True
|
| 141 |
-
except Exception as quant_error:
|
| 142 |
-
logger.warning(f" 8-bit quantization unavailable: {quant_error}")
|
| 143 |
-
logger.info(" Falling back to float32 (will use more memory)")
|
| 144 |
-
|
| 145 |
# Load the model with optimization
|
| 146 |
logger.info(" Loading PHI model (this may take 30-60 seconds)...")
|
| 147 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -159,24 +155,19 @@ def initialize_llm():
|
|
| 159 |
# Move to eval mode to disable dropout and save memory
|
| 160 |
model.eval()
|
| 161 |
|
| 162 |
-
#
|
| 163 |
-
|
| 164 |
-
if hasattr(torch, 'compile') and not USE_8BIT_QUANTIZATION:
|
| 165 |
-
logger.info(" Applying torch.compile for faster inference...")
|
| 166 |
-
model = torch.compile(model, mode="reduce-overhead")
|
| 167 |
-
except Exception as compile_error:
|
| 168 |
-
logger.info(f" Torch compile not available or failed: {compile_error}")
|
| 169 |
|
| 170 |
# Create pipeline for generation
|
| 171 |
-
# NOTE: When using accelerate/quantization, do NOT specify device parameter
|
| 172 |
logger.info(" Creating text-generation pipeline...")
|
| 173 |
llm_client = pipeline(
|
| 174 |
"text-generation",
|
| 175 |
model=model,
|
| 176 |
tokenizer=tokenizer,
|
| 177 |
-
max_new_tokens=
|
| 178 |
pad_token_id=tokenizer.eos_token_id,
|
| 179 |
-
|
|
|
|
| 180 |
)
|
| 181 |
|
| 182 |
CONFIG["llm_model"] = LOCAL_PHI_MODEL
|
|
@@ -681,80 +672,71 @@ def generate_llm_answer(
|
|
| 681 |
max_iterations = 0 # Single-shot only for speed
|
| 682 |
|
| 683 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 684 |
-
|
| 685 |
-
logger.info(f" → Prompt length: {len(prompt)} chars")
|
| 686 |
try:
|
| 687 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
out = llm_client(
|
| 689 |
-
|
| 690 |
max_new_tokens=max_new_tokens,
|
| 691 |
temperature=temperature,
|
| 692 |
top_p=top_p,
|
| 693 |
do_sample=True,
|
| 694 |
repetition_penalty=repetition_penalty,
|
| 695 |
num_return_sequences=1,
|
| 696 |
-
|
| 697 |
-
eos_token_id=llm_client.tokenizer.eos_token_id,
|
| 698 |
-
truncation=True,
|
| 699 |
-
return_full_text=False # Only return new generation, not prompt
|
| 700 |
)
|
| 701 |
|
| 702 |
-
logger.info(f" →
|
| 703 |
-
|
| 704 |
-
# Extract generated text from pipeline output
|
| 705 |
-
if isinstance(out, list) and len(out) > 0:
|
| 706 |
-
first_item = out[0]
|
| 707 |
-
if isinstance(first_item, dict):
|
| 708 |
-
generated = first_item.get('generated_text', '')
|
| 709 |
-
else:
|
| 710 |
-
generated = str(first_item)
|
| 711 |
-
else:
|
| 712 |
-
generated = str(out) if out else ''
|
| 713 |
|
| 714 |
-
|
|
|
|
|
|
|
|
|
|
| 715 |
|
| 716 |
-
|
| 717 |
-
if generated and prompt in generated:
|
| 718 |
-
prompt_end = generated.find(prompt) + len(prompt)
|
| 719 |
-
generated = generated[prompt_end:].strip()
|
| 720 |
|
| 721 |
-
#
|
| 722 |
-
|
| 723 |
-
generated = generated
|
| 724 |
|
| 725 |
-
|
|
|
|
| 726 |
|
| 727 |
-
|
|
|
|
| 728 |
|
| 729 |
except Exception as e:
|
| 730 |
-
logger.error(f" ✗
|
| 731 |
import traceback
|
| 732 |
-
logger.error(
|
| 733 |
return ''
|
| 734 |
|
| 735 |
-
#
|
| 736 |
-
|
|
|
|
| 737 |
|
| 738 |
-
|
| 739 |
|
| 740 |
-
|
| 741 |
-
{context_text[:1200]}
|
| 742 |
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
Answer:"""
|
| 746 |
-
|
| 747 |
-
# Natural generation parameters: quality over speed, no artificial limits
|
| 748 |
if attempt == 1:
|
| 749 |
-
temperature = 0.
|
| 750 |
-
max_new_tokens =
|
| 751 |
-
top_p = 0.
|
| 752 |
-
repetition_penalty = 1.
|
| 753 |
else:
|
| 754 |
-
temperature = 0.
|
| 755 |
-
max_new_tokens =
|
| 756 |
-
top_p = 0.
|
| 757 |
-
repetition_penalty = 1.
|
| 758 |
|
| 759 |
logger.info(f" → Starting generation with prompt: {base_prompt[:200]}...")
|
| 760 |
initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
|
|
@@ -787,6 +769,11 @@ Answer:"""
|
|
| 787 |
if word_count >= 20:
|
| 788 |
logger.info(f" ⚠️ Short but acceptable response ({word_count} words)")
|
| 789 |
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 790 |
|
| 791 |
# Otherwise, try iterative continuation to build up to the target
|
| 792 |
accumulated = response
|
|
|
|
| 65 |
# PHI-2 is optimal for CPU deployment: 2.7B parameters, excellent quality
|
| 66 |
# Can be swapped with Phi-3-mini-4k-instruct if more memory is available
|
| 67 |
LOCAL_PHI_MODEL = os.environ.get("LOCAL_PHI_MODEL", "microsoft/phi-2")
|
| 68 |
+
USE_8BIT_QUANTIZATION = False # DISABLED: causes hanging on CPU
|
| 69 |
USE_REMOTE_LLM = False
|
| 70 |
|
| 71 |
# Natural flow mode: No word limits, let model decide length
|
| 72 |
+
MAX_CONTEXT_LENGTH = 400 # Reduced for faster generation
|
| 73 |
USE_CACHING = True # Cache model outputs for repeated patterns
|
| 74 |
ENABLE_FAST_MODE = False # Allow natural completion, no artificial limits
|
| 75 |
|
|
|
|
| 122 |
use_fast=True
|
| 123 |
)
|
| 124 |
|
| 125 |
+
# Configure tokenizer for PHI models
|
| 126 |
if tokenizer.pad_token is None:
|
| 127 |
tokenizer.pad_token = tokenizer.eos_token
|
| 128 |
+
if tokenizer.pad_token_id is None:
|
| 129 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 130 |
|
| 131 |
+
logger.info(f" Tokenizer configured: vocab_size={len(tokenizer)}, eos_token={tokenizer.eos_token}")
|
| 132 |
+
|
| 133 |
+
# Configure model loading for CPU efficiency (NO quantization)
|
| 134 |
model_kwargs = {
|
| 135 |
"trust_remote_code": True,
|
| 136 |
"low_cpu_mem_usage": True,
|
| 137 |
"torch_dtype": torch.float32, # CPU works best with float32
|
| 138 |
+
"device_map": "auto", # Let transformers handle device placement
|
| 139 |
}
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
# Load the model with optimization
|
| 142 |
logger.info(" Loading PHI model (this may take 30-60 seconds)...")
|
| 143 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 155 |
# Move to eval mode to disable dropout and save memory
|
| 156 |
model.eval()
|
| 157 |
|
| 158 |
+
# Skip torch.compile - can cause issues on Hugging Face Spaces
|
| 159 |
+
logger.info(" Model ready for inference")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# Create pipeline for generation
|
|
|
|
| 162 |
logger.info(" Creating text-generation pipeline...")
|
| 163 |
llm_client = pipeline(
|
| 164 |
"text-generation",
|
| 165 |
model=model,
|
| 166 |
tokenizer=tokenizer,
|
| 167 |
+
max_new_tokens=200, # Reduced for faster generation
|
| 168 |
pad_token_id=tokenizer.eos_token_id,
|
| 169 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 170 |
+
device=0 if device == "cuda" else -1 # -1 for CPU
|
| 171 |
)
|
| 172 |
|
| 173 |
CONFIG["llm_model"] = LOCAL_PHI_MODEL
|
|
|
|
| 672 |
max_iterations = 0 # Single-shot only for speed
|
| 673 |
|
| 674 |
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
|
| 675 |
+
"""Optimized for PHI-2 - fast generation on CPU"""
|
|
|
|
| 676 |
try:
|
| 677 |
+
# Simple direct prompt - no fancy formatting
|
| 678 |
+
formatted_prompt = f"{prompt}\n\nAnswer:"
|
| 679 |
+
|
| 680 |
+
logger.info(f" → Calling PHI-2 (tokens={max_new_tokens}, temp={temperature})")
|
| 681 |
+
logger.info(f" → Formatted prompt length: {len(formatted_prompt)} chars")
|
| 682 |
+
|
| 683 |
+
# Call PHI-2 with MINIMAL settings for speed
|
| 684 |
out = llm_client(
|
| 685 |
+
formatted_prompt,
|
| 686 |
max_new_tokens=max_new_tokens,
|
| 687 |
temperature=temperature,
|
| 688 |
top_p=top_p,
|
| 689 |
do_sample=True,
|
| 690 |
repetition_penalty=repetition_penalty,
|
| 691 |
num_return_sequences=1,
|
| 692 |
+
return_full_text=False
|
|
|
|
|
|
|
|
|
|
| 693 |
)
|
| 694 |
|
| 695 |
+
logger.info(f" → Generation completed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
+
# Extract text quickly
|
| 698 |
+
if not out or not isinstance(out, list) or len(out) == 0:
|
| 699 |
+
logger.warning(" ✗ Empty output")
|
| 700 |
+
return ''
|
| 701 |
|
| 702 |
+
generated = out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0])
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
+
# Quick cleanup
|
| 705 |
+
for remove in [formatted_prompt, 'Answer:', 'Response:', 'Output:']:
|
| 706 |
+
generated = generated.replace(remove, '')
|
| 707 |
|
| 708 |
+
generated = generated.strip()
|
| 709 |
+
word_count = len(generated.split())
|
| 710 |
|
| 711 |
+
logger.info(f" ✅ Generated {word_count} words")
|
| 712 |
+
return generated
|
| 713 |
|
| 714 |
except Exception as e:
|
| 715 |
+
logger.error(f" ✗ Error: {e}")
|
| 716 |
import traceback
|
| 717 |
+
logger.error(traceback.format_exc())
|
| 718 |
return ''
|
| 719 |
|
| 720 |
+
# PHI-2 optimized: VERY short prompt for fast generation
|
| 721 |
+
# Long prompts cause slow/hanging generation on CPU
|
| 722 |
+
base_prompt = f"""Question: {query}
|
| 723 |
|
| 724 |
+
Context: {context_text[:400]}
|
| 725 |
|
| 726 |
+
Answer with fashion advice:"""
|
|
|
|
| 727 |
|
| 728 |
+
# PHI-2 generation parameters: SPEED OPTIMIZED for CPU
|
| 729 |
+
# Shorter outputs = faster generation on Hugging Face Spaces
|
|
|
|
|
|
|
|
|
|
| 730 |
if attempt == 1:
|
| 731 |
+
temperature = 0.7
|
| 732 |
+
max_new_tokens = 200 # Reduced for faster generation
|
| 733 |
+
top_p = 0.9
|
| 734 |
+
repetition_penalty = 1.15 # Higher to prevent loops
|
| 735 |
else:
|
| 736 |
+
temperature = 0.75
|
| 737 |
+
max_new_tokens = 250
|
| 738 |
+
top_p = 0.92
|
| 739 |
+
repetition_penalty = 1.2
|
| 740 |
|
| 741 |
logger.info(f" → Starting generation with prompt: {base_prompt[:200]}...")
|
| 742 |
initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
|
|
|
|
| 769 |
if word_count >= 20:
|
| 770 |
logger.info(f" ⚠️ Short but acceptable response ({word_count} words)")
|
| 771 |
return response
|
| 772 |
+
|
| 773 |
+
# Ultra permissive: accept ANYTHING with 10+ words to show something
|
| 774 |
+
if word_count >= 10:
|
| 775 |
+
logger.info(f" ⚠️ Very short response ({word_count} words) but accepting")
|
| 776 |
+
return response
|
| 777 |
|
| 778 |
# Otherwise, try iterative continuation to build up to the target
|
| 779 |
accumulated = response
|