Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,7 +29,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
| 29 |
from langchain_core.runnables import Runnable
|
| 30 |
from langchain_core.runnables.utils import Input, Output
|
| 31 |
|
| 32 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
|
| 33 |
import torch
|
| 34 |
|
| 35 |
load_dotenv(".env")
|
|
@@ -283,7 +283,7 @@ Your goal is to be an educational partner who empowers students to succeed throu
|
|
| 283 |
|
| 284 |
# --- Updated LLM Class with Phi-3-mini and TextIteratorStreamer ---
|
| 285 |
class Phi3MiniEducationalLLM(Runnable):
|
| 286 |
-
"""LLM class optimized for Microsoft Phi-3-mini-4k-instruct with TextIteratorStreamer"""
|
| 287 |
|
| 288 |
def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct", use_4bit: bool = False):
|
| 289 |
super().__init__()
|
|
@@ -318,7 +318,9 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 318 |
torch_dtype=torch.float16,
|
| 319 |
trust_remote_code=True,
|
| 320 |
low_cpu_mem_usage=True,
|
| 321 |
-
token=hf_token
|
|
|
|
|
|
|
| 322 |
)
|
| 323 |
else:
|
| 324 |
self._load_optimized_model(model_path)
|
|
@@ -340,7 +342,7 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 340 |
self.streamer = None
|
| 341 |
|
| 342 |
def _load_optimized_model(self, model_path: str):
|
| 343 |
-
"""Optimized model loading for Phi-3-mini."""
|
| 344 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 345 |
model_path,
|
| 346 |
torch_dtype=torch.float16, # Use float16 to save memory
|
|
@@ -348,7 +350,8 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 348 |
trust_remote_code=True,
|
| 349 |
low_cpu_mem_usage=True,
|
| 350 |
token=hf_token,
|
| 351 |
-
|
|
|
|
| 352 |
)
|
| 353 |
|
| 354 |
def _format_chat_template(self, prompt: str) -> str:
|
|
@@ -371,7 +374,7 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 371 |
return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
| 372 |
|
| 373 |
def invoke(self, input: Input, config=None) -> Output:
|
| 374 |
-
"""Main invoke method optimized for Phi-3-mini"""
|
| 375 |
start_invoke_time = time.perf_counter()
|
| 376 |
current_time = datetime.now()
|
| 377 |
|
|
@@ -396,6 +399,9 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 396 |
# Move to model device
|
| 397 |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 398 |
|
|
|
|
|
|
|
|
|
|
| 399 |
with torch.no_grad():
|
| 400 |
outputs = self.model.generate(
|
| 401 |
**inputs,
|
|
@@ -407,7 +413,8 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 407 |
repetition_penalty=1.1,
|
| 408 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 409 |
early_stopping=True,
|
| 410 |
-
|
|
|
|
| 411 |
)
|
| 412 |
|
| 413 |
# Decode only new tokens
|
|
@@ -428,7 +435,7 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 428 |
return f"[Error generating response: {str(e)}]"
|
| 429 |
|
| 430 |
def stream_generate(self, input: Input, config=None):
|
| 431 |
-
"""Streaming generation using TextIteratorStreamer"""
|
| 432 |
start_stream_time = time.perf_counter()
|
| 433 |
current_time = datetime.now()
|
| 434 |
logger.info("Starting stream_generate with TextIteratorStreamer...")
|
|
@@ -461,6 +468,9 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 461 |
skip_special_tokens=True
|
| 462 |
)
|
| 463 |
|
|
|
|
|
|
|
|
|
|
| 464 |
# Generation parameters
|
| 465 |
generation_kwargs = {
|
| 466 |
**inputs,
|
|
@@ -472,7 +482,8 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 472 |
"repetition_penalty": 1.1,
|
| 473 |
"pad_token_id": self.tokenizer.eos_token_id,
|
| 474 |
"streamer": streamer,
|
| 475 |
-
"
|
|
|
|
| 476 |
}
|
| 477 |
|
| 478 |
# Start generation in a separate thread
|
|
@@ -794,7 +805,7 @@ mathjax_config = '''
|
|
| 794 |
window.MathJax = {
|
| 795 |
tex: {
|
| 796 |
inlineMath: [['\\\\(', '\\\\)']],
|
| 797 |
-
displayMath: [['
|
| 798 |
packages: {'[+]': ['ams']}
|
| 799 |
},
|
| 800 |
svg: {fontCache: 'global'},
|
|
|
|
| 29 |
from langchain_core.runnables import Runnable
|
| 30 |
from langchain_core.runnables.utils import Input, Output
|
| 31 |
|
| 32 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer, DynamicCache
|
| 33 |
import torch
|
| 34 |
|
| 35 |
load_dotenv(".env")
|
|
|
|
| 283 |
|
| 284 |
# --- Updated LLM Class with Phi-3-mini and TextIteratorStreamer ---
|
| 285 |
class Phi3MiniEducationalLLM(Runnable):
|
| 286 |
+
"""LLM class optimized for Microsoft Phi-3-mini-4k-instruct with TextIteratorStreamer and proper cache handling"""
|
| 287 |
|
| 288 |
def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct", use_4bit: bool = False):
|
| 289 |
super().__init__()
|
|
|
|
| 318 |
torch_dtype=torch.float16,
|
| 319 |
trust_remote_code=True,
|
| 320 |
low_cpu_mem_usage=True,
|
| 321 |
+
token=hf_token,
|
| 322 |
+
# Use eager attention for better compatibility in HF Spaces
|
| 323 |
+
attn_implementation="eager"
|
| 324 |
)
|
| 325 |
else:
|
| 326 |
self._load_optimized_model(model_path)
|
|
|
|
| 342 |
self.streamer = None
|
| 343 |
|
| 344 |
def _load_optimized_model(self, model_path: str):
|
| 345 |
+
"""Optimized model loading for Phi-3-mini with proper cache support."""
|
| 346 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 347 |
model_path,
|
| 348 |
torch_dtype=torch.float16, # Use float16 to save memory
|
|
|
|
| 350 |
trust_remote_code=True,
|
| 351 |
low_cpu_mem_usage=True,
|
| 352 |
token=hf_token,
|
| 353 |
+
# Use eager attention for better compatibility in HF Spaces
|
| 354 |
+
attn_implementation="eager"
|
| 355 |
)
|
| 356 |
|
| 357 |
def _format_chat_template(self, prompt: str) -> str:
|
|
|
|
| 374 |
return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
| 375 |
|
| 376 |
def invoke(self, input: Input, config=None) -> Output:
|
| 377 |
+
"""Main invoke method optimized for Phi-3-mini with proper cache handling"""
|
| 378 |
start_invoke_time = time.perf_counter()
|
| 379 |
current_time = datetime.now()
|
| 380 |
|
|
|
|
| 399 |
# Move to model device
|
| 400 |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 401 |
|
| 402 |
+
# Initialize DynamicCache for proper caching
|
| 403 |
+
past_key_values = DynamicCache()
|
| 404 |
+
|
| 405 |
with torch.no_grad():
|
| 406 |
outputs = self.model.generate(
|
| 407 |
**inputs,
|
|
|
|
| 413 |
repetition_penalty=1.1,
|
| 414 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 415 |
early_stopping=True,
|
| 416 |
+
past_key_values=past_key_values, # Use DynamicCache properly
|
| 417 |
+
use_cache=True # Enable cache for performance
|
| 418 |
)
|
| 419 |
|
| 420 |
# Decode only new tokens
|
|
|
|
| 435 |
return f"[Error generating response: {str(e)}]"
|
| 436 |
|
| 437 |
def stream_generate(self, input: Input, config=None):
|
| 438 |
+
"""Streaming generation using TextIteratorStreamer with proper cache handling"""
|
| 439 |
start_stream_time = time.perf_counter()
|
| 440 |
current_time = datetime.now()
|
| 441 |
logger.info("Starting stream_generate with TextIteratorStreamer...")
|
|
|
|
| 468 |
skip_special_tokens=True
|
| 469 |
)
|
| 470 |
|
| 471 |
+
# Initialize DynamicCache for proper caching
|
| 472 |
+
past_key_values = DynamicCache()
|
| 473 |
+
|
| 474 |
# Generation parameters
|
| 475 |
generation_kwargs = {
|
| 476 |
**inputs,
|
|
|
|
| 482 |
"repetition_penalty": 1.1,
|
| 483 |
"pad_token_id": self.tokenizer.eos_token_id,
|
| 484 |
"streamer": streamer,
|
| 485 |
+
"past_key_values": past_key_values, # Use DynamicCache properly
|
| 486 |
+
"use_cache": True # Enable cache for performance
|
| 487 |
}
|
| 488 |
|
| 489 |
# Start generation in a separate thread
|
|
|
|
| 805 |
window.MathJax = {
|
| 806 |
tex: {
|
| 807 |
inlineMath: [['\\\\(', '\\\\)']],
|
| 808 |
+
displayMath: [[', '], ['\\\\[', '\\\\]']],
|
| 809 |
packages: {'[+]': ['ams']}
|
| 810 |
},
|
| 811 |
svg: {fontCache: 'global'},
|