jdesiree commited on
Commit
393c789
·
verified ·
1 Parent(s): e4fcfe9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -98
app.py CHANGED
@@ -301,7 +301,7 @@ class Phi2EducationalLLM(Runnable):
301
  model_path,
302
  quantization_config=quant_config,
303
  device_map="auto",
304
- dtype=torch.float16,
305
  trust_remote_code=True,
306
  low_cpu_mem_usage=True
307
  )
@@ -328,7 +328,7 @@ class Phi2EducationalLLM(Runnable):
328
  """Optimized model loading for 16GB RAM systems."""
329
  self.model = AutoModelForCausalLM.from_pretrained(
330
  model_path,
331
- dtype=torch.float16, # Use float16 to save memory
332
  device_map="cpu", # Force CPU for stability
333
  trust_remote_code=True,
334
  low_cpu_mem_usage=True,
@@ -345,7 +345,7 @@ class Phi2EducationalLLM(Runnable):
345
 
346
  self.model = AutoModelForCausalLM.from_pretrained(
347
  fallback_model,
348
- dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
349
  device_map="cpu",
350
  trust_remote_code=True,
351
  low_cpu_mem_usage=True
@@ -373,8 +373,8 @@ class Phi2EducationalLLM(Runnable):
373
  except:
374
  # Fallback for models without chat template support
375
  if "phi" in self.model_name.lower():
376
- # Phi-2 format
377
- text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
378
  else:
379
  # Generic format for other models
380
  text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
@@ -414,129 +414,116 @@ class Phi2EducationalLLM(Runnable):
414
  return f"[Error generating response: {str(e)}]"
415
 
416
  def stream_generate(self, input: Input, config=None):
417
- """Streaming generation method for real-time response display."""
418
- import time
419
- from datetime import datetime
420
-
421
  start_stream_time = time.perf_counter()
422
  current_time = datetime.now()
423
-
424
- # --- Debug Start ---
425
  logger.info("Starting stream_generate...")
426
- logger.debug(f"Input received: {input}")
427
- # -------------------
428
-
429
- # Handle input
430
  if isinstance(input, dict):
431
  prompt = input.get('input', str(input))
432
  else:
433
  prompt = str(input)
434
-
435
  try:
436
- # === Configurable Generation Parameters ===
437
- temperature = config.get("temperature", 0.7) if config else 0.7
438
- top_k = config.get("top_k", 50) if config else 50
439
- top_p = config.get("top_p", 0.9) if config else 0.9
440
- max_new_tokens = config.get("max_new_tokens", 600) if config else 600
441
- timeout_seconds = config.get("timeout_seconds", 15) if config else 15
442
-
443
- # === Prompt Construction ===
444
  try:
445
  messages = [
446
  {"role": "system", "content": SYSTEM_PROMPT},
447
  {"role": "user", "content": prompt}
448
  ]
449
  text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
450
  except Exception as e:
451
  logger.warning(f"Failed to use chat template: {e}")
452
  if "phi" in self.model_name.lower():
453
  text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
 
454
  else:
455
  text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
 
456
 
457
- # === Tokenize ===
458
- inputs = self.tokenizer(
459
- [text],
460
- return_tensors="pt",
461
- padding=True,
462
- truncation=True,
463
- max_length=self.tokenizer.model_max_length
464
- )
465
  if torch.cuda.is_available():
466
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
467
 
468
- input_ids = inputs["input_ids"]
469
- attention_mask = inputs["attention_mask"]
470
- input_length = input_ids.shape[1]
471
-
472
- # === Streaming Generation ===
473
  generated_tokens = []
474
- past_key_values = None
475
- eos_token_id = self.tokenizer.eos_token_id
476
- start_time = time.time()
477
-
478
  logger.info("Beginning token-by-token generation...")
479
-
 
 
 
 
480
  for step in range(max_new_tokens):
481
- if time.time() - start_time > timeout_seconds:
482
- logger.warning("Timeout reached. Ending stream.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  break
484
-
485
- with torch.no_grad():
486
- model_inputs = {
487
- "attention_mask": attention_mask,
488
- "use_cache": True
489
- }
490
-
491
- if past_key_values is None:
492
- model_inputs["input_ids"] = input_ids
493
- else:
494
- model_inputs["input_ids"] = next_token
495
- model_inputs["past_key_values"] = past_key_values
496
-
497
- outputs = self.model(**model_inputs)
498
- logits = outputs.logits[:, -1, :]
499
- past_key_values = outputs.past_key_values
500
-
501
- # Sampling
502
- logits = logits / temperature
503
- filtered_logits = self._top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
504
- probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
505
- next_token = torch.multinomial(probs, num_samples=1)
506
-
507
- token_id = next_token.item()
508
- logger.debug(f"Step {step}: Token ID = {token_id}")
509
-
510
- if eos_token_id is not None and token_id == eos_token_id:
511
- logger.info("EOS token encountered. Ending generation.")
512
- break
513
-
514
- generated_tokens.append(token_id)
515
-
516
- # Decode efficiently
517
- new_text = self.tokenizer.decode([token_id], skip_special_tokens=True)
518
- yield new_text
519
-
520
- # Optional heuristic: stop on sentence-ending punctuation
521
- if new_text.strip().endswith(('.', '?', '!')):
522
- logger.info("Sentence-ending punctuation hit. Ending early.")
523
- break
524
-
525
- # Prepare for next step
526
- input_ids = next_token
527
- attention_mask = torch.cat([
528
- attention_mask,
529
- torch.ones((1, 1), dtype=attention_mask.dtype, device=attention_mask.device)
530
- ], dim=-1)
531
-
532
- # Final output logging
533
  final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
534
- logger.info(f"Streaming complete. Tokens generated: {len(generated_tokens)}")
535
-
 
 
 
 
536
  end_stream_time = time.perf_counter()
537
  stream_time = end_stream_time - start_stream_time
538
  log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Tokens generated: {len(generated_tokens)}. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
539
-
 
540
  except Exception as e:
541
  logger.error(f"Streaming generation error: {e}")
542
  end_stream_time = time.perf_counter()
@@ -879,7 +866,7 @@ mathjax_config = '''
879
  window.MathJax = {
880
  tex: {
881
  inlineMath: [['\\\\(', '\\\\)']],
882
- displayMath: [['$', '$'], ['\\\\[', '\\\\]']],
883
  packages: {'[+]': ['ams']}
884
  },
885
  svg: {fontCache: 'global'},
@@ -1077,7 +1064,7 @@ def create_interface():
1077
 
1078
  with gr.Column(elem_classes=["main-container"]):
1079
  # Title Section
1080
- gr.HTML('<div class="title-header"><h1>🎓 Mimir 🎓</h1></div>')
1081
 
1082
  # Chat Section
1083
  with gr.Row():
 
301
  model_path,
302
  quantization_config=quant_config,
303
  device_map="auto",
304
+ torch_dtype=torch.float16,
305
  trust_remote_code=True,
306
  low_cpu_mem_usage=True
307
  )
 
328
  """Optimized model loading for 16GB RAM systems."""
329
  self.model = AutoModelForCausalLM.from_pretrained(
330
  model_path,
331
+ torch_dtype=torch.float16, # Use float16 to save memory
332
  device_map="cpu", # Force CPU for stability
333
  trust_remote_code=True,
334
  low_cpu_mem_usage=True,
 
345
 
346
  self.model = AutoModelForCausalLM.from_pretrained(
347
  fallback_model,
348
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
349
  device_map="cpu",
350
  trust_remote_code=True,
351
  low_cpu_mem_usage=True
 
373
  except:
374
  # Fallback for models without chat template support
375
  if "phi" in self.model_name.lower():
376
+ # Phi-2 proper format
377
+ text = f"{SYSTEM_PROMPT}\n\nQuestion: {prompt}\nAnswer:"
378
  else:
379
  # Generic format for other models
380
  text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
 
414
  return f"[Error generating response: {str(e)}]"
415
 
416
  def stream_generate(self, input: Input, config=None):
417
+ """Streaming generation method for real-time response display"""
 
 
 
418
  start_stream_time = time.perf_counter()
419
  current_time = datetime.now()
 
 
420
  logger.info("Starting stream_generate...")
421
+
422
+ # Handle both string and dict inputs for flexibility
 
 
423
  if isinstance(input, dict):
424
  prompt = input.get('input', str(input))
425
  else:
426
  prompt = str(input)
427
+
428
  try:
429
+ # Prepare input text with better error handling
 
 
 
 
 
 
 
430
  try:
431
  messages = [
432
  {"role": "system", "content": SYSTEM_PROMPT},
433
  {"role": "user", "content": prompt}
434
  ]
435
  text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
436
+ logger.info("Successfully used chat template")
437
  except Exception as e:
438
  logger.warning(f"Failed to use chat template: {e}")
439
  if "phi" in self.model_name.lower():
440
  text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
441
+ logger.info("Using Phi-2 format")
442
  else:
443
  text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
444
+ logger.info("Using generic format")
445
 
446
+ inputs = self.tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024)
 
 
 
 
 
 
 
447
  if torch.cuda.is_available():
448
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
449
 
450
+ # Initialize for streaming
 
 
 
 
451
  generated_tokens = []
452
+ max_new_tokens = 600
 
 
 
453
  logger.info("Beginning token-by-token generation...")
454
+
455
+ # Generate token by token
456
+ current_input_ids = inputs.input_ids
457
+ current_attention_mask = inputs.attention_mask
458
+
459
  for step in range(max_new_tokens):
460
+ try:
461
+ with torch.no_grad():
462
+ outputs = self.model(
463
+ input_ids=current_input_ids,
464
+ attention_mask=current_attention_mask,
465
+ use_cache=True
466
+ )
467
+
468
+ # Get next token probabilities
469
+ next_token_logits = outputs.logits[:, -1, :]
470
+
471
+ # Apply temperature and sampling
472
+ next_token_logits = next_token_logits / 0.7
473
+
474
+ # Apply top-k and top-p filtering
475
+ filtered_logits = self._top_k_top_p_filtering(next_token_logits, top_k=50, top_p=0.9)
476
+
477
+ # Sample next token
478
+ probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
479
+ next_token = torch.multinomial(probs, num_samples=1)
480
+
481
+ # Check for end of sequence
482
+ if next_token.item() == self.tokenizer.eos_token_id:
483
+ logger.info(f"Reached EOS token at step {step}")
484
+ break
485
+
486
+ # Add to generated tokens
487
+ generated_tokens.append(next_token.item())
488
+
489
+ # Decode and yield partial result every few tokens for efficiency
490
+ if step % 3 == 0 or step < 10: # Yield more frequently at start
491
+ partial_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
492
+ if partial_text.strip(): # Only yield non-empty text
493
+ yield partial_text
494
+
495
+ # Safety checks to prevent infinite loops
496
+ if step > 10 and len(generated_tokens) == 0:
497
+ logger.error("No tokens generated after 10 steps, breaking")
498
+ break
499
+
500
+ if step > 50 and len(partial_text.strip()) < 10:
501
+ logger.warning("Very little text generated, continuing...")
502
+
503
+ # Update input for next iteration
504
+ current_input_ids = torch.cat([current_input_ids, next_token], dim=-1)
505
+ current_attention_mask = torch.cat([
506
+ current_attention_mask,
507
+ torch.ones((1, 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device)
508
+ ], dim=-1)
509
+
510
+ except Exception as e:
511
+ logger.error(f"Error in generation step {step}: {e}")
512
  break
513
+
514
+ # Final result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
516
+ if final_text:
517
+ yield final_text
518
+ else:
519
+ logger.error("No final text generated")
520
+ yield "I'm having trouble generating a response. Please try again."
521
+
522
  end_stream_time = time.perf_counter()
523
  stream_time = end_stream_time - start_stream_time
524
  log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Tokens generated: {len(generated_tokens)}. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
525
+ logger.info(f"Stream generation completed: {len(generated_tokens)} tokens in {stream_time:.2f}s")
526
+
527
  except Exception as e:
528
  logger.error(f"Streaming generation error: {e}")
529
  end_stream_time = time.perf_counter()
 
866
  window.MathJax = {
867
  tex: {
868
  inlineMath: [['\\\\(', '\\\\)']],
869
+ displayMath: [[', '], ['\\\\[', '\\\\]']],
870
  packages: {'[+]': ['ams']}
871
  },
872
  svg: {fontCache: 'global'},
 
1064
 
1065
  with gr.Column(elem_classes=["main-container"]):
1066
  # Title Section
1067
+ gr.HTML('<div class="title-header"><h1> Mimir 🎓</h1></div>')
1068
 
1069
  # Chat Section
1070
  with gr.Row():