jdesiree commited on
Commit
1c43355
·
verified ·
1 Parent(s): 324c8ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -85
app.py CHANGED
@@ -414,95 +414,136 @@ class Phi2EducationalLLM(Runnable):
414
  return f"[Error generating response: {str(e)}]"
415
 
416
  def stream_generate(self, input: Input, config=None):
417
- """Streaming generation method for real-time response display"""
418
- start_stream_time = time.perf_counter()
419
- current_time = datetime.now()
420
-
421
- # Handle both string and dict inputs for flexibility
422
- if isinstance(input, dict):
423
- prompt = input.get('input', str(input))
424
- else:
425
- prompt = str(input)
426
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  try:
428
- # Prepare input text
429
- try:
430
- messages = [
431
- {"role": "system", "content": SYSTEM_PROMPT},
432
- {"role": "user", "content": prompt}
433
- ]
434
- text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
435
- except:
436
- if "phi" in self.model_name.lower():
437
- text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  else:
439
- text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
 
440
 
441
- inputs = self.tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024)
442
- if torch.cuda.is_available():
443
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
- # Initialize for streaming
446
- generated_tokens = []
447
- input_length = inputs.input_ids.shape[1]
448
- max_new_tokens = 600
449
-
450
- # Generate token by token
451
- current_input_ids = inputs.input_ids
452
- current_attention_mask = inputs.attention_mask
453
-
454
- for step in range(max_new_tokens):
455
- with torch.no_grad():
456
- outputs = self.model(
457
- input_ids=current_input_ids,
458
- attention_mask=current_attention_mask,
459
- use_cache=True
460
- )
461
-
462
- # Get next token probabilities
463
- next_token_logits = outputs.logits[:, -1, :]
464
-
465
- # Apply temperature and sampling
466
- next_token_logits = next_token_logits / 0.7
467
-
468
- # Apply top-k and top-p filtering
469
- filtered_logits = self._top_k_top_p_filtering(next_token_logits, top_k=50, top_p=0.9)
470
-
471
- # Sample next token
472
- probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
473
- next_token = torch.multinomial(probs, num_samples=1)
474
-
475
- # Check for end of sequence
476
- if next_token.item() == self.tokenizer.eos_token_id:
477
- break
478
-
479
- # Add to generated tokens
480
- generated_tokens.append(next_token.item())
481
-
482
- # Decode and yield partial result
483
- partial_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
484
- yield partial_text
485
-
486
- # Update input for next iteration
487
- current_input_ids = torch.cat([current_input_ids, next_token], dim=-1)
488
- current_attention_mask = torch.cat([
489
- current_attention_mask,
490
- torch.ones((1, 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device)
491
- ], dim=-1)
492
-
493
- # Final result
494
- final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
495
-
496
- end_stream_time = time.perf_counter()
497
- stream_time = end_stream_time - start_stream_time
498
- log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Tokens generated: {len(generated_tokens)}. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
499
-
500
- except Exception as e:
501
- logger.error(f"Streaming generation error: {e}")
502
- end_stream_time = time.perf_counter()
503
- stream_time = end_stream_time - start_stream_time
504
- log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
505
- yield f"[Error in streaming generation: {str(e)}]"
506
 
507
  def _top_k_top_p_filtering(self, logits, top_k=50, top_p=0.9):
508
  """Apply top-k and top-p filtering to logits"""
 
414
  return f"[Error generating response: {str(e)}]"
415
 
416
  def stream_generate(self, input: Input, config=None):
417
+ """Streaming generation method for real-time response display."""
418
+ import time
419
+ from datetime import datetime
420
+
421
+ start_stream_time = time.perf_counter()
422
+ current_time = datetime.now()
423
+
424
+ # --- Debug Start ---
425
+ logger.info("Starting stream_generate...")
426
+ logger.debug(f"Input received: {input}")
427
+ # -------------------
428
+
429
+ # Handle input
430
+ if isinstance(input, dict):
431
+ prompt = input.get('input', str(input))
432
+ else:
433
+ prompt = str(input)
434
+
435
+ try:
436
+ # === Configurable Generation Parameters ===
437
+ temperature = config.get("temperature", 0.7) if config else 0.7
438
+ top_k = config.get("top_k", 50) if config else 50
439
+ top_p = config.get("top_p", 0.9) if config else 0.9
440
+ max_new_tokens = config.get("max_new_tokens", 600) if config else 600
441
+ timeout_seconds = config.get("timeout_seconds", 15) if config else 15
442
+
443
+ # === Prompt Construction ===
444
  try:
445
+ messages = [
446
+ {"role": "system", "content": SYSTEM_PROMPT},
447
+ {"role": "user", "content": prompt}
448
+ ]
449
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
450
+ except Exception as e:
451
+ logger.warning(f"Failed to use chat template: {e}")
452
+ if "phi" in self.model_name.lower():
453
+ text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
454
+ else:
455
+ text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
456
+
457
+ # === Tokenize ===
458
+ inputs = self.tokenizer(
459
+ [text],
460
+ return_tensors="pt",
461
+ padding=True,
462
+ truncation=True,
463
+ max_length=self.tokenizer.model_max_length
464
+ )
465
+ if torch.cuda.is_available():
466
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
467
+
468
+ input_ids = inputs["input_ids"]
469
+ attention_mask = inputs["attention_mask"]
470
+ input_length = input_ids.shape[1]
471
+
472
+ # === Streaming Generation ===
473
+ generated_tokens = []
474
+ past_key_values = None
475
+ eos_token_id = self.tokenizer.eos_token_id
476
+ start_time = time.time()
477
+
478
+ logger.info("Beginning token-by-token generation...")
479
+
480
+ for step in range(max_new_tokens):
481
+ if time.time() - start_time > timeout_seconds:
482
+ logger.warning("Timeout reached. Ending stream.")
483
+ break
484
+
485
+ with torch.no_grad():
486
+ model_inputs = {
487
+ "attention_mask": attention_mask,
488
+ "use_cache": True
489
+ }
490
+
491
+ if past_key_values is None:
492
+ model_inputs["input_ids"] = input_ids
493
  else:
494
+ model_inputs["input_ids"] = next_token
495
+ model_inputs["past_key_values"] = past_key_values
496
 
497
+ outputs = self.model(**model_inputs)
498
+ logits = outputs.logits[:, -1, :]
499
+ past_key_values = outputs.past_key_values
500
+
501
+ # Sampling
502
+ logits = logits / temperature
503
+ filtered_logits = self._top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
504
+ probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
505
+ next_token = torch.multinomial(probs, num_samples=1)
506
+
507
+ token_id = next_token.item()
508
+ logger.debug(f"Step {step}: Token ID = {token_id}")
509
+
510
+ if eos_token_id is not None and token_id == eos_token_id:
511
+ logger.info("EOS token encountered. Ending generation.")
512
+ break
513
+
514
+ generated_tokens.append(token_id)
515
+
516
+ # Decode efficiently
517
+ new_text = self.tokenizer.decode([token_id], skip_special_tokens=True)
518
+ yield new_text
519
+
520
+ # Optional heuristic: stop on sentence-ending punctuation
521
+ if new_text.strip().endswith(('.', '?', '!')):
522
+ logger.info("Sentence-ending punctuation hit. Ending early.")
523
+ break
524
+
525
+ # Prepare for next step
526
+ input_ids = next_token
527
+ attention_mask = torch.cat([
528
+ attention_mask,
529
+ torch.ones((1, 1), dtype=attention_mask.dtype, device=attention_mask.device)
530
+ ], dim=-1)
531
+
532
+ # Final output logging
533
+ final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
534
+ logger.info(f"Streaming complete. Tokens generated: {len(generated_tokens)}")
535
+
536
+ end_stream_time = time.perf_counter()
537
+ stream_time = end_stream_time - start_stream_time
538
+ log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Tokens generated: {len(generated_tokens)}. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
539
+
540
+ except Exception as e:
541
+ logger.error(f"Streaming generation error: {e}")
542
+ end_stream_time = time.perf_counter()
543
+ stream_time = end_stream_time - start_stream_time
544
+ log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
545
+ yield f"[Error in streaming generation: {str(e)}]"
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
  def _top_k_top_p_filtering(self, logits, top_k=50, top_p=0.9):
549
  """Apply top-k and top-p filtering to logits"""