jdesiree commited on
Commit
613dbea
·
verified ·
1 Parent(s): 578ef70

Readded Quantumization

Browse files
Files changed (1) hide show
  1. app.py +75 -57
app.py CHANGED
@@ -25,7 +25,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
25
  from langchain_core.runnables import Runnable
26
  from langchain_core.runnables.utils import Input, Output
27
 
28
- from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
29
  import torch
30
  import time
31
  import warnings
@@ -292,28 +292,38 @@ Rather than providing complete solutions, you should:
292
  Your goal is to be an educational partner who empowers students to succeed through understanding."""
293
 
294
  # --- Updated LLM Class with Phi-3-mini ---
 
295
  class Phi3MiniEducationalLLM(Runnable):
296
- """LLM class optimized for Microsoft Phi-3-mini-4k-instruct without quantization"""
297
 
298
  def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct"):
299
  super().__init__()
300
- logger.info(f"Loading Phi-3-mini model: {model_path}")
301
  start_Loading_Model_time = time.perf_counter()
302
  current_time = datetime.now()
303
 
304
  self.model_name = model_path
305
 
306
  try:
307
- # Load tokenizer
308
  self.tokenizer = AutoTokenizer.from_pretrained(
309
  model_path,
310
  trust_remote_code=True,
311
- token=hf_token
 
 
 
 
 
 
 
 
 
312
  )
313
 
314
- # Store model path instead of loading model immediately
315
  self.model_path = model_path
316
- self.model = None # Load model lazily in GPU methods
317
 
318
  except Exception as e:
319
  logger.error(f"Failed to initialize Phi-3-mini model {model_path}: {e}")
@@ -326,16 +336,24 @@ class Phi3MiniEducationalLLM(Runnable):
326
  self.streamer = None
327
 
328
  def _load_model_if_needed(self):
329
- """Load model only when needed inside GPU context"""
330
  if self.model is None:
331
- self.model = AutoModelForCausalLM.from_pretrained(
332
- self.model_path,
333
- torch_dtype=torch.float16,
334
- trust_remote_code=True,
335
- low_cpu_mem_usage=True,
336
- token=hf_token,
337
- attn_implementation="eager"
338
- )
 
 
 
 
 
 
 
 
339
  return self.model
340
 
341
  def _format_chat_template(self, prompt: str) -> str:
@@ -357,80 +375,82 @@ class Phi3MiniEducationalLLM(Runnable):
357
  # Fallback to manual Phi-3 format
358
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
359
 
360
- @spaces.GPU(duration=60)
361
  def invoke(self, input: Input, config=None) -> Output:
362
- """Main invoke method optimized for Phi-3-mini"""
363
  start_invoke_time = time.perf_counter()
364
  current_time = datetime.now()
365
 
366
- # Handle both string and dict inputs for flexibility
367
  if isinstance(input, dict):
368
  prompt = input.get('input', str(input))
369
  else:
370
  prompt = str(input)
371
 
372
  try:
 
 
 
373
  # Format using Phi-3 chat template
374
  text = self._format_chat_template(prompt)
375
-
376
  inputs = self.tokenizer(
377
  text,
378
  return_tensors="pt",
379
  padding=True,
380
  truncation=True,
381
- max_length=3072 # Leave room for generation within 4k context
382
  )
383
-
384
  # Move inputs to model device
385
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
386
-
387
- # Generate with the model
388
  with torch.no_grad():
389
- outputs = self.model.generate(
390
  **inputs,
391
- max_new_tokens=800, # Increased for comprehensive responses
392
  do_sample=True,
393
- temperature=0.7, # Good balance for educational content
394
  top_p=0.9,
395
  top_k=50,
396
  repetition_penalty=1.1,
397
  pad_token_id=self.tokenizer.eos_token_id,
398
  early_stopping=True,
399
- use_cache=False,
400
  past_key_values=None
401
  )
402
-
403
  # Decode only new tokens
404
  new_tokens = outputs[0][len(inputs.input_ids[0]):]
405
  result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
406
 
407
  end_invoke_time = time.perf_counter()
408
  invoke_time = end_invoke_time - start_invoke_time
409
- log_metric(f"LLM Invoke time: {invoke_time:0.4f} seconds. Input length: {len(prompt)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
410
 
411
  return result if result else "I'm still learning how to respond to that properly."
412
-
413
  except Exception as e:
414
- logger.error(f"Generation error: {e}")
415
  end_invoke_time = time.perf_counter()
416
  invoke_time = end_invoke_time - start_invoke_time
417
  log_metric(f"LLM Invoke time (error): {invoke_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
418
  return f"[Error generating response: {str(e)}]"
419
 
420
- @spaces.GPU(duration=120)
421
  def stream_generate(self, input: Input, config=None):
422
- """Streaming generation using TextIteratorStreamer with loop detection and early escape."""
423
  start_stream_time = time.perf_counter()
424
  current_time = datetime.now()
425
- logger.info("Starting stream_generate with TextIteratorStreamer and loop detection...")
426
-
427
  if isinstance(input, dict):
428
  prompt = input.get('input', str(input))
429
  else:
430
  prompt = str(input)
431
-
432
  try:
433
- # Load model inside GPU context
434
  model = self._load_model_if_needed()
435
 
436
  # Clear GPU cache
@@ -438,7 +458,7 @@ class Phi3MiniEducationalLLM(Runnable):
438
  torch.cuda.empty_cache()
439
 
440
  text = self._format_chat_template(prompt)
441
-
442
  inputs = self.tokenizer(
443
  text,
444
  return_tensors="pt",
@@ -446,18 +466,18 @@ class Phi3MiniEducationalLLM(Runnable):
446
  truncation=True,
447
  max_length=3072
448
  )
449
-
450
- # Move inputs to model device - now model is not None
451
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
452
-
453
  # Initialize TextIteratorStreamer
454
  streamer = TextIteratorStreamer(
455
  self.tokenizer,
456
  skip_prompt=True,
457
  skip_special_tokens=True
458
  )
459
-
460
- # Generation parameters
461
  generation_kwargs = {
462
  **inputs,
463
  "max_new_tokens": 800,
@@ -471,15 +491,15 @@ class Phi3MiniEducationalLLM(Runnable):
471
  "use_cache": False,
472
  "past_key_values": None
473
  }
474
-
475
  # Start generation in background
476
  generation_thread = threading.Thread(
477
- target=model.generate, # Use the loaded model
478
  kwargs=generation_kwargs
479
  )
480
  generation_thread.start()
481
-
482
- # Track outputs
483
  generated_text = ""
484
  token_history = []
485
  loop_window = 20
@@ -492,39 +512,37 @@ class Phi3MiniEducationalLLM(Runnable):
492
 
493
  generated_text += new_text
494
 
495
- # Tokenize and track
496
  tokens = self.tokenizer.tokenize(new_text)
497
  token_history.extend(tokens)
498
 
499
- # Check for loops
500
  if len(token_history) >= 2 * loop_window:
501
  recent = token_history[-loop_window:]
502
  prev = token_history[-2*loop_window:-loop_window]
503
  overlap = sum(1 for r, p in zip(recent, prev) if r == p)
504
 
505
  if overlap >= loop_threshold:
506
- logger.warning(f"Looping detected (overlap: {overlap}/{loop_window}). Aborting generation.")
507
  yield "[Looping detected — generation stopped early]"
508
  break
509
 
510
  yield generated_text
511
  except Exception as e:
512
- logger.error(f"Error in streaming iteration: {e}")
513
  yield f"[Streaming error: {str(e)}]"
514
 
515
  generation_thread.join()
516
 
517
  end_stream_time = time.perf_counter()
518
  stream_time = end_stream_time - start_stream_time
519
- log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
520
- logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
521
 
522
  except Exception as e:
523
- logger.error(f"Streaming generation error: {e}")
524
  end_stream_time = time.perf_counter()
525
  stream_time = end_stream_time - start_stream_time
526
  log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
527
- yield f"[Error in streaming generation: {str(e)}]"
528
 
529
 
530
  @property
 
25
  from langchain_core.runnables import Runnable
26
  from langchain_core.runnables.utils import Input, Output
27
 
28
+ from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM, BitsAndBytesConfig
29
  import torch
30
  import time
31
  import warnings
 
292
  Your goal is to be an educational partner who empowers students to succeed through understanding."""
293
 
294
  # --- Updated LLM Class with Phi-3-mini ---
295
+
296
  class Phi3MiniEducationalLLM(Runnable):
297
+ """LLM class optimized for Microsoft Phi-3-mini-4k-instruct with 4-bit quantization"""
298
 
299
  def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct"):
300
  super().__init__()
301
+ logger.info(f"Loading Phi-3-mini model with 4-bit quantization: {model_path}")
302
  start_Loading_Model_time = time.perf_counter()
303
  current_time = datetime.now()
304
 
305
  self.model_name = model_path
306
 
307
  try:
308
+ # Load tokenizer (can be done on CPU)
309
  self.tokenizer = AutoTokenizer.from_pretrained(
310
  model_path,
311
  trust_remote_code=True,
312
+ token=hf_token,
313
+ use_fast=False
314
+ )
315
+
316
+ # Configure 4-bit quantization
317
+ self.quantization_config = BitsAndBytesConfig(
318
+ load_in_4bit=True,
319
+ bnb_4bit_compute_dtype=torch.bfloat16,
320
+ bnb_4bit_quant_type="nf4", # NormalFloat 4-bit
321
+ bnb_4bit_use_double_quant=True, # Nested quantization for extra savings
322
  )
323
 
324
+ # Store model path - model will be loaded inside GPU context
325
  self.model_path = model_path
326
+ self.model = None
327
 
328
  except Exception as e:
329
  logger.error(f"Failed to initialize Phi-3-mini model {model_path}: {e}")
 
336
  self.streamer = None
337
 
338
  def _load_model_if_needed(self):
339
+ """Load model with 4-bit quantization only when needed inside GPU context"""
340
  if self.model is None:
341
+ logger.info("Loading model with 4-bit quantization...")
342
+ try:
343
+ self.model = AutoModelForCausalLM.from_pretrained(
344
+ self.model_path,
345
+ quantization_config=self.quantization_config,
346
+ torch_dtype=torch.bfloat16,
347
+ trust_remote_code=True,
348
+ low_cpu_mem_usage=True,
349
+ token=hf_token,
350
+ attn_implementation="eager",
351
+ device_map="auto"
352
+ )
353
+ logger.info(f"Model loaded successfully. Memory footprint reduced to ~2.2GB with 4-bit quantization")
354
+ except Exception as e:
355
+ logger.error(f"Failed to load quantized model: {e}")
356
+ raise
357
  return self.model
358
 
359
  def _format_chat_template(self, prompt: str) -> str:
 
375
  # Fallback to manual Phi-3 format
376
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
377
 
378
+ @spaces.GPU(duration=180)
379
  def invoke(self, input: Input, config=None) -> Output:
380
+ """Main invoke method optimized for 4-bit quantized Phi-3-mini"""
381
  start_invoke_time = time.perf_counter()
382
  current_time = datetime.now()
383
 
 
384
  if isinstance(input, dict):
385
  prompt = input.get('input', str(input))
386
  else:
387
  prompt = str(input)
388
 
389
  try:
390
+ # Load model inside GPU context
391
+ model = self._load_model_if_needed()
392
+
393
  # Format using Phi-3 chat template
394
  text = self._format_chat_template(prompt)
395
+
396
  inputs = self.tokenizer(
397
  text,
398
  return_tensors="pt",
399
  padding=True,
400
  truncation=True,
401
+ max_length=3072
402
  )
403
+
404
  # Move inputs to model device
405
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
406
+
407
+ # Generate with optimized parameters for quantized model
408
  with torch.no_grad():
409
+ outputs = model.generate(
410
  **inputs,
411
+ max_new_tokens=800,
412
  do_sample=True,
413
+ temperature=0.7,
414
  top_p=0.9,
415
  top_k=50,
416
  repetition_penalty=1.1,
417
  pad_token_id=self.tokenizer.eos_token_id,
418
  early_stopping=True,
419
+ use_cache=False, # Disable cache for compatibility
420
  past_key_values=None
421
  )
422
+
423
  # Decode only new tokens
424
  new_tokens = outputs[0][len(inputs.input_ids[0]):]
425
  result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
426
 
427
  end_invoke_time = time.perf_counter()
428
  invoke_time = end_invoke_time - start_invoke_time
429
+ log_metric(f"LLM Invoke time (4-bit): {invoke_time:0.4f} seconds. Input length: {len(prompt)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
430
 
431
  return result if result else "I'm still learning how to respond to that properly."
432
+
433
  except Exception as e:
434
+ logger.error(f"Generation error with 4-bit model: {e}")
435
  end_invoke_time = time.perf_counter()
436
  invoke_time = end_invoke_time - start_invoke_time
437
  log_metric(f"LLM Invoke time (error): {invoke_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
438
  return f"[Error generating response: {str(e)}]"
439
 
440
+ @spaces.GPU(duration=240)
441
  def stream_generate(self, input: Input, config=None):
442
+ """Streaming generation with 4-bit quantized model"""
443
  start_stream_time = time.perf_counter()
444
  current_time = datetime.now()
445
+ logger.info("Starting stream_generate with 4-bit quantized model...")
446
+
447
  if isinstance(input, dict):
448
  prompt = input.get('input', str(input))
449
  else:
450
  prompt = str(input)
451
+
452
  try:
453
+ # Load quantized model inside GPU context
454
  model = self._load_model_if_needed()
455
 
456
  # Clear GPU cache
 
458
  torch.cuda.empty_cache()
459
 
460
  text = self._format_chat_template(prompt)
461
+
462
  inputs = self.tokenizer(
463
  text,
464
  return_tensors="pt",
 
466
  truncation=True,
467
  max_length=3072
468
  )
469
+
470
+ # Move inputs to model device
471
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
472
+
473
  # Initialize TextIteratorStreamer
474
  streamer = TextIteratorStreamer(
475
  self.tokenizer,
476
  skip_prompt=True,
477
  skip_special_tokens=True
478
  )
479
+
480
+ # Generation parameters optimized for 4-bit
481
  generation_kwargs = {
482
  **inputs,
483
  "max_new_tokens": 800,
 
491
  "use_cache": False,
492
  "past_key_values": None
493
  }
494
+
495
  # Start generation in background
496
  generation_thread = threading.Thread(
497
+ target=model.generate,
498
  kwargs=generation_kwargs
499
  )
500
  generation_thread.start()
501
+
502
+ # Stream results with loop detection
503
  generated_text = ""
504
  token_history = []
505
  loop_window = 20
 
512
 
513
  generated_text += new_text
514
 
515
+ # Loop detection logic
516
  tokens = self.tokenizer.tokenize(new_text)
517
  token_history.extend(tokens)
518
 
 
519
  if len(token_history) >= 2 * loop_window:
520
  recent = token_history[-loop_window:]
521
  prev = token_history[-2*loop_window:-loop_window]
522
  overlap = sum(1 for r, p in zip(recent, prev) if r == p)
523
 
524
  if overlap >= loop_threshold:
525
+ logger.warning(f"Looping detected with 4-bit model. Stopping generation.")
526
  yield "[Looping detected — generation stopped early]"
527
  break
528
 
529
  yield generated_text
530
  except Exception as e:
531
+ logger.error(f"Error in 4-bit streaming iteration: {e}")
532
  yield f"[Streaming error: {str(e)}]"
533
 
534
  generation_thread.join()
535
 
536
  end_stream_time = time.perf_counter()
537
  stream_time = end_stream_time - start_stream_time
538
+ log_metric(f"LLM Stream time (4-bit): {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
 
539
 
540
  except Exception as e:
541
+ logger.error(f"4-bit streaming generation error: {e}")
542
  end_stream_time = time.perf_counter()
543
  stream_time = end_stream_time - start_stream_time
544
  log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
545
+ yield f"[Error in 4-bit streaming generation: {str(e)}]"
546
 
547
 
548
  @property