jdesiree commited on
Commit
467da8b
·
verified ·
1 Parent(s): fac310e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -209
app.py CHANGED
@@ -336,7 +336,7 @@ Rather than providing complete solutions, you should:
336
 
337
  Your goal is to be an educational partner who empowers students to succeed through understanding."""
338
 
339
- # FIXED LLM Class with Phi-3-mini
340
  class Phi3MiniEducationalLLM(Runnable):
341
  """LLM class optimized for Microsoft Phi-3-mini-4k-instruct with 4-bit quantization"""
342
 
@@ -419,243 +419,231 @@ class Phi3MiniEducationalLLM(Runnable):
419
  # Fallback to manual Phi-3 format
420
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
421
 
422
- from transformers import StoppingCriteria, StoppingCriteriaList
423
-
424
- class StopOnSequence(StoppingCriteria):
425
- def __init__(self, tokenizer, stop_sequence):
426
- self.tokenizer = tokenizer
427
- self.stop_sequence = tokenizer.encode(stop_sequence, add_special_tokens=False)
428
-
429
- def __call__(self, input_ids, scores, **kwargs):
430
- if input_ids[0, -len(self.stop_sequence):].tolist() == self.stop_sequence:
431
- return True
432
- return False
433
-
434
- @spaces.GPU(duration=180)
435
- def invoke(self, input: Input, config=None) -> Output:
436
- """Main invoke method optimized for 4-bit quantized Phi‑3‑mini"""
437
- start_invoke_time = time.perf_counter()
438
- current_time = datetime.now()
439
 
440
- # Handle different input types
441
- if isinstance(input, dict):
442
- if 'input' in input:
443
- prompt = input['input']
444
- elif 'messages' in input:
445
- prompt = str(input['messages'])
 
 
446
  else:
447
  prompt = str(input)
448
- else:
449
- prompt = str(input)
450
-
451
- try:
452
- model = self._load_model_if_needed()
453
- text = self._format_chat_template(prompt)
454
 
455
  try:
456
- max_input_length = 4096 - 300
457
- inputs = self.tokenizer(
458
- text,
459
- return_tensors="pt",
460
- padding=True,
461
- truncation=True,
462
- max_length=max_input_length
463
- )
464
- if 'input_ids' not in inputs:
465
- logger.error("Tokenizer did not return input_ids")
 
 
 
 
 
 
 
466
  return "I encountered an error processing your request. Please try again."
467
- except Exception as tokenizer_error:
468
- logger.error(f"Tokenization error: {tokenizer_error}")
469
- return "I encountered an error processing your request. Please try again."
470
 
471
- try:
472
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
473
- except Exception as device_error:
474
- logger.error(f"Device transfer error: {device_error}")
475
- return "I encountered an error processing your request. Please try again."
476
 
477
- # Define stopping criteria after tokenizer initialization
478
- stop_criteria = StoppingCriteriaList([StopOnSequence(self.tokenizer, "User:")])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
- with torch.no_grad():
481
  try:
482
- outputs = model.generate(
483
- input_ids=inputs['input_ids'],
484
- attention_mask=inputs.get('attention_mask', None),
485
- max_new_tokens=300,
486
- do_sample=True,
487
- temperature=0.7,
488
- top_p=0.9,
489
- top_k=50,
490
- repetition_penalty=1.1,
491
- pad_token_id=self.tokenizer.eos_token_id,
492
- use_cache=False,
493
- past_key_values=None,
494
- stopping_criteria=stop_criteria
495
- )
496
- except Exception as generation_error:
497
- logger.error(f"Generation error: {generation_error}")
498
- return "I encountered an error generating the response. Please try again."
499
 
500
- try:
501
- new_tokens = outputs[0][len(inputs['input_ids'][0]):]
502
- result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
- # Apply soft-stop cleanup
505
- for stop_word in ["User:", "\n\n", "###"]:
506
- if stop_word in result:
507
- result = result.split(stop_word)[0].strip()
508
- break
509
- except Exception as decode_error:
510
- logger.error(f"Decoding error: {decode_error}")
511
- return "I encountered an error processing the response. Please try again."
512
-
513
- end_invoke_time = time.perf_counter()
514
- invoke_time = end_invoke_time - start_invoke_time
515
- log_metric(
516
- f"LLM Invoke time (4‑bit): {invoke_time:0.4f} seconds. "
517
- f"Input length: {len(prompt)} chars. "
518
- f"Model: {self.model_name}. "
519
- f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
520
- )
521
 
522
- return result if result else "I'm still learning how to respond to that properly."
 
 
 
 
 
 
 
 
 
523
 
524
- except Exception as e:
525
- logger.error(f"Generation error with 4‑bit model: {e}")
526
- end_invoke_time = time.perf_counter()
527
- invoke_time = end_invoke_time - start_invoke_time
528
- log_metric(
529
- f"LLM Invoke time (error): {invoke_time:0.4f} seconds. "
530
- f"Model: {self.model_name}. "
531
- f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
532
- )
533
- return f"I encountered an error: {str(e)}"
534
 
535
- @spaces.GPU(duration=240)
536
- def stream_generate(self, input: Input, config=None):
537
- """Streaming generation with 4‑bit quantized model and expanded context"""
538
- start_stream_time = time.perf_counter()
539
- current_time = datetime.now()
540
- logger.info("Starting stream_generate with 4‑bit quantized model...")
541
 
542
- # Handle input properly
543
- if isinstance(input, dict):
544
- prompt = input.get('input', str(input))
545
- else:
546
- prompt = str(input)
547
 
548
- try:
549
- model = self._load_model_if_needed()
550
- if torch.cuda.is_available():
551
- torch.cuda.empty_cache()
552
- text = self._format_chat_template(prompt)
 
 
 
 
 
 
 
 
 
 
553
 
554
- try:
555
- inputs = self.tokenizer(
556
- text,
557
- return_tensors="pt",
558
- padding=True,
559
- truncation=True,
560
- max_length=4096
561
- )
562
- if 'input_ids' not in inputs:
563
  yield "I encountered an error processing your request. Please try again."
564
  return
565
- except Exception as tokenizer_error:
566
- logger.error(f"Streaming tokenization error: {tokenizer_error}")
567
- yield "I encountered an error processing your request. Please try again."
568
- return
569
 
570
- try:
571
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
572
- except Exception as device_error:
573
- logger.error(f"Streaming device transfer error: {device_error}")
574
- yield "I encountered an error processing your request. Please try again."
575
- return
576
-
577
- streamer = TextIteratorStreamer(
578
- self.tokenizer,
579
- skip_prompt=True,
580
- skip_special_tokens=True
581
- )
582
 
583
- generation_kwargs = {
584
- "input_ids": inputs['input_ids'],
585
- "attention_mask": inputs.get('attention_mask', None),
586
- "max_new_tokens": 1200,
587
- "do_sample": True,
588
- "temperature": 0.7,
589
- "top_p": 0.9,
590
- "top_k": 50,
591
- "repetition_penalty": 1.2,
592
- "pad_token_id": self.tokenizer.eos_token_id,
593
- "streamer": streamer,
594
- "use_cache": False,
595
- "past_key_values": None
596
- }
597
-
598
- generation_thread = threading.Thread(
599
- target=model.generate,
600
- kwargs=generation_kwargs
601
- )
602
- generation_thread.start()
603
 
604
- generated_text = ""
605
- consecutive_repeats = 0
606
- last_chunk = ""
 
 
607
 
608
- try:
609
- for new_token_text in streamer:
610
- if not new_token_text:
611
- continue
612
- generated_text += new_token_text
613
- if new_token_text == last_chunk:
614
- consecutive_repeats += 1
615
- if consecutive_repeats >= 5:
616
- logger.warning("Repetitive generation detected, stopping early")
617
- break
618
- else:
619
- consecutive_repeats = 0
620
- last_chunk = new_token_text
 
 
 
 
 
 
 
 
 
621
  yield generated_text
622
- except Exception as e:
623
- logger.error(f"Error in streaming iteration: {e}")
624
  if not generated_text.strip():
625
  generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
626
- yield generated_text
627
-
628
- generation_thread.join()
629
- if not generated_text.strip():
630
- generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
631
- yield generated_text
632
-
633
- end_stream_time = time.perf_counter()
634
- stream_time = end_stream_time - start_stream_time
635
- log_metric(
636
- f"LLM Stream time (4‑bit): {stream_time:0.4f} seconds. "
637
- f"Generated length: {len(generated_text)} chars. "
638
- f"Model: {self.model_name}. "
639
- f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
640
- )
641
- except Exception as e:
642
- logger.error(f"4‑bit streaming generation error: {e}")
643
- end_stream_time = time.perf_counter()
644
- stream_time = end_stream_time - start_stream_time
645
- log_metric(
646
- f"LLM Stream time (error): {stream_time:0.4f} seconds. "
647
- f"Model: {self.model_name}. "
648
- f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
649
- )
650
- yield "I encountered an error generating the response. Please try again."
651
 
652
- @property
653
- def InputType(self) -> Type[Input]:
654
- return str
655
 
656
- @property
657
- def OutputType(self) -> Type[Output]:
658
- return str
659
 
660
  # LangGraph Agent Implementation with Tool Calling
661
  class Educational_Agent:
 
336
 
337
  Your goal is to be an educational partner who empowers students to succeed through understanding."""
338
 
339
+ # --- LLM Class with Phi-3 Mini ---
340
  class Phi3MiniEducationalLLM(Runnable):
341
  """LLM class optimized for Microsoft Phi-3-mini-4k-instruct with 4-bit quantization"""
342
 
 
419
  # Fallback to manual Phi-3 format
420
  return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
421
 
422
+ @spaces.GPU(duration=180)
423
+ def invoke(self, input: Input, config=None) -> Output:
424
+ """Main invoke method optimized for 4-bit quantized Phi‑3‑mini"""
425
+ start_invoke_time = time.perf_counter()
426
+ current_time = datetime.now()
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
+ # Handle different input types
429
+ if isinstance(input, dict):
430
+ if 'input' in input:
431
+ prompt = input['input']
432
+ elif 'messages' in input:
433
+ prompt = str(input['messages'])
434
+ else:
435
+ prompt = str(input)
436
  else:
437
  prompt = str(input)
 
 
 
 
 
 
438
 
439
  try:
440
+ model = self._load_model_if_needed()
441
+ text = self._format_chat_template(prompt)
442
+
443
+ try:
444
+ max_input_length = 4096 - 300
445
+ inputs = self.tokenizer(
446
+ text,
447
+ return_tensors="pt",
448
+ padding=True,
449
+ truncation=True,
450
+ max_length=max_input_length
451
+ )
452
+ if 'input_ids' not in inputs:
453
+ logger.error("Tokenizer did not return input_ids")
454
+ return "I encountered an error processing your request. Please try again."
455
+ except Exception as tokenizer_error:
456
+ logger.error(f"Tokenization error: {tokenizer_error}")
457
  return "I encountered an error processing your request. Please try again."
 
 
 
458
 
459
+ try:
460
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
461
+ except Exception as device_error:
462
+ logger.error(f"Device transfer error: {device_error}")
463
+ return "I encountered an error processing your request. Please try again."
464
 
465
+ # Define stopping criteria after tokenizer initialization
466
+ stop_criteria = StoppingCriteriaList([StopOnSequence(self.tokenizer, "User:")])
467
+
468
+ with torch.no_grad():
469
+ try:
470
+ outputs = model.generate(
471
+ input_ids=inputs['input_ids'],
472
+ attention_mask=inputs.get('attention_mask', None),
473
+ max_new_tokens=300,
474
+ do_sample=True,
475
+ temperature=0.7,
476
+ top_p=0.9,
477
+ top_k=50,
478
+ repetition_penalty=1.1,
479
+ pad_token_id=self.tokenizer.eos_token_id,
480
+ use_cache=False,
481
+ past_key_values=None,
482
+ stopping_criteria=stop_criteria
483
+ )
484
+ except Exception as generation_error:
485
+ logger.error(f"Generation error: {generation_error}")
486
+ return "I encountered an error generating the response. Please try again."
487
 
 
488
  try:
489
+ new_tokens = outputs[0][len(inputs['input_ids'][0]):]
490
+ result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
+ # Apply soft-stop cleanup
493
+ for stop_word in ["User:", "\n\n", "###"]:
494
+ if stop_word in result:
495
+ result = result.split(stop_word)[0].strip()
496
+ break
497
+ except Exception as decode_error:
498
+ logger.error(f"Decoding error: {decode_error}")
499
+ return "I encountered an error processing the response. Please try again."
500
+
501
+ end_invoke_time = time.perf_counter()
502
+ invoke_time = end_invoke_time - start_invoke_time
503
+ log_metric(
504
+ f"LLM Invoke time (4‑bit): {invoke_time:0.4f} seconds. "
505
+ f"Input length: {len(prompt)} chars. "
506
+ f"Model: {self.model_name}. "
507
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
508
+ )
509
 
510
+ return result if result else "I'm still learning how to respond to that properly."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
+ except Exception as e:
513
+ logger.error(f"Generation error with 4‑bit model: {e}")
514
+ end_invoke_time = time.perf_counter()
515
+ invoke_time = end_invoke_time - start_invoke_time
516
+ log_metric(
517
+ f"LLM Invoke time (error): {invoke_time:0.4f} seconds. "
518
+ f"Model: {self.model_name}. "
519
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
520
+ )
521
+ return f"I encountered an error: {str(e)}"
522
 
523
+ @spaces.GPU(duration=240)
524
+ def stream_generate(self, input: Input, config=None):
525
+ """Streaming generation with 4‑bit quantized model and expanded context"""
526
+ start_stream_time = time.perf_counter()
527
+ current_time = datetime.now()
528
+ logger.info("Starting stream_generate with 4‑bit quantized model...")
 
 
 
 
529
 
530
+ # Handle input properly
531
+ if isinstance(input, dict):
532
+ prompt = input.get('input', str(input))
533
+ else:
534
+ prompt = str(input)
 
535
 
536
+ try:
537
+ model = self._load_model_if_needed()
538
+ if torch.cuda.is_available():
539
+ torch.cuda.empty_cache()
540
+ text = self._format_chat_template(prompt)
541
 
542
+ try:
543
+ inputs = self.tokenizer(
544
+ text,
545
+ return_tensors="pt",
546
+ padding=True,
547
+ truncation=True,
548
+ max_length=4096
549
+ )
550
+ if 'input_ids' not in inputs:
551
+ yield "I encountered an error processing your request. Please try again."
552
+ return
553
+ except Exception as tokenizer_error:
554
+ logger.error(f"Streaming tokenization error: {tokenizer_error}")
555
+ yield "I encountered an error processing your request. Please try again."
556
+ return
557
 
558
+ try:
559
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
560
+ except Exception as device_error:
561
+ logger.error(f"Streaming device transfer error: {device_error}")
 
 
 
 
 
562
  yield "I encountered an error processing your request. Please try again."
563
  return
 
 
 
 
564
 
565
+ streamer = TextIteratorStreamer(
566
+ self.tokenizer,
567
+ skip_prompt=True,
568
+ skip_special_tokens=True
569
+ )
 
 
 
 
 
 
 
570
 
571
+ generation_kwargs = {
572
+ "input_ids": inputs['input_ids'],
573
+ "attention_mask": inputs.get('attention_mask', None),
574
+ "max_new_tokens": 1200,
575
+ "do_sample": True,
576
+ "temperature": 0.7,
577
+ "top_p": 0.9,
578
+ "top_k": 50,
579
+ "repetition_penalty": 1.2,
580
+ "pad_token_id": self.tokenizer.eos_token_id,
581
+ "streamer": streamer,
582
+ "use_cache": False,
583
+ "past_key_values": None
584
+ }
 
 
 
 
 
 
585
 
586
+ generation_thread = threading.Thread(
587
+ target=model.generate,
588
+ kwargs=generation_kwargs
589
+ )
590
+ generation_thread.start()
591
 
592
+ generated_text = ""
593
+ consecutive_repeats = 0
594
+ last_chunk = ""
595
+
596
+ try:
597
+ for new_token_text in streamer:
598
+ if not new_token_text:
599
+ continue
600
+ generated_text += new_token_text
601
+ if new_token_text == last_chunk:
602
+ consecutive_repeats += 1
603
+ if consecutive_repeats >= 5:
604
+ logger.warning("Repetitive generation detected, stopping early")
605
+ break
606
+ else:
607
+ consecutive_repeats = 0
608
+ last_chunk = new_token_text
609
+ yield generated_text
610
+ except Exception as e:
611
+ logger.error(f"Error in streaming iteration: {e}")
612
+ if not generated_text.strip():
613
+ generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
614
  yield generated_text
615
+
616
+ generation_thread.join()
617
  if not generated_text.strip():
618
  generated_text = "I apologize, but I'm having trouble generating a response. Please try rephrasing your question."
619
+ yield generated_text
620
+
621
+ end_stream_time = time.perf_counter()
622
+ stream_time = end_stream_time - start_stream_time
623
+ log_metric(
624
+ f"LLM Stream time (4‑bit): {stream_time:0.4f} seconds. "
625
+ f"Generated length: {len(generated_text)} chars. "
626
+ f"Model: {self.model_name}. "
627
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
628
+ )
629
+ except Exception as e:
630
+ logger.error(f"4‑bit streaming generation error: {e}")
631
+ end_stream_time = time.perf_counter()
632
+ stream_time = end_stream_time - start_stream_time
633
+ log_metric(
634
+ f"LLM Stream time (error): {stream_time:0.4f} seconds. "
635
+ f"Model: {self.model_name}. "
636
+ f"Timestamp: {current_time:%Y‑%m‑%d %H:%M:%S}"
637
+ )
638
+ yield "I encountered an error generating the response. Please try again."
 
 
 
 
 
639
 
640
+ @property
641
+ def InputType(self) -> Type[Input]:
642
+ return str
643
 
644
+ @property
645
+ def OutputType(self) -> Type[Output]:
646
+ return str
647
 
648
  # LangGraph Agent Implementation with Tool Calling
649
  class Educational_Agent: