jdesiree commited on
Commit
293ae98
·
verified ·
1 Parent(s): 10ba1d2

Added Streaming

Browse files
Files changed (1) hide show
  1. app.py +256 -76
app.py CHANGED
@@ -270,16 +270,18 @@ You have the ability to create graphs and charts to enhance your explanations. U
270
  - Provide honest, accurate feedback even when it may not be what the student wants to hear
271
  Your goal is to be an educational partner who empowers students to succeed through understanding, not a service that completes their work for them."""
272
 
273
- # --- Fixed LLM Class with Runnable inheritance ---
274
- class Qwen25SmallLLM(Runnable):
275
- """LLM class that properly inherits from Runnable for LangChain compatibility"""
276
 
277
- def __init__(self, model_path: str = "Qwen/Qwen2.5-3B-Instruct", use_4bit: bool = True):
278
  super().__init__()
279
  logger.info(f"Loading model: {model_path} (use_4bit={use_4bit})")
280
  start_Loading_Model_time = time.perf_counter()
281
  current_time = datetime.now()
282
 
 
 
283
  try:
284
  # Load tokenizer
285
  self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -294,46 +296,63 @@ class Qwen25SmallLLM(Runnable):
294
  llm_int8_skip_modules=["lm_head"]
295
  )
296
 
297
- # Try quantized load with updated dtype parameter
298
  self.model = AutoModelForCausalLM.from_pretrained(
299
  model_path,
300
  quantization_config=quant_config,
301
  device_map="auto",
302
- dtype=torch.bfloat16,
303
  trust_remote_code=True,
304
  low_cpu_mem_usage=True
305
  )
306
  else:
307
- self._load_fallback_model(model_path)
308
 
309
  # Success path - log timing
310
  end_Loading_Model_time = time.perf_counter()
311
  Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
312
- log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
313
 
314
  except Exception as e:
315
- logger.warning(f"Quantized load failed, falling back: {e}")
316
- self._load_fallback_model(model_path)
317
  end_Loading_Model_time = time.perf_counter()
318
  Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
319
- log_metric(f"Model Load time (fallback): {Loading_Model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
320
 
321
  # Ensure pad token
322
  if self.tokenizer.pad_token is None:
323
  self.tokenizer.pad_token = self.tokenizer.eos_token
324
 
325
- def _load_fallback_model(self, model_path: str):
326
- """Fallback if quantization fails."""
327
  self.model = AutoModelForCausalLM.from_pretrained(
328
  model_path,
329
- dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
330
- device_map="auto" if torch.cuda.is_available() else None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  trust_remote_code=True,
332
  low_cpu_mem_usage=True
333
  )
334
 
335
  def invoke(self, input: Input, config=None) -> Output:
336
- """Main invoke method for Runnable compatibility"""
337
  start_invoke_time = time.perf_counter()
338
  current_time = datetime.now()
339
 
@@ -344,26 +363,38 @@ class Qwen25SmallLLM(Runnable):
344
  prompt = str(input)
345
 
346
  try:
347
- messages = [
348
- {"role": "system", "content": SYSTEM_PROMPT},
349
- {"role": "user", "content": prompt}
350
- ]
351
- text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
 
352
 
353
- inputs = self.tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=2048)
354
  if torch.cuda.is_available():
355
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
356
 
357
  with torch.no_grad():
358
  outputs = self.model.generate(
359
  **inputs,
360
- max_new_tokens=800,
361
  do_sample=True,
362
- temperature=0.7,
363
  top_p=0.9,
364
- top_k=50,
365
  repetition_penalty=1.1,
366
- pad_token_id=self.tokenizer.eos_token_id
 
 
367
  )
368
 
369
  new_tokens = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)]
@@ -371,17 +402,131 @@ class Qwen25SmallLLM(Runnable):
371
 
372
  end_invoke_time = time.perf_counter()
373
  invoke_time = end_invoke_time - start_invoke_time
374
- log_metric(f"LLM Invoke time: {invoke_time:0.4f} seconds. Input length: {len(prompt)} chars. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
375
 
376
- return result
377
 
378
  except Exception as e:
379
  logger.error(f"Generation error: {e}")
380
  end_invoke_time = time.perf_counter()
381
  invoke_time = end_invoke_time - start_invoke_time
382
- log_metric(f"LLM Invoke time (error): {invoke_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
383
  return f"[Error generating response: {str(e)}]"
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  @property
386
  def InputType(self) -> Type[Input]:
387
  return str
@@ -392,13 +537,13 @@ class Qwen25SmallLLM(Runnable):
392
 
393
  # --- LangGraph Agent Implementation ---
394
  class Educational_Agent:
395
- """Modern LangGraph-based educational agent"""
396
 
397
  def __init__(self):
398
  start_init_and_langgraph_time = time.perf_counter()
399
  current_time = datetime.now()
400
 
401
- self.llm = Qwen25SmallLLM(model_path="Qwen/Qwen2.5-1.5B-Instruct", use_4bit=True)
402
  self.tool_decision_engine = Tool_Decision_Engine(self.llm)
403
 
404
  # Create LangGraph workflow
@@ -409,7 +554,7 @@ class Educational_Agent:
409
  log_metric(f"Init and LangGraph workflow setup time: {init_and_langgraph_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
410
 
411
  def _create_langgraph_workflow(self):
412
- """Create the LangGraph workflow"""
413
  # Define tools
414
  tools = [Create_Graph_Tool]
415
  tool_node = ToolNode(tools)
@@ -600,7 +745,31 @@ Otherwise, provide a regular educational response.
600
  return workflow.compile(checkpointer=memory)
601
 
602
  def chat(self, message: str, thread_id: str = "default") -> str:
603
- """Main chat interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  start_chat_time = time.perf_counter()
605
  current_time = datetime.now()
606
 
@@ -614,37 +783,45 @@ Otherwise, provide a regular educational response.
614
  "educational_context": None
615
  }
616
 
617
- # Run the workflow
618
- result = self.app.invoke(initial_state, config=config)
619
-
620
- # Extract the final response
621
- final_messages = result["messages"]
622
-
623
- # Build the response from all assistant and tool messages
624
- response_parts = []
625
- for msg in final_messages:
626
- if isinstance(msg, AIMessage) and msg.content:
627
- response_parts.append(msg.content)
628
- elif isinstance(msg, ToolMessage) and msg.content:
629
- response_parts.append(msg.content)
630
 
631
- if response_parts:
632
- final_response = "\n\n".join(response_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  else:
634
- final_response = "I apologize, but I couldn't generate a proper response."
 
 
 
635
 
636
  end_chat_time = time.perf_counter()
637
  chat_time = end_chat_time - start_chat_time
638
- log_metric(f"Complete chat time: {chat_time:0.4f} seconds. Response length: {len(final_response)} chars. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
639
-
640
- return final_response
641
 
642
  except Exception as e:
643
- logger.error(f"Error in LangGraph chat: {e}")
644
  end_chat_time = time.perf_counter()
645
  chat_time = end_chat_time - start_chat_time
646
- log_metric(f"Complete chat time (error): {chat_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
647
- return f"I apologize, but I encountered an error: {str(e)}"
648
 
649
  # --- Global Agent Instance ---
650
  agent = None
@@ -662,7 +839,7 @@ mathjax_config = '''
662
  window.MathJax = {
663
  tex: {
664
  inlineMath: [['\\\\(', '\\\\)']],
665
- displayMath: [['$', '$'], ['\\\\[', '\\\\]']],
666
  packages: {'[+]': ['ams']}
667
  },
668
  svg: {fontCache: 'global'},
@@ -732,7 +909,7 @@ def smart_truncate(text, max_length=3000):
732
  return result
733
 
734
  def generate_response_with_agent(message, max_retries=3):
735
- """Generate response using LangGraph agent."""
736
  start_generate_response_with_agent_time = time.perf_counter()
737
  current_time = datetime.now()
738
 
@@ -741,16 +918,15 @@ def generate_response_with_agent(message, max_retries=3):
741
  # Get the agent
742
  current_agent = get_agent()
743
 
744
- # Use the agent's chat method
745
- response = current_agent.chat(message)
746
-
747
- result = smart_truncate(response)
748
 
749
  end_generate_response_with_agent_time = time.perf_counter()
750
  generate_response_with_agent_time = end_generate_response_with_agent_time - start_generate_response_with_agent_time
751
  log_metric(f"Generate response with agent time: {generate_response_with_agent_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
752
 
753
- return result
754
 
755
  except Exception as e:
756
  logger.error(f"Agent error (attempt {attempt + 1}): {e}")
@@ -761,32 +937,33 @@ def generate_response_with_agent(message, max_retries=3):
761
  end_generate_response_with_agent_time = time.perf_counter()
762
  generate_response_with_agent_time = end_generate_response_with_agent_time - start_generate_response_with_agent_time
763
  log_metric(f"Generate response with agent time (error): {generate_response_with_agent_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
764
- return f"I apologize, but I encountered an error while processing your message: {str(e)}"
765
 
766
  def chat_response(message, history=None):
767
- """Process chat message and return response."""
768
  start_chat_response_time = time.perf_counter()
769
  current_time = datetime.now()
770
 
771
  try:
772
- # Generate response with LangGraph agent
773
- response = generate_response_with_agent(message)
 
 
 
774
 
775
  end_chat_response_time = time.perf_counter()
776
  chat_response_time = end_chat_response_time - start_chat_response_time
777
  log_metric(f"Chat response time: {chat_response_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
778
 
779
- return response
780
-
781
  except Exception as e:
782
  logger.error(f"Error in chat_response: {e}")
783
  end_chat_response_time = time.perf_counter()
784
  chat_response_time = end_chat_response_time - start_chat_response_time
785
  log_metric(f"Chat response time (error): {chat_response_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
786
- return f"I apologize, but I encountered an error while processing your message: {str(e)}"
787
 
788
  def respond_and_update(message, history):
789
- """Main function to handle user submission."""
790
  if not message.strip():
791
  return history, ""
792
 
@@ -794,11 +971,14 @@ def respond_and_update(message, history):
794
  history.append({"role": "user", "content": message})
795
  yield history, ""
796
 
797
- # Generate response
798
- response = chat_response(message)
799
 
800
- history.append({"role": "assistant", "content": response})
801
- yield history, ""
 
 
 
802
 
803
  def clear_chat():
804
  """Clear the chat history."""
@@ -905,7 +1085,7 @@ def create_interface():
905
  if __name__ == "__main__":
906
  try:
907
  logger.info("=" * 50)
908
- logger.info("Starting Mimir Application with LangGraph")
909
  logger.info("=" * 50)
910
 
911
  # Step 1: Preload the model and agent
@@ -929,5 +1109,5 @@ if __name__ == "__main__":
929
  )
930
 
931
  except Exception as e:
932
- logger.error(f"❌ Failed to launch Mimir with LangGraph: {e}")
933
  raise
 
270
  - Provide honest, accurate feedback even when it may not be what the student wants to hear
271
  Your goal is to be an educational partner who empowers students to succeed through understanding, not a service that completes their work for them."""
272
 
273
+ # --- Updated LLM Class with Microsoft Phi-2 and TinyLlama fallback ---
274
+ class Phi2EducationalLLM(Runnable):
275
+ """LLM class optimized for Microsoft Phi-2 with TinyLlama fallback for educational tasks"""
276
 
277
+ def __init__(self, model_path: str = "microsoft/phi-2", fallback_model: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_4bit: bool = False):
278
  super().__init__()
279
  logger.info(f"Loading model: {model_path} (use_4bit={use_4bit})")
280
  start_Loading_Model_time = time.perf_counter()
281
  current_time = datetime.now()
282
 
283
+ self.model_name = model_path
284
+
285
  try:
286
  # Load tokenizer
287
  self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
296
  llm_int8_skip_modules=["lm_head"]
297
  )
298
 
299
+ # Try quantized load
300
  self.model = AutoModelForCausalLM.from_pretrained(
301
  model_path,
302
  quantization_config=quant_config,
303
  device_map="auto",
304
+ torch_dtype=torch.float16,
305
  trust_remote_code=True,
306
  low_cpu_mem_usage=True
307
  )
308
  else:
309
+ self._load_optimized_model(model_path)
310
 
311
  # Success path - log timing
312
  end_Loading_Model_time = time.perf_counter()
313
  Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
314
+ log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Model: {model_path}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
315
 
316
  except Exception as e:
317
+ logger.warning(f"Primary model {model_path} failed, trying fallback {fallback_model}: {e}")
318
+ self._load_fallback_model(fallback_model)
319
  end_Loading_Model_time = time.perf_counter()
320
  Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
321
+ log_metric(f"Model Load time (fallback): {Loading_Model_time:0.4f} seconds. Model: {fallback_model}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
322
 
323
  # Ensure pad token
324
  if self.tokenizer.pad_token is None:
325
  self.tokenizer.pad_token = self.tokenizer.eos_token
326
 
327
+ def _load_optimized_model(self, model_path: str):
328
+ """Optimized model loading for 16GB RAM systems."""
329
  self.model = AutoModelForCausalLM.from_pretrained(
330
  model_path,
331
+ torch_dtype=torch.float16, # Use float16 to save memory
332
+ device_map="cpu", # Force CPU for stability
333
+ trust_remote_code=True,
334
+ low_cpu_mem_usage=True,
335
+ max_memory={0: "14GB"} # Reserve 2GB for system/gradio
336
+ )
337
+
338
+ def _load_fallback_model(self, fallback_model: str):
339
+ """Fallback to TinyLlama if Phi-2 fails."""
340
+ logger.info(f"Loading fallback model: {fallback_model}")
341
+
342
+ # Update tokenizer for fallback model
343
+ self.tokenizer = AutoTokenizer.from_pretrained(fallback_model, trust_remote_code=True)
344
+ self.model_name = fallback_model
345
+
346
+ self.model = AutoModelForCausalLM.from_pretrained(
347
+ fallback_model,
348
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
349
+ device_map="cpu",
350
  trust_remote_code=True,
351
  low_cpu_mem_usage=True
352
  )
353
 
354
  def invoke(self, input: Input, config=None) -> Output:
355
+ """Main invoke method optimized for educational tasks"""
356
  start_invoke_time = time.perf_counter()
357
  current_time = datetime.now()
358
 
 
363
  prompt = str(input)
364
 
365
  try:
366
+ # Try chat template first (works with Phi-2 and TinyLlama)
367
+ try:
368
+ messages = [
369
+ {"role": "system", "content": SYSTEM_PROMPT},
370
+ {"role": "user", "content": prompt}
371
+ ]
372
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
373
+ except:
374
+ # Fallback for models without chat template support
375
+ if "phi" in self.model_name.lower():
376
+ # Phi-2 format
377
+ text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
378
+ else:
379
+ # Generic format for other models
380
+ text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
381
 
382
+ inputs = self.tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024)
383
  if torch.cuda.is_available():
384
  inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
385
 
386
  with torch.no_grad():
387
  outputs = self.model.generate(
388
  **inputs,
389
+ max_new_tokens=600, # Sufficient for comprehensive educational responses
390
  do_sample=True,
391
+ temperature=0.7, # Good balance for educational content
392
  top_p=0.9,
393
+ top_k=50, # Reasonable variety for educational explanations
394
  repetition_penalty=1.1,
395
+ pad_token_id=self.tokenizer.eos_token_id,
396
+ early_stopping=True,
397
+ use_cache=True # Enable KV cache for faster generation
398
  )
399
 
400
  new_tokens = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)]
 
402
 
403
  end_invoke_time = time.perf_counter()
404
  invoke_time = end_invoke_time - start_invoke_time
405
+ log_metric(f"LLM Invoke time: {invoke_time:0.4f} seconds. Input length: {len(prompt)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
406
 
407
+ return result if result else "I'm still learning how to respond to that properly."
408
 
409
  except Exception as e:
410
  logger.error(f"Generation error: {e}")
411
  end_invoke_time = time.perf_counter()
412
  invoke_time = end_invoke_time - start_invoke_time
413
+ log_metric(f"LLM Invoke time (error): {invoke_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
414
  return f"[Error generating response: {str(e)}]"
415
 
416
+ def stream_generate(self, input: Input, config=None):
417
+ """Streaming generation method for real-time response display"""
418
+ start_stream_time = time.perf_counter()
419
+ current_time = datetime.now()
420
+
421
+ # Handle both string and dict inputs for flexibility
422
+ if isinstance(input, dict):
423
+ prompt = input.get('input', str(input))
424
+ else:
425
+ prompt = str(input)
426
+
427
+ try:
428
+ # Prepare input text
429
+ try:
430
+ messages = [
431
+ {"role": "system", "content": SYSTEM_PROMPT},
432
+ {"role": "user", "content": prompt}
433
+ ]
434
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
435
+ except:
436
+ if "phi" in self.model_name.lower():
437
+ text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
438
+ else:
439
+ text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
440
+
441
+ inputs = self.tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024)
442
+ if torch.cuda.is_available():
443
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
444
+
445
+ # Initialize for streaming
446
+ generated_tokens = []
447
+ input_length = inputs.input_ids.shape[1]
448
+ max_new_tokens = 600
449
+
450
+ # Generate token by token
451
+ current_input_ids = inputs.input_ids
452
+ current_attention_mask = inputs.attention_mask
453
+
454
+ for step in range(max_new_tokens):
455
+ with torch.no_grad():
456
+ outputs = self.model(
457
+ input_ids=current_input_ids,
458
+ attention_mask=current_attention_mask,
459
+ use_cache=True
460
+ )
461
+
462
+ # Get next token probabilities
463
+ next_token_logits = outputs.logits[:, -1, :]
464
+
465
+ # Apply temperature and sampling
466
+ next_token_logits = next_token_logits / 0.7
467
+
468
+ # Apply top-k and top-p filtering
469
+ filtered_logits = self._top_k_top_p_filtering(next_token_logits, top_k=50, top_p=0.9)
470
+
471
+ # Sample next token
472
+ probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
473
+ next_token = torch.multinomial(probs, num_samples=1)
474
+
475
+ # Check for end of sequence
476
+ if next_token.item() == self.tokenizer.eos_token_id:
477
+ break
478
+
479
+ # Add to generated tokens
480
+ generated_tokens.append(next_token.item())
481
+
482
+ # Decode and yield partial result
483
+ partial_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
484
+ yield partial_text
485
+
486
+ # Update input for next iteration
487
+ current_input_ids = torch.cat([current_input_ids, next_token], dim=-1)
488
+ current_attention_mask = torch.cat([
489
+ current_attention_mask,
490
+ torch.ones((1, 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device)
491
+ ], dim=-1)
492
+
493
+ # Final result
494
+ final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
495
+
496
+ end_stream_time = time.perf_counter()
497
+ stream_time = end_stream_time - start_stream_time
498
+ log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Tokens generated: {len(generated_tokens)}. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
499
+
500
+ except Exception as e:
501
+ logger.error(f"Streaming generation error: {e}")
502
+ end_stream_time = time.perf_counter()
503
+ stream_time = end_stream_time - start_stream_time
504
+ log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
505
+ yield f"[Error in streaming generation: {str(e)}]"
506
+
507
+ def _top_k_top_p_filtering(self, logits, top_k=50, top_p=0.9):
508
+ """Apply top-k and top-p filtering to logits"""
509
+ if top_k > 0:
510
+ # Get top-k indices
511
+ top_k = min(top_k, logits.size(-1))
512
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
513
+ logits[indices_to_remove] = float('-inf')
514
+
515
+ if top_p < 1.0:
516
+ # Sort and get cumulative probabilities
517
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
518
+ cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
519
+
520
+ # Remove tokens with cumulative probability above the threshold
521
+ sorted_indices_to_remove = cumulative_probs > top_p
522
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
523
+ sorted_indices_to_remove[..., 0] = 0
524
+
525
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
526
+ logits[indices_to_remove] = float('-inf')
527
+
528
+ return logits
529
+
530
  @property
531
  def InputType(self) -> Type[Input]:
532
  return str
 
537
 
538
  # --- LangGraph Agent Implementation ---
539
  class Educational_Agent:
540
+ """Modern LangGraph-based educational agent with Phi-2 and streaming"""
541
 
542
  def __init__(self):
543
  start_init_and_langgraph_time = time.perf_counter()
544
  current_time = datetime.now()
545
 
546
+ self.llm = Phi2EducationalLLM(model_path="microsoft/phi-2", fallback_model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_4bit=False)
547
  self.tool_decision_engine = Tool_Decision_Engine(self.llm)
548
 
549
  # Create LangGraph workflow
 
554
  log_metric(f"Init and LangGraph workflow setup time: {init_and_langgraph_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
555
 
556
  def _create_langgraph_workflow(self):
557
+ """Create the complete LangGraph workflow"""
558
  # Define tools
559
  tools = [Create_Graph_Tool]
560
  tool_node = ToolNode(tools)
 
745
  return workflow.compile(checkpointer=memory)
746
 
747
  def chat(self, message: str, thread_id: str = "default") -> str:
748
+ """Main chat interface (non-streaming for backward compatibility)"""
749
+ start_chat_time = time.perf_counter()
750
+ current_time = datetime.now()
751
+
752
+ try:
753
+ # Collect all streaming parts into final response
754
+ final_response = ""
755
+ for partial_response in self.stream_chat(message, thread_id):
756
+ final_response = partial_response
757
+
758
+ end_chat_time = time.perf_counter()
759
+ chat_time = end_chat_time - start_chat_time
760
+ log_metric(f"Complete chat time: {chat_time:0.4f} seconds. Response length: {len(final_response)} chars. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
761
+
762
+ return final_response
763
+
764
+ except Exception as e:
765
+ logger.error(f"Error in LangGraph chat: {e}")
766
+ end_chat_time = time.perf_counter()
767
+ chat_time = end_chat_time - start_chat_time
768
+ log_metric(f"Complete chat time (error): {chat_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
769
+ return f"I apologize, but I encountered an error: {str(e)}"
770
+
771
+ def stream_chat(self, message: str, thread_id: str = "default"):
772
+ """Streaming chat interface that yields partial responses"""
773
  start_chat_time = time.perf_counter()
774
  current_time = datetime.now()
775
 
 
783
  "educational_context": None
784
  }
785
 
786
+ # First check if tools are needed
787
+ user_query = message
788
+ needs_tools = self.tool_decision_engine.should_use_visualization(user_query)
 
 
 
 
 
 
 
 
 
 
789
 
790
+ if needs_tools:
791
+ logger.info("Query requires visualization - handling tool call first")
792
+ # Handle tool generation first (non-streaming for tools)
793
+ result = self.app.invoke(initial_state, config=config)
794
+ final_messages = result["messages"]
795
+
796
+ # Build the response from all assistant and tool messages
797
+ response_parts = []
798
+ for msg in final_messages:
799
+ if isinstance(msg, AIMessage) and msg.content:
800
+ response_parts.append(msg.content)
801
+ elif isinstance(msg, ToolMessage) and msg.content:
802
+ response_parts.append(msg.content)
803
+
804
+ final_response = "\n\n".join(response_parts) if response_parts else "I couldn't generate a proper response."
805
+
806
+ # For tool responses, yield the complete result at once
807
+ yield final_response
808
+
809
  else:
810
+ logger.info("Streaming regular response without tools")
811
+ # Stream the LLM response directly
812
+ for partial_text in self.llm.stream_generate(message):
813
+ yield smart_truncate(partial_text, max_length=3000)
814
 
815
  end_chat_time = time.perf_counter()
816
  chat_time = end_chat_time - start_chat_time
817
+ log_metric(f"Complete streaming chat time: {chat_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
 
 
818
 
819
  except Exception as e:
820
+ logger.error(f"Error in streaming chat: {e}")
821
  end_chat_time = time.perf_counter()
822
  chat_time = end_chat_time - start_chat_time
823
+ log_metric(f"Complete streaming chat time (error): {chat_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
824
+ yield f"I apologize, but I encountered an error: {str(e)}"
825
 
826
  # --- Global Agent Instance ---
827
  agent = None
 
839
  window.MathJax = {
840
  tex: {
841
  inlineMath: [['\\\\(', '\\\\)']],
842
+ displayMath: [[', '], ['\\\\[', '\\\\]']],
843
  packages: {'[+]': ['ams']}
844
  },
845
  svg: {fontCache: 'global'},
 
909
  return result
910
 
911
  def generate_response_with_agent(message, max_retries=3):
912
+ """Generate streaming response using LangGraph agent."""
913
  start_generate_response_with_agent_time = time.perf_counter()
914
  current_time = datetime.now()
915
 
 
918
  # Get the agent
919
  current_agent = get_agent()
920
 
921
+ # Use the agent's streaming chat method
922
+ for partial_response in current_agent.stream_chat(message):
923
+ yield partial_response
 
924
 
925
  end_generate_response_with_agent_time = time.perf_counter()
926
  generate_response_with_agent_time = end_generate_response_with_agent_time - start_generate_response_with_agent_time
927
  log_metric(f"Generate response with agent time: {generate_response_with_agent_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
928
 
929
+ return
930
 
931
  except Exception as e:
932
  logger.error(f"Agent error (attempt {attempt + 1}): {e}")
 
937
  end_generate_response_with_agent_time = time.perf_counter()
938
  generate_response_with_agent_time = end_generate_response_with_agent_time - start_generate_response_with_agent_time
939
  log_metric(f"Generate response with agent time (error): {generate_response_with_agent_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
940
+ yield f"I apologize, but I encountered an error while processing your message: {str(e)}"
941
 
942
  def chat_response(message, history=None):
943
+ """Process chat message and return streaming response."""
944
  start_chat_response_time = time.perf_counter()
945
  current_time = datetime.now()
946
 
947
  try:
948
+ # Generate streaming response with LangGraph agent
949
+ final_response = ""
950
+ for partial_response in generate_response_with_agent(message):
951
+ final_response = partial_response
952
+ yield partial_response
953
 
954
  end_chat_response_time = time.perf_counter()
955
  chat_response_time = end_chat_response_time - start_chat_response_time
956
  log_metric(f"Chat response time: {chat_response_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
957
 
 
 
958
  except Exception as e:
959
  logger.error(f"Error in chat_response: {e}")
960
  end_chat_response_time = time.perf_counter()
961
  chat_response_time = end_chat_response_time - start_chat_response_time
962
  log_metric(f"Chat response time (error): {chat_response_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
963
+ yield f"I apologize, but I encountered an error while processing your message: {str(e)}"
964
 
965
  def respond_and_update(message, history):
966
+ """Main function to handle user submission with streaming."""
967
  if not message.strip():
968
  return history, ""
969
 
 
971
  history.append({"role": "user", "content": message})
972
  yield history, ""
973
 
974
+ # Start with empty assistant message
975
+ history.append({"role": "assistant", "content": ""})
976
 
977
+ # Stream the response
978
+ for partial_response in chat_response(message):
979
+ # Update the last message (assistant) with the partial response
980
+ history[-1]["content"] = partial_response
981
+ yield history, ""
982
 
983
  def clear_chat():
984
  """Clear the chat history."""
 
1085
  if __name__ == "__main__":
1086
  try:
1087
  logger.info("=" * 50)
1088
+ logger.info("Starting Mimir Application with Microsoft Phi-2 and Streaming")
1089
  logger.info("=" * 50)
1090
 
1091
  # Step 1: Preload the model and agent
 
1109
  )
1110
 
1111
  except Exception as e:
1112
+ logger.error(f"❌ Failed to launch Mimir with Microsoft Phi-2: {e}")
1113
  raise