jdesiree commited on
Commit
dd436fe
·
verified ·
1 Parent(s): ef96f77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -29,7 +29,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
29
  from langchain_core.runnables import Runnable
30
  from langchain_core.runnables.utils import Input, Output
31
 
32
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
33
  import torch
34
 
35
  load_dotenv(".env")
@@ -283,7 +283,7 @@ Your goal is to be an educational partner who empowers students to succeed throu
283
 
284
  # --- Updated LLM Class with Phi-3-mini and TextIteratorStreamer ---
285
  class Phi3MiniEducationalLLM(Runnable):
286
- """LLM class optimized for Microsoft Phi-3-mini-4k-instruct with TextIteratorStreamer"""
287
 
288
  def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct", use_4bit: bool = False):
289
  super().__init__()
@@ -318,7 +318,9 @@ class Phi3MiniEducationalLLM(Runnable):
318
  torch_dtype=torch.float16,
319
  trust_remote_code=True,
320
  low_cpu_mem_usage=True,
321
- token=hf_token
 
 
322
  )
323
  else:
324
  self._load_optimized_model(model_path)
@@ -340,7 +342,7 @@ class Phi3MiniEducationalLLM(Runnable):
340
  self.streamer = None
341
 
342
  def _load_optimized_model(self, model_path: str):
343
- """Optimized model loading for Phi-3-mini."""
344
  self.model = AutoModelForCausalLM.from_pretrained(
345
  model_path,
346
  torch_dtype=torch.float16, # Use float16 to save memory
@@ -348,7 +350,8 @@ class Phi3MiniEducationalLLM(Runnable):
348
  trust_remote_code=True,
349
  low_cpu_mem_usage=True,
350
  token=hf_token,
351
- revision="0a67737cc96d2554230f90338b163bc6380a2a85" # Pin revision for security
 
352
  )
353
 
354
  def _format_chat_template(self, prompt: str) -> str:
@@ -371,7 +374,7 @@ class Phi3MiniEducationalLLM(Runnable):
371
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
372
 
373
  def invoke(self, input: Input, config=None) -> Output:
374
- """Main invoke method optimized for Phi-3-mini"""
375
  start_invoke_time = time.perf_counter()
376
  current_time = datetime.now()
377
 
@@ -396,6 +399,9 @@ class Phi3MiniEducationalLLM(Runnable):
396
  # Move to model device
397
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
398
 
 
 
 
399
  with torch.no_grad():
400
  outputs = self.model.generate(
401
  **inputs,
@@ -407,7 +413,8 @@ class Phi3MiniEducationalLLM(Runnable):
407
  repetition_penalty=1.1,
408
  pad_token_id=self.tokenizer.eos_token_id,
409
  early_stopping=True,
410
- use_cache=False # Disable cache to avoid compatibility issues
 
411
  )
412
 
413
  # Decode only new tokens
@@ -428,7 +435,7 @@ class Phi3MiniEducationalLLM(Runnable):
428
  return f"[Error generating response: {str(e)}]"
429
 
430
  def stream_generate(self, input: Input, config=None):
431
- """Streaming generation using TextIteratorStreamer"""
432
  start_stream_time = time.perf_counter()
433
  current_time = datetime.now()
434
  logger.info("Starting stream_generate with TextIteratorStreamer...")
@@ -461,6 +468,9 @@ class Phi3MiniEducationalLLM(Runnable):
461
  skip_special_tokens=True
462
  )
463
 
 
 
 
464
  # Generation parameters
465
  generation_kwargs = {
466
  **inputs,
@@ -472,7 +482,8 @@ class Phi3MiniEducationalLLM(Runnable):
472
  "repetition_penalty": 1.1,
473
  "pad_token_id": self.tokenizer.eos_token_id,
474
  "streamer": streamer,
475
- "use_cache": True
 
476
  }
477
 
478
  # Start generation in a separate thread
@@ -794,7 +805,7 @@ mathjax_config = '''
794
  window.MathJax = {
795
  tex: {
796
  inlineMath: [['\\\\(', '\\\\)']],
797
- displayMath: [['$', '$'], ['\\\\[', '\\\\]']],
798
  packages: {'[+]': ['ams']}
799
  },
800
  svg: {fontCache: 'global'},
 
29
  from langchain_core.runnables import Runnable
30
  from langchain_core.runnables.utils import Input, Output
31
 
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer, DynamicCache
33
  import torch
34
 
35
  load_dotenv(".env")
 
283
 
284
  # --- Updated LLM Class with Phi-3-mini and TextIteratorStreamer ---
285
  class Phi3MiniEducationalLLM(Runnable):
286
+ """LLM class optimized for Microsoft Phi-3-mini-4k-instruct with TextIteratorStreamer and proper cache handling"""
287
 
288
  def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct", use_4bit: bool = False):
289
  super().__init__()
 
318
  torch_dtype=torch.float16,
319
  trust_remote_code=True,
320
  low_cpu_mem_usage=True,
321
+ token=hf_token,
322
+ # Use eager attention for better compatibility in HF Spaces
323
+ attn_implementation="eager"
324
  )
325
  else:
326
  self._load_optimized_model(model_path)
 
342
  self.streamer = None
343
 
344
  def _load_optimized_model(self, model_path: str):
345
+ """Optimized model loading for Phi-3-mini with proper cache support."""
346
  self.model = AutoModelForCausalLM.from_pretrained(
347
  model_path,
348
  torch_dtype=torch.float16, # Use float16 to save memory
 
350
  trust_remote_code=True,
351
  low_cpu_mem_usage=True,
352
  token=hf_token,
353
+ # Use eager attention for better compatibility in HF Spaces
354
+ attn_implementation="eager"
355
  )
356
 
357
  def _format_chat_template(self, prompt: str) -> str:
 
374
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
375
 
376
  def invoke(self, input: Input, config=None) -> Output:
377
+ """Main invoke method optimized for Phi-3-mini with proper cache handling"""
378
  start_invoke_time = time.perf_counter()
379
  current_time = datetime.now()
380
 
 
399
  # Move to model device
400
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
401
 
402
+ # Initialize DynamicCache for proper caching
403
+ past_key_values = DynamicCache()
404
+
405
  with torch.no_grad():
406
  outputs = self.model.generate(
407
  **inputs,
 
413
  repetition_penalty=1.1,
414
  pad_token_id=self.tokenizer.eos_token_id,
415
  early_stopping=True,
416
+ past_key_values=past_key_values, # Use DynamicCache properly
417
+ use_cache=True # Enable cache for performance
418
  )
419
 
420
  # Decode only new tokens
 
435
  return f"[Error generating response: {str(e)}]"
436
 
437
  def stream_generate(self, input: Input, config=None):
438
+ """Streaming generation using TextIteratorStreamer with proper cache handling"""
439
  start_stream_time = time.perf_counter()
440
  current_time = datetime.now()
441
  logger.info("Starting stream_generate with TextIteratorStreamer...")
 
468
  skip_special_tokens=True
469
  )
470
 
471
+ # Initialize DynamicCache for proper caching
472
+ past_key_values = DynamicCache()
473
+
474
  # Generation parameters
475
  generation_kwargs = {
476
  **inputs,
 
482
  "repetition_penalty": 1.1,
483
  "pad_token_id": self.tokenizer.eos_token_id,
484
  "streamer": streamer,
485
+ "past_key_values": past_key_values, # Use DynamicCache properly
486
+ "use_cache": True # Enable cache for performance
487
  }
488
 
489
  # Start generation in a separate thread
 
805
  window.MathJax = {
806
  tex: {
807
  inlineMath: [['\\\\(', '\\\\)']],
808
+ displayMath: [[', '], ['\\\\[', '\\\\]']],
809
  packages: {'[+]': ['ams']}
810
  },
811
  svg: {fontCache: 'global'},