jdesiree commited on
Commit
9a01e13
·
verified ·
1 Parent(s): 5a61673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -103
app.py CHANGED
@@ -1,12 +1,6 @@
1
  import gradio as gr
2
  from graph_tool import generate_plot
3
  import os
4
-
5
- # Environment Variables
6
- os.environ['HF_HOME'] = '/tmp/huggingface'
7
- os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface'
8
-
9
- import time
10
  import platform
11
  from dotenv import load_dotenv
12
  import logging
@@ -30,10 +24,20 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
30
  from langchain_core.runnables import Runnable
31
  from langchain_core.runnables.utils import Input, Output
32
 
33
- from transformers import AutoTokenizer, TextIteratorStreamer
34
- from optimum.onnxruntime import ORTModelForCausalLM, ORTQuantizer
35
- from optimum.onnxruntime.configuration import AutoQuantizationConfig
36
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  load_dotenv(".env")
39
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
@@ -284,20 +288,17 @@ Rather than providing complete solutions, you should:
284
 
285
  Your goal is to be an educational partner who empowers students to succeed through understanding."""
286
 
287
- # --- Updated LLM Class with Phi-3-mini and TextIteratorStreamer ---
288
  class Phi3MiniEducationalLLM(Runnable):
289
- """LLM class optimized for Microsoft Phi-3-mini-4k-instruct with ONNX Runtime quantization"""
290
 
291
- def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct", use_quantization: bool = True,
292
- quantization_type: str = "avx512_vnni"):
293
  super().__init__()
294
- logger.info(f"Loading Phi-3-mini model: {model_path} (quantization={use_quantization}, type={quantization_type})")
295
  start_Loading_Model_time = time.perf_counter()
296
  current_time = datetime.now()
297
 
298
  self.model_name = model_path
299
- self.use_quantization = use_quantization
300
- self.quantization_type = quantization_type
301
 
302
  try:
303
  # Load tokenizer - Phi-3 requires trust_remote_code
@@ -307,15 +308,21 @@ class Phi3MiniEducationalLLM(Runnable):
307
  token=hf_token
308
  )
309
 
310
- if use_quantization:
311
- self._load_quantized_model(model_path, quantization_type)
312
- else:
313
- self._load_standard_onnx_model(model_path)
 
 
 
 
 
 
314
 
315
  # Success path - log timing
316
  end_Loading_Model_time = time.perf_counter()
317
  Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
318
- log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Model: {model_path}. Quantization: {use_quantization}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
319
 
320
  except Exception as e:
321
  logger.error(f"Failed to load Phi-3-mini model {model_path}: {e}")
@@ -328,67 +335,6 @@ class Phi3MiniEducationalLLM(Runnable):
328
  # Initialize TextIteratorStreamer
329
  self.streamer = None
330
 
331
- def _load_quantized_model(self, model_path: str, quantization_type: str):
332
- """Load model with ONNX Runtime quantization."""
333
- try:
334
- # First, load the model as ONNX format
335
- onnx_model = ORTModelForCausalLM.from_pretrained(
336
- model_path,
337
- export=True, # Convert PyTorch to ONNX if needed
338
- trust_remote_code=True,
339
- token=hf_token,
340
- provider="CPUExecutionProvider" # Force CPU execution
341
- )
342
-
343
- # Create quantizer
344
- quantizer = ORTQuantizer.from_pretrained(onnx_model)
345
-
346
- # Define quantization configuration based on type
347
- if quantization_type == "avx512_vnni":
348
- qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
349
- elif quantization_type == "avx512":
350
- qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=False)
351
- elif quantization_type == "avx2":
352
- qconfig = AutoQuantizationConfig.avx2(is_static=False, per_channel=False)
353
- elif quantization_type == "arm64":
354
- qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
355
- else:
356
- logger.warning(f"Unknown quantization type {quantization_type}, using avx512_vnni")
357
- qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
358
-
359
- # Create temporary directory for quantized model
360
- quantized_model_dir = f"./quantized_{model_path.replace('/', '_')}"
361
- os.makedirs(quantized_model_dir, exist_ok=True)
362
-
363
- # Quantize the model
364
- logger.info(f"Quantizing model with {quantization_type}...")
365
- model_quantized_path = quantizer.quantize(
366
- save_dir=quantized_model_dir,
367
- quantization_config=qconfig,
368
- )
369
-
370
- # Load the quantized model
371
- self.model = ORTModelForCausalLM.from_pretrained(
372
- quantized_model_dir,
373
- provider="CPUExecutionProvider"
374
- )
375
-
376
- logger.info(f"Successfully loaded quantized model from {model_quantized_path}")
377
-
378
- except Exception as e:
379
- logger.warning(f"Quantization failed ({e}), falling back to standard ONNX model")
380
- self._load_standard_onnx_model(model_path)
381
-
382
- def _load_standard_onnx_model(self, model_path: str):
383
- """Load standard ONNX model without quantization."""
384
- self.model = ORTModelForCausalLM.from_pretrained(
385
- model_path,
386
- export=True, # Convert PyTorch to ONNX if needed
387
- trust_remote_code=True,
388
- token=hf_token,
389
- provider="CPUExecutionProvider" # Force CPU execution
390
- )
391
-
392
  def _format_chat_template(self, prompt: str) -> str:
393
  """Format prompt using Phi-3's chat template"""
394
  try:
@@ -409,7 +355,7 @@ class Phi3MiniEducationalLLM(Runnable):
409
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
410
 
411
  def invoke(self, input: Input, config=None) -> Output:
412
- """Main invoke method optimized for Phi-3-mini with ONNX Runtime"""
413
  start_invoke_time = time.perf_counter()
414
  current_time = datetime.now()
415
 
@@ -431,19 +377,23 @@ class Phi3MiniEducationalLLM(Runnable):
431
  max_length=3072 # Leave room for generation within 4k context
432
  )
433
 
434
- # Generate with ONNX Runtime model
435
- outputs = self.model.generate(
436
- **inputs,
437
- max_new_tokens=800, # Increased for comprehensive responses
438
- do_sample=True,
439
- temperature=0.7, # Good balance for educational content
440
- top_p=0.9,
441
- top_k=50,
442
- repetition_penalty=1.1,
443
- pad_token_id=self.tokenizer.eos_token_id,
444
- early_stopping=True,
445
- use_cache=True # Enable cache for performance
446
- )
 
 
 
 
447
 
448
  # Decode only new tokens
449
  new_tokens = outputs[0][len(inputs.input_ids[0]):]
@@ -463,10 +413,10 @@ class Phi3MiniEducationalLLM(Runnable):
463
  return f"[Error generating response: {str(e)}]"
464
 
465
  def stream_generate(self, input: Input, config=None):
466
- """Streaming generation using TextIteratorStreamer with ONNX Runtime"""
467
  start_stream_time = time.perf_counter()
468
  current_time = datetime.now()
469
- logger.info("Starting stream_generate with TextIteratorStreamer and ONNX Runtime...")
470
 
471
  # Handle both string and dict inputs
472
  if isinstance(input, dict):
@@ -486,6 +436,9 @@ class Phi3MiniEducationalLLM(Runnable):
486
  max_length=3072
487
  )
488
 
 
 
 
489
  # Initialize TextIteratorStreamer
490
  streamer = TextIteratorStreamer(
491
  self.tokenizer,
@@ -493,7 +446,7 @@ class Phi3MiniEducationalLLM(Runnable):
493
  skip_special_tokens=True
494
  )
495
 
496
- # Generation parameters for ONNX Runtime model
497
  generation_kwargs = {
498
  **inputs,
499
  "max_new_tokens": 800,
@@ -529,14 +482,14 @@ class Phi3MiniEducationalLLM(Runnable):
529
  generation_thread.join()
530
 
531
  end_stream_time = time.perf_counter()
532
- stream_time = end_stream_time - start_invoke_time
533
  log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
534
  logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
535
 
536
  except Exception as e:
537
  logger.error(f"Streaming generation error: {e}")
538
  end_stream_time = time.perf_counter()
539
- stream_time = end_stream_time - start_invoke_time
540
  log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
541
  yield f"[Error in streaming generation: {str(e)}]"
542
 
@@ -556,7 +509,7 @@ class Educational_Agent:
556
  start_init_and_langgraph_time = time.perf_counter()
557
  current_time = datetime.now()
558
 
559
- self.llm = Phi3MiniEducationalLLM(model_path="microsoft/Phi-3-mini-4k-instruct", use_quantization=True)
560
  self.tool_decision_engine = Tool_Decision_Engine(self.llm)
561
 
562
  # Create LangGraph workflow
@@ -1081,7 +1034,7 @@ def create_interface():
1081
  if __name__ == "__main__":
1082
  try:
1083
  logger.info("=" * 50)
1084
- logger.info("Starting Mimir Application with Microsoft Phi-3-mini-4k-instruct and ONNX Runtime Quantization")
1085
  logger.info("=" * 50)
1086
 
1087
  # Step 1: Preload the model and agent
 
1
  import gradio as gr
2
  from graph_tool import generate_plot
3
  import os
 
 
 
 
 
 
4
  import platform
5
  from dotenv import load_dotenv
6
  import logging
 
24
  from langchain_core.runnables import Runnable
25
  from langchain_core.runnables.utils import Input, Output
26
 
27
+ from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
 
 
28
  import torch
29
+ import time
30
+ import warnings
31
+
32
+ # Updated environment variables
33
+ os.environ['HF_HOME'] = '/tmp/huggingface'
34
+ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface'
35
+
36
+ # Suppress warnings
37
+ warnings.filterwarnings("ignore", message="Special tokens have been added")
38
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
39
+ warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")
40
+ torch._C._set_print_trace_warnings(False)
41
 
42
  load_dotenv(".env")
43
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
288
 
289
  Your goal is to be an educational partner who empowers students to succeed through understanding."""
290
 
291
+ # --- Updated LLM Class with Phi-3-mini ---
292
  class Phi3MiniEducationalLLM(Runnable):
293
+ """LLM class optimized for Microsoft Phi-3-mini-4k-instruct without quantization"""
294
 
295
+ def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct"):
 
296
  super().__init__()
297
+ logger.info(f"Loading Phi-3-mini model: {model_path}")
298
  start_Loading_Model_time = time.perf_counter()
299
  current_time = datetime.now()
300
 
301
  self.model_name = model_path
 
 
302
 
303
  try:
304
  # Load tokenizer - Phi-3 requires trust_remote_code
 
308
  token=hf_token
309
  )
310
 
311
+ # Load model with memory-efficient settings
312
+ self.model = AutoModelForCausalLM.from_pretrained(
313
+ model_path,
314
+ dtype=torch.float16, # Use float16 to reduce memory usage
315
+ device_map="auto", # Let it handle device placement
316
+ trust_remote_code=True,
317
+ low_cpu_mem_usage=True, # Essential for memory efficiency
318
+ token=hf_token,
319
+ attn_implementation="eager" # Use eager attention for compatibility
320
+ )
321
 
322
  # Success path - log timing
323
  end_Loading_Model_time = time.perf_counter()
324
  Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
325
+ log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Model: {model_path}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
326
 
327
  except Exception as e:
328
  logger.error(f"Failed to load Phi-3-mini model {model_path}: {e}")
 
335
  # Initialize TextIteratorStreamer
336
  self.streamer = None
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def _format_chat_template(self, prompt: str) -> str:
339
  """Format prompt using Phi-3's chat template"""
340
  try:
 
355
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
356
 
357
  def invoke(self, input: Input, config=None) -> Output:
358
+ """Main invoke method optimized for Phi-3-mini"""
359
  start_invoke_time = time.perf_counter()
360
  current_time = datetime.now()
361
 
 
377
  max_length=3072 # Leave room for generation within 4k context
378
  )
379
 
380
+ # Move inputs to model device
381
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
382
+
383
+ # Generate with the model
384
+ with torch.no_grad():
385
+ outputs = self.model.generate(
386
+ **inputs,
387
+ max_new_tokens=800, # Increased for comprehensive responses
388
+ do_sample=True,
389
+ temperature=0.7, # Good balance for educational content
390
+ top_p=0.9,
391
+ top_k=50,
392
+ repetition_penalty=1.1,
393
+ pad_token_id=self.tokenizer.eos_token_id,
394
+ early_stopping=True,
395
+ use_cache=True # Enable cache for performance
396
+ )
397
 
398
  # Decode only new tokens
399
  new_tokens = outputs[0][len(inputs.input_ids[0]):]
 
413
  return f"[Error generating response: {str(e)}]"
414
 
415
  def stream_generate(self, input: Input, config=None):
416
+ """Streaming generation using TextIteratorStreamer"""
417
  start_stream_time = time.perf_counter()
418
  current_time = datetime.now()
419
+ logger.info("Starting stream_generate with TextIteratorStreamer...")
420
 
421
  # Handle both string and dict inputs
422
  if isinstance(input, dict):
 
436
  max_length=3072
437
  )
438
 
439
+ # Move inputs to model device
440
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
441
+
442
  # Initialize TextIteratorStreamer
443
  streamer = TextIteratorStreamer(
444
  self.tokenizer,
 
446
  skip_special_tokens=True
447
  )
448
 
449
+ # Generation parameters
450
  generation_kwargs = {
451
  **inputs,
452
  "max_new_tokens": 800,
 
482
  generation_thread.join()
483
 
484
  end_stream_time = time.perf_counter()
485
+ stream_time = end_stream_time - start_stream_time
486
  log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
487
  logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
488
 
489
  except Exception as e:
490
  logger.error(f"Streaming generation error: {e}")
491
  end_stream_time = time.perf_counter()
492
+ stream_time = end_stream_time - start_stream_time
493
  log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
494
  yield f"[Error in streaming generation: {str(e)}]"
495
 
 
509
  start_init_and_langgraph_time = time.perf_counter()
510
  current_time = datetime.now()
511
 
512
+ self.llm = Phi3MiniEducationalLLM(model_path="microsoft/Phi-3-mini-4k-instruct")
513
  self.tool_decision_engine = Tool_Decision_Engine(self.llm)
514
 
515
  # Create LangGraph workflow
 
1034
  if __name__ == "__main__":
1035
  try:
1036
  logger.info("=" * 50)
1037
+ logger.info("Starting Mimir Application with Microsoft Phi-3-mini-4k-instruct")
1038
  logger.info("=" * 50)
1039
 
1040
  # Step 1: Preload the model and agent