jdesiree commited on
Commit
e4fcfe9
·
verified ·
1 Parent(s): 1c43355

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -118
app.py CHANGED
@@ -414,136 +414,135 @@ class Phi2EducationalLLM(Runnable):
414
  return f"[Error generating response: {str(e)}]"
415
 
416
  def stream_generate(self, input: Input, config=None):
417
- """Streaming generation method for real-time response display."""
418
- import time
419
- from datetime import datetime
420
 
421
- start_stream_time = time.perf_counter()
422
- current_time = datetime.now()
423
 
424
- # --- Debug Start ---
425
- logger.info("Starting stream_generate...")
426
- logger.debug(f"Input received: {input}")
427
- # -------------------
428
 
429
- # Handle input
430
- if isinstance(input, dict):
431
- prompt = input.get('input', str(input))
432
- else:
433
- prompt = str(input)
434
 
435
- try:
436
- # === Configurable Generation Parameters ===
437
- temperature = config.get("temperature", 0.7) if config else 0.7
438
- top_k = config.get("top_k", 50) if config else 50
439
- top_p = config.get("top_p", 0.9) if config else 0.9
440
- max_new_tokens = config.get("max_new_tokens", 600) if config else 600
441
- timeout_seconds = config.get("timeout_seconds", 15) if config else 15
442
-
443
- # === Prompt Construction ===
444
  try:
445
- messages = [
446
- {"role": "system", "content": SYSTEM_PROMPT},
447
- {"role": "user", "content": prompt}
448
- ]
449
- text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
450
- except Exception as e:
451
- logger.warning(f"Failed to use chat template: {e}")
452
- if "phi" in self.model_name.lower():
453
- text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
454
- else:
455
- text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
456
-
457
- # === Tokenize ===
458
- inputs = self.tokenizer(
459
- [text],
460
- return_tensors="pt",
461
- padding=True,
462
- truncation=True,
463
- max_length=self.tokenizer.model_max_length
464
- )
465
- if torch.cuda.is_available():
466
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
467
-
468
- input_ids = inputs["input_ids"]
469
- attention_mask = inputs["attention_mask"]
470
- input_length = input_ids.shape[1]
471
-
472
- # === Streaming Generation ===
473
- generated_tokens = []
474
- past_key_values = None
475
- eos_token_id = self.tokenizer.eos_token_id
476
- start_time = time.time()
477
-
478
- logger.info("Beginning token-by-token generation...")
479
-
480
- for step in range(max_new_tokens):
481
- if time.time() - start_time > timeout_seconds:
482
- logger.warning("Timeout reached. Ending stream.")
483
- break
484
-
485
- with torch.no_grad():
486
- model_inputs = {
487
- "attention_mask": attention_mask,
488
- "use_cache": True
489
- }
490
-
491
- if past_key_values is None:
492
- model_inputs["input_ids"] = input_ids
493
  else:
494
- model_inputs["input_ids"] = next_token
495
- model_inputs["past_key_values"] = past_key_values
496
-
497
- outputs = self.model(**model_inputs)
498
- logits = outputs.logits[:, -1, :]
499
- past_key_values = outputs.past_key_values
500
-
501
- # Sampling
502
- logits = logits / temperature
503
- filtered_logits = self._top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
504
- probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
505
- next_token = torch.multinomial(probs, num_samples=1)
506
 
507
- token_id = next_token.item()
508
- logger.debug(f"Step {step}: Token ID = {token_id}")
 
 
 
 
 
 
 
 
509
 
510
- if eos_token_id is not None and token_id == eos_token_id:
511
- logger.info("EOS token encountered. Ending generation.")
512
- break
513
 
514
- generated_tokens.append(token_id)
 
 
 
 
515
 
516
- # Decode efficiently
517
- new_text = self.tokenizer.decode([token_id], skip_special_tokens=True)
518
- yield new_text
519
 
520
- # Optional heuristic: stop on sentence-ending punctuation
521
- if new_text.strip().endswith(('.', '?', '!')):
522
- logger.info("Sentence-ending punctuation hit. Ending early.")
523
  break
524
 
525
- # Prepare for next step
526
- input_ids = next_token
527
- attention_mask = torch.cat([
528
- attention_mask,
529
- torch.ones((1, 1), dtype=attention_mask.dtype, device=attention_mask.device)
530
- ], dim=-1)
531
-
532
- # Final output logging
533
- final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
534
- logger.info(f"Streaming complete. Tokens generated: {len(generated_tokens)}")
535
-
536
- end_stream_time = time.perf_counter()
537
- stream_time = end_stream_time - start_stream_time
538
- log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Tokens generated: {len(generated_tokens)}. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
539
-
540
- except Exception as e:
541
- logger.error(f"Streaming generation error: {e}")
542
- end_stream_time = time.perf_counter()
543
- stream_time = end_stream_time - start_stream_time
544
- log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
545
- yield f"[Error in streaming generation: {str(e)}]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
 
 
 
 
 
 
547
 
548
  def _top_k_top_p_filtering(self, logits, top_k=50, top_p=0.9):
549
  """Apply top-k and top-p filtering to logits"""
@@ -880,7 +879,7 @@ mathjax_config = '''
880
  window.MathJax = {
881
  tex: {
882
  inlineMath: [['\\\\(', '\\\\)']],
883
- displayMath: [[', '], ['\\\\[', '\\\\]']],
884
  packages: {'[+]': ['ams']}
885
  },
886
  svg: {fontCache: 'global'},
@@ -1078,7 +1077,7 @@ def create_interface():
1078
 
1079
  with gr.Column(elem_classes=["main-container"]):
1080
  # Title Section
1081
- gr.HTML('<div class="title-header"><h1> Mimir 🎓</h1></div>')
1082
 
1083
  # Chat Section
1084
  with gr.Row():
 
414
  return f"[Error generating response: {str(e)}]"
415
 
416
  def stream_generate(self, input: Input, config=None):
417
+ """Streaming generation method for real-time response display."""
418
+ import time
419
+ from datetime import datetime
420
 
421
+ start_stream_time = time.perf_counter()
422
+ current_time = datetime.now()
423
 
424
+ # --- Debug Start ---
425
+ logger.info("Starting stream_generate...")
426
+ logger.debug(f"Input received: {input}")
427
+ # -------------------
428
 
429
+ # Handle input
430
+ if isinstance(input, dict):
431
+ prompt = input.get('input', str(input))
432
+ else:
433
+ prompt = str(input)
434
 
 
 
 
 
 
 
 
 
 
435
  try:
436
+ # === Configurable Generation Parameters ===
437
+ temperature = config.get("temperature", 0.7) if config else 0.7
438
+ top_k = config.get("top_k", 50) if config else 50
439
+ top_p = config.get("top_p", 0.9) if config else 0.9
440
+ max_new_tokens = config.get("max_new_tokens", 600) if config else 600
441
+ timeout_seconds = config.get("timeout_seconds", 15) if config else 15
442
+
443
+ # === Prompt Construction ===
444
+ try:
445
+ messages = [
446
+ {"role": "system", "content": SYSTEM_PROMPT},
447
+ {"role": "user", "content": prompt}
448
+ ]
449
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
450
+ except Exception as e:
451
+ logger.warning(f"Failed to use chat template: {e}")
452
+ if "phi" in self.model_name.lower():
453
+ text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  else:
455
+ text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
 
 
 
 
 
 
 
 
 
 
 
456
 
457
+ # === Tokenize ===
458
+ inputs = self.tokenizer(
459
+ [text],
460
+ return_tensors="pt",
461
+ padding=True,
462
+ truncation=True,
463
+ max_length=self.tokenizer.model_max_length
464
+ )
465
+ if torch.cuda.is_available():
466
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
467
 
468
+ input_ids = inputs["input_ids"]
469
+ attention_mask = inputs["attention_mask"]
470
+ input_length = input_ids.shape[1]
471
 
472
+ # === Streaming Generation ===
473
+ generated_tokens = []
474
+ past_key_values = None
475
+ eos_token_id = self.tokenizer.eos_token_id
476
+ start_time = time.time()
477
 
478
+ logger.info("Beginning token-by-token generation...")
 
 
479
 
480
+ for step in range(max_new_tokens):
481
+ if time.time() - start_time > timeout_seconds:
482
+ logger.warning("Timeout reached. Ending stream.")
483
  break
484
 
485
+ with torch.no_grad():
486
+ model_inputs = {
487
+ "attention_mask": attention_mask,
488
+ "use_cache": True
489
+ }
490
+
491
+ if past_key_values is None:
492
+ model_inputs["input_ids"] = input_ids
493
+ else:
494
+ model_inputs["input_ids"] = next_token
495
+ model_inputs["past_key_values"] = past_key_values
496
+
497
+ outputs = self.model(**model_inputs)
498
+ logits = outputs.logits[:, -1, :]
499
+ past_key_values = outputs.past_key_values
500
+
501
+ # Sampling
502
+ logits = logits / temperature
503
+ filtered_logits = self._top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
504
+ probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
505
+ next_token = torch.multinomial(probs, num_samples=1)
506
+
507
+ token_id = next_token.item()
508
+ logger.debug(f"Step {step}: Token ID = {token_id}")
509
+
510
+ if eos_token_id is not None and token_id == eos_token_id:
511
+ logger.info("EOS token encountered. Ending generation.")
512
+ break
513
+
514
+ generated_tokens.append(token_id)
515
+
516
+ # Decode efficiently
517
+ new_text = self.tokenizer.decode([token_id], skip_special_tokens=True)
518
+ yield new_text
519
+
520
+ # Optional heuristic: stop on sentence-ending punctuation
521
+ if new_text.strip().endswith(('.', '?', '!')):
522
+ logger.info("Sentence-ending punctuation hit. Ending early.")
523
+ break
524
+
525
+ # Prepare for next step
526
+ input_ids = next_token
527
+ attention_mask = torch.cat([
528
+ attention_mask,
529
+ torch.ones((1, 1), dtype=attention_mask.dtype, device=attention_mask.device)
530
+ ], dim=-1)
531
+
532
+ # Final output logging
533
+ final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
534
+ logger.info(f"Streaming complete. Tokens generated: {len(generated_tokens)}")
535
+
536
+ end_stream_time = time.perf_counter()
537
+ stream_time = end_stream_time - start_stream_time
538
+ log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Tokens generated: {len(generated_tokens)}. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
539
 
540
+ except Exception as e:
541
+ logger.error(f"Streaming generation error: {e}")
542
+ end_stream_time = time.perf_counter()
543
+ stream_time = end_stream_time - start_stream_time
544
+ log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
545
+ yield f"[Error in streaming generation: {str(e)}]"
546
 
547
  def _top_k_top_p_filtering(self, logits, top_k=50, top_p=0.9):
548
  """Apply top-k and top-p filtering to logits"""
 
879
  window.MathJax = {
880
  tex: {
881
  inlineMath: [['\\\\(', '\\\\)']],
882
+ displayMath: [['$', '$'], ['\\\\[', '\\\\]']],
883
  packages: {'[+]': ['ams']}
884
  },
885
  svg: {fontCache: 'global'},
 
1077
 
1078
  with gr.Column(elem_classes=["main-container"]):
1079
  # Title Section
1080
+ gr.HTML('<div class="title-header"><h1>🎓 Mimir 🎓</h1></div>')
1081
 
1082
  # Chat Section
1083
  with gr.Row():