jdesiree commited on
Commit
fac310e
·
verified ·
1 Parent(s): d4d436a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -231
app.py CHANGED
@@ -419,6 +419,8 @@ class Phi3MiniEducationalLLM(Runnable):
419
  # Fallback to manual Phi-3 format
420
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
421
 
 
 
422
  class StopOnSequence(StoppingCriteria):
423
  def __init__(self, tokenizer, stop_sequence):
424
  self.tokenizer = tokenizer
@@ -429,251 +431,231 @@ class StopOnSequence(StoppingCriteria):
429
  return True
430
  return False
431
 
432
- stop_criteria = StoppingCriteriaList([StopOnSequence(self.tokenizer, "User:")])
 
 
 
 
433
 
434
- @spaces.GPU(duration=180)
435
- def invoke(self, input: Input, config=None) -> Output:
436
- """Main invoke method optimized for 4-bit quantized Phi-3-mini"""
437
- start_invoke_time = time.perf_counter()
438
- current_time = datetime.now()
439
-
440
- # FIX: Handle different input types properly
441
- if isinstance(input, dict):
442
- if 'input' in input:
443
- prompt = input['input']
444
- elif 'messages' in input:
445
- # Handle messages format
446
- prompt = str(input['messages'])
447
- else:
448
- prompt = str(input)
449
  else:
450
  prompt = str(input)
451
-
 
 
 
 
 
 
452
  try:
453
- # Load model inside GPU context
454
- model = self._load_model_if_needed()
455
-
456
- # Format using Phi-3 chat template
457
- text = self._format_chat_template(prompt)
458
-
459
- # FIX: Proper tokenization with error handling
460
- try:
461
- max_input_length = 4096 - 300
462
- inputs = self.tokenizer(
463
- text,
464
- return_tensors="pt",
465
- padding=True,
466
- truncation=True,
467
- max_length=max_input_length
468
- )
469
-
470
- # Ensure inputs are properly formatted
471
- if 'input_ids' not in inputs:
472
- logger.error("Tokenizer did not return input_ids")
473
- return "I encountered an error processing your request. Please try again."
474
-
475
- except Exception as tokenizer_error:
476
- logger.error(f"Tokenization error: {tokenizer_error}")
477
- return "I encountered an error processing your request. Please try again."
478
-
479
- # Move inputs to model device
480
- try:
481
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
482
- except Exception as device_error:
483
- logger.error(f"Device transfer error: {device_error}")
484
  return "I encountered an error processing your request. Please try again."
485
-
486
- # Generate with optimized parameters for quantized model
487
- with torch.no_grad():
488
- try:
489
- outputs = model.generate(
490
- input_ids=inputs['input_ids'],
491
- attention_mask=inputs.get('attention_mask', None),
492
- max_new_tokens=300,
493
- do_sample=True,
494
- temperature=0.7,
495
- top_p=0.9,
496
- top_k=50,
497
- repetition_penalty=1.1,
498
- pad_token_id=self.tokenizer.eos_token_id,
499
- use_cache=False,
500
- past_key_values=None,
501
- stopping_criteria=stop_criteria
502
- )
503
- except Exception as generation_error:
504
- logger.error(f"Generation error: {generation_error}")
505
- return "I encountered an error generating the response. Please try again."
506
-
507
- # Decode only new tokens
508
- try:
509
- new_tokens = outputs[0][len(inputs['input_ids'][0]):]
510
- result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
511
-
512
- # Soft stop cleanup
513
- for stop_word in ["User:", "\n\n", "###"]:
514
- if stop_word in result:
515
- result = result.split(stop_word)[0].strip()
516
- break
517
-
518
- except Exception as decode_error:
519
- logger.error(f"Decoding error: {decode_error}")
520
- return "I encountered an error processing the response. Please try again."
521
-
522
- end_invoke_time = time.perf_counter()
523
- invoke_time = end_invoke_time - start_invoke_time
524
- log_metric(f"LLM Invoke time (4-bit): {invoke_time:0.4f} seconds. Input length: {len(prompt)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
525
-
526
- return result if result else "I'm still learning how to respond to that properly."
527
 
528
-
529
- except Exception as e:
530
- logger.error(f"Generation error with 4-bit model: {e}")
531
- end_invoke_time = time.perf_counter()
532
- invoke_time = end_invoke_time - start_invoke_time
533
- log_metric(f"LLM Invoke time (error): {invoke_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
534
- return f"I encountered an error: {str(e)}"
535
-
536
- @spaces.GPU(duration=240)
537
- def stream_generate(self, input: Input, config=None):
538
- """Streaming generation with 4-bit quantized model and expanded context"""
539
- start_stream_time = time.perf_counter()
540
- current_time = datetime.now()
541
- logger.info("Starting stream_generate with 4-bit quantized model...")
542
-
543
- # Handle input properly
544
- if isinstance(input, dict):
545
- if 'input' in input:
546
- prompt = input['input']
547
- else:
548
- prompt = str(input)
549
- else:
550
- prompt = str(input)
551
-
552
  try:
553
- # Load quantized model inside GPU context
554
- model = self._load_model_if_needed()
555
-
556
- # Clear GPU cache
557
- if torch.cuda.is_available():
558
- torch.cuda.empty_cache()
559
-
560
- text = self._format_chat_template(prompt)
561
-
562
- # Proper tokenization with error handling
563
  try:
564
- inputs = self.tokenizer(
565
- text,
566
- return_tensors="pt",
567
- padding=True,
568
- truncation=True,
569
- max_length=4096
 
 
 
 
 
 
 
570
  )
571
-
572
- if not hasattr(inputs, 'input_ids'):
573
- yield "I encountered an error processing your request. Please try again."
574
- return
575
-
576
- except Exception as tokenizer_error:
577
- logger.error(f"Streaming tokenization error: {tokenizer_error}")
578
- yield "I encountered an error processing your request. Please try again."
579
- return
580
-
581
- # Move inputs to model device
582
- try:
583
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
584
- except Exception as device_error:
585
- logger.error(f"Streaming device transfer error: {device_error}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  yield "I encountered an error processing your request. Please try again."
587
  return
588
-
589
- # Initialize TextIteratorStreamer - this streams the GENERATED TOKENS, not the input
590
- streamer = TextIteratorStreamer(
591
- self.tokenizer,
592
- skip_prompt=True, # Skip the input prompt in output
593
- skip_special_tokens=True
594
- )
595
-
596
- # Generation parameters optimized for 4-bit
597
- generation_kwargs = {
598
- "input_ids": inputs['input_ids'],
599
- "attention_mask": inputs.get('attention_mask', None),
600
- "max_new_tokens": 1200,
601
- "do_sample": True,
602
- "temperature": 0.7,
603
- "top_p": 0.9,
604
- "top_k": 50,
605
- "repetition_penalty": 1.2,
606
- "pad_token_id": self.tokenizer.eos_token_id,
607
- "streamer": streamer, # This streams the OUTPUT tokens as they're generated
608
- "use_cache": False,
609
- "past_key_values": None
610
- }
611
-
612
- # Start generation in background thread
613
- generation_thread = threading.Thread(
614
- target=model.generate,
615
- kwargs=generation_kwargs
616
- )
617
- generation_thread.start()
618
-
619
- # Stream the generated tokens as they come from the model
620
- generated_text = ""
621
- consecutive_repeats = 0
622
- last_chunk = ""
623
-
624
- try:
625
- # This loop receives tokens as they're generated by the model
626
- for new_token_text in streamer:
627
- if not new_token_text:
628
- continue
629
-
630
- # Accumulate the generated text
631
- generated_text += new_token_text
632
-
633
- # Simple repetition detection
634
- if new_token_text == last_chunk:
635
- consecutive_repeats += 1
636
- if consecutive_repeats >= 5:
637
- logger.warning("Repetitive generation detected, stopping early")
638
- break
639
- else:
640
- consecutive_repeats = 0
641
- last_chunk = new_token_text
642
-
643
- # Yield the accumulated generated text (not the input prompt)
644
- yield generated_text
645
-
646
- except Exception as e:
647
- logger.error(f"Error in streaming iteration: {e}")
648
- if not generated_text.strip():
649
- generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
650
  yield generated_text
651
-
652
- generation_thread.join()
653
-
654
- # Ensure we have some content
655
  if not generated_text.strip():
656
  generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
657
- yield generated_text
658
-
659
- end_stream_time = time.perf_counter()
660
- stream_time = end_stream_time - start_stream_time
661
- log_metric(f"LLM Stream time (4-bit): {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
662
-
663
- except Exception as e:
664
- logger.error(f"4-bit streaming generation error: {e}")
665
- end_stream_time = time.perf_counter()
666
- stream_time = end_stream_time - start_stream_time
667
- log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
668
- yield "I encountered an error generating the response. Please try again."
669
-
670
- @property
671
- def InputType(self) -> Type[Input]:
672
- return str
673
-
674
- @property
675
- def OutputType(self) -> Type[Output]:
676
- return str
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  # LangGraph Agent Implementation with Tool Calling
679
  class Educational_Agent:
 
419
  # Fallback to manual Phi-3 format
420
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
421
 
422
+ from transformers import StoppingCriteria, StoppingCriteriaList
423
+
424
  class StopOnSequence(StoppingCriteria):
425
  def __init__(self, tokenizer, stop_sequence):
426
  self.tokenizer = tokenizer
 
431
  return True
432
  return False
433
 
434
+ @spaces.GPU(duration=180)
435
+ def invoke(self, input: Input, config=None) -> Output:
436
+ """Main invoke method optimized for 4-bit quantized Phi‑3‑mini"""
437
+ start_invoke_time = time.perf_counter()
438
+ current_time = datetime.now()
439
 
440
+ # Handle different input types
441
+ if isinstance(input, dict):
442
+ if 'input' in input:
443
+ prompt = input['input']
444
+ elif 'messages' in input:
445
+ prompt = str(input['messages'])
 
 
 
 
 
 
 
 
 
446
  else:
447
  prompt = str(input)
448
+ else:
449
+ prompt = str(input)
450
+
451
+ try:
452
+ model = self._load_model_if_needed()
453
+ text = self._format_chat_template(prompt)
454
+
455
  try:
456
+ max_input_length = 4096 - 300
457
+ inputs = self.tokenizer(
458
+ text,
459
+ return_tensors="pt",
460
+ padding=True,
461
+ truncation=True,
462
+ max_length=max_input_length
463
+ )
464
+ if 'input_ids' not in inputs:
465
+ logger.error("Tokenizer did not return input_ids")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  return "I encountered an error processing your request. Please try again."
467
+ except Exception as tokenizer_error:
468
+ logger.error(f"Tokenization error: {tokenizer_error}")
469
+ return "I encountered an error processing your request. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  try:
472
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
473
+ except Exception as device_error:
474
+ logger.error(f"Device transfer error: {device_error}")
475
+ return "I encountered an error processing your request. Please try again."
476
+
477
+ # Define stopping criteria after tokenizer initialization
478
+ stop_criteria = StoppingCriteriaList([StopOnSequence(self.tokenizer, "User:")])
479
+
480
+ with torch.no_grad():
 
481
  try:
482
+ outputs = model.generate(
483
+ input_ids=inputs['input_ids'],
484
+ attention_mask=inputs.get('attention_mask', None),
485
+ max_new_tokens=300,
486
+ do_sample=True,
487
+ temperature=0.7,
488
+ top_p=0.9,
489
+ top_k=50,
490
+ repetition_penalty=1.1,
491
+ pad_token_id=self.tokenizer.eos_token_id,
492
+ use_cache=False,
493
+ past_key_values=None,
494
+ stopping_criteria=stop_criteria
495
  )
496
+ except Exception as generation_error:
497
+ logger.error(f"Generation error: {generation_error}")
498
+ return "I encountered an error generating the response. Please try again."
499
+
500
+ try:
501
+ new_tokens = outputs[0][len(inputs['input_ids'][0]):]
502
+ result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
503
+
504
+ # Apply soft-stop cleanup
505
+ for stop_word in ["User:", "\n\n", "###"]:
506
+ if stop_word in result:
507
+ result = result.split(stop_word)[0].strip()
508
+ break
509
+ except Exception as decode_error:
510
+ logger.error(f"Decoding error: {decode_error}")
511
+ return "I encountered an error processing the response. Please try again."
512
+
513
+ end_invoke_time = time.perf_counter()
514
+ invoke_time = end_invoke_time - start_invoke_time
515
+ log_metric(
516
+ f"LLM Invoke time (4‑bit): {invoke_time:0.4f} seconds. "
517
+ f"Input length: {len(prompt)} chars. "
518
+ f"Model: {self.model_name}. "
519
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
520
+ )
521
+
522
+ return result if result else "I'm still learning how to respond to that properly."
523
+
524
+ except Exception as e:
525
+ logger.error(f"Generation error with 4‑bit model: {e}")
526
+ end_invoke_time = time.perf_counter()
527
+ invoke_time = end_invoke_time - start_invoke_time
528
+ log_metric(
529
+ f"LLM Invoke time (error): {invoke_time:0.4f} seconds. "
530
+ f"Model: {self.model_name}. "
531
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
532
+ )
533
+ return f"I encountered an error: {str(e)}"
534
+
535
+ @spaces.GPU(duration=240)
536
+ def stream_generate(self, input: Input, config=None):
537
+ """Streaming generation with 4‑bit quantized model and expanded context"""
538
+ start_stream_time = time.perf_counter()
539
+ current_time = datetime.now()
540
+ logger.info("Starting stream_generate with 4‑bit quantized model...")
541
+
542
+ # Handle input properly
543
+ if isinstance(input, dict):
544
+ prompt = input.get('input', str(input))
545
+ else:
546
+ prompt = str(input)
547
+
548
+ try:
549
+ model = self._load_model_if_needed()
550
+ if torch.cuda.is_available():
551
+ torch.cuda.empty_cache()
552
+ text = self._format_chat_template(prompt)
553
+
554
+ try:
555
+ inputs = self.tokenizer(
556
+ text,
557
+ return_tensors="pt",
558
+ padding=True,
559
+ truncation=True,
560
+ max_length=4096
561
+ )
562
+ if 'input_ids' not in inputs:
563
  yield "I encountered an error processing your request. Please try again."
564
  return
565
+ except Exception as tokenizer_error:
566
+ logger.error(f"Streaming tokenization error: {tokenizer_error}")
567
+ yield "I encountered an error processing your request. Please try again."
568
+ return
569
+
570
+ try:
571
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
572
+ except Exception as device_error:
573
+ logger.error(f"Streaming device transfer error: {device_error}")
574
+ yield "I encountered an error processing your request. Please try again."
575
+ return
576
+
577
+ streamer = TextIteratorStreamer(
578
+ self.tokenizer,
579
+ skip_prompt=True,
580
+ skip_special_tokens=True
581
+ )
582
+
583
+ generation_kwargs = {
584
+ "input_ids": inputs['input_ids'],
585
+ "attention_mask": inputs.get('attention_mask', None),
586
+ "max_new_tokens": 1200,
587
+ "do_sample": True,
588
+ "temperature": 0.7,
589
+ "top_p": 0.9,
590
+ "top_k": 50,
591
+ "repetition_penalty": 1.2,
592
+ "pad_token_id": self.tokenizer.eos_token_id,
593
+ "streamer": streamer,
594
+ "use_cache": False,
595
+ "past_key_values": None
596
+ }
597
+
598
+ generation_thread = threading.Thread(
599
+ target=model.generate,
600
+ kwargs=generation_kwargs
601
+ )
602
+ generation_thread.start()
603
+
604
+ generated_text = ""
605
+ consecutive_repeats = 0
606
+ last_chunk = ""
607
+
608
+ try:
609
+ for new_token_text in streamer:
610
+ if not new_token_text:
611
+ continue
612
+ generated_text += new_token_text
613
+ if new_token_text == last_chunk:
614
+ consecutive_repeats += 1
615
+ if consecutive_repeats >= 5:
616
+ logger.warning("Repetitive generation detected, stopping early")
617
+ break
618
+ else:
619
+ consecutive_repeats = 0
620
+ last_chunk = new_token_text
 
 
 
 
 
 
621
  yield generated_text
622
+ except Exception as e:
623
+ logger.error(f"Error in streaming iteration: {e}")
 
 
624
  if not generated_text.strip():
625
  generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
626
+ yield generated_text
627
+
628
+ generation_thread.join()
629
+ if not generated_text.strip():
630
+ generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
631
+ yield generated_text
632
+
633
+ end_stream_time = time.perf_counter()
634
+ stream_time = end_stream_time - start_stream_time
635
+ log_metric(
636
+ f"LLM Stream time (4‑bit): {stream_time:0.4f} seconds. "
637
+ f"Generated length: {len(generated_text)} chars. "
638
+ f"Model: {self.model_name}. "
639
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
640
+ )
641
+ except Exception as e:
642
+ logger.error(f"4‑bit streaming generation error: {e}")
643
+ end_stream_time = time.perf_counter()
644
+ stream_time = end_stream_time - start_stream_time
645
+ log_metric(
646
+ f"LLM Stream time (error): {stream_time:0.4f} seconds. "
647
+ f"Model: {self.model_name}. "
648
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
649
+ )
650
+ yield "I encountered an error generating the response. Please try again."
651
+
652
+ @property
653
+ def InputType(self) -> Type[Input]:
654
+ return str
655
+
656
+ @property
657
+ def OutputType(self) -> Type[Output]:
658
+ return str
659
 
660
  # LangGraph Agent Implementation with Tool Calling
661
  class Educational_Agent: