Spaces:
Sleeping
Sleeping
Major Update to app.py
Browse files- Removed fallback model
- Changed model to Phi-3-mini-4k-instruct
- Removed bulky streaming script for TextIteratorStreamer
app.py
CHANGED
|
@@ -11,11 +11,12 @@ from dotenv import load_dotenv
|
|
| 11 |
import logging
|
| 12 |
import re
|
| 13 |
import json
|
|
|
|
| 14 |
from datetime import datetime
|
| 15 |
from typing import Annotated, Sequence, TypedDict, List, Optional, Any, Type
|
| 16 |
from pydantic import BaseModel, Field
|
| 17 |
|
| 18 |
-
# LangGraph imports
|
| 19 |
from langgraph.graph import StateGraph, START, END
|
| 20 |
from langgraph.graph.message import add_messages
|
| 21 |
from langgraph.checkpoint.memory import MemorySaver
|
|
@@ -24,11 +25,11 @@ from langgraph.prebuilt import ToolNode
|
|
| 24 |
# Updated LangChain imports
|
| 25 |
from langchain_core.tools import tool
|
| 26 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage, BaseMessage
|
| 27 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 28 |
from langchain_core.runnables import Runnable
|
| 29 |
from langchain_core.runnables.utils import Input, Output
|
| 30 |
|
| 31 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 32 |
import torch
|
| 33 |
|
| 34 |
load_dotenv(".env")
|
|
@@ -223,13 +224,15 @@ Decision:"""
|
|
| 223 |
log_metric(f"Tool decision time (error): {graph_decision_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 224 |
return False
|
| 225 |
|
| 226 |
-
# --- System Prompt ---
|
| 227 |
SYSTEM_PROMPT = """You are Mimir, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
|
|
|
|
| 228 |
## Core Educational Principles
|
| 229 |
- Provide comprehensive, educational responses that help students truly understand concepts
|
| 230 |
- Use minimal formatting, with markdown bolding reserved for **key terms** only
|
| 231 |
- Prioritize teaching methodology over answer delivery
|
| 232 |
- Foster critical thinking and independent problem-solving skills
|
|
|
|
| 233 |
## Tone and Communication Style
|
| 234 |
- Maintain an engaging, friendly tone appropriate for high school students
|
| 235 |
- Write at a reading level that is accessible yet intellectually stimulating
|
|
@@ -239,52 +242,64 @@ SYSTEM_PROMPT = """You are Mimir, an expert multi-concept tutor designed to faci
|
|
| 239 |
- Skip flattery and respond directly to questions
|
| 240 |
- Do not use emojis or actions in asterisks unless specifically requested
|
| 241 |
- Present critiques and corrections kindly as educational opportunities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
## Academic Integrity Approach
|
| 243 |
-
|
| 244 |
-
- **Guide through processes**: Break down problems into conceptual components
|
| 245 |
-
- **Ask clarifying questions**: Understand what the student
|
| 246 |
-
- **Provide similar examples**: Work through analogous problems
|
| 247 |
-
- **Encourage original thinking**: Help students develop
|
| 248 |
-
- **Suggest study strategies**: Recommend effective learning approaches
|
| 249 |
-
|
| 250 |
-
You have the ability to create graphs and charts to enhance your explanations. Use this capability proactively when:
|
| 251 |
-
- Explaining mathematical concepts (functions, distributions, relationships)
|
| 252 |
-
- Teaching statistical analysis or data interpretation
|
| 253 |
-
- Discussing scientific trends, patterns, or experimental results
|
| 254 |
-
- Comparing different options, outcomes, or scenarios
|
| 255 |
-
- Illustrating economic principles, business metrics, or financial concepts
|
| 256 |
-
- Showing survey results, demographic data, or research findings
|
| 257 |
-
- Demonstrating any concept where visualization aids comprehension
|
| 258 |
-
**Important**: Only use the graph tool when visualization would genuinely help explain a concept. For general conversation, explanations, or questions that don't involve data or relationships, respond normally without tools.
|
| 259 |
## Response Guidelines
|
| 260 |
-
- **For math problems**: Explain concepts
|
| 261 |
-
- **For multiple-choice questions**: Discuss
|
| 262 |
-
- **For essays
|
| 263 |
-
- **For factual questions**: Provide educational context and encourage
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
- Provide honest, accurate feedback even when it may not be what the student wants to hear
|
| 271 |
-
Your goal is to be an educational partner who empowers students to succeed through understanding, not a service that completes their work for them."""
|
| 272 |
-
|
| 273 |
-
# --- Updated LLM Class with Microsoft Phi-2 and TinyLlama fallback ---
|
| 274 |
-
class Phi2EducationalLLM(Runnable):
|
| 275 |
-
"""LLM class optimized for Microsoft Phi-2 with TinyLlama fallback for educational tasks"""
|
| 276 |
|
| 277 |
-
def __init__(self, model_path: str = "microsoft/
|
| 278 |
super().__init__()
|
| 279 |
-
logger.info(f"Loading model: {model_path} (use_4bit={use_4bit})")
|
| 280 |
start_Loading_Model_time = time.perf_counter()
|
| 281 |
current_time = datetime.now()
|
| 282 |
|
| 283 |
self.model_name = model_path
|
| 284 |
|
| 285 |
try:
|
| 286 |
-
# Load tokenizer
|
| 287 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
if use_4bit:
|
| 290 |
quant_config = BitsAndBytesConfig(
|
|
@@ -296,14 +311,14 @@ class Phi2EducationalLLM(Runnable):
|
|
| 296 |
llm_int8_skip_modules=["lm_head"]
|
| 297 |
)
|
| 298 |
|
| 299 |
-
# Try quantized load
|
| 300 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 301 |
model_path,
|
| 302 |
quantization_config=quant_config,
|
| 303 |
device_map="auto",
|
| 304 |
-
|
| 305 |
trust_remote_code=True,
|
| 306 |
-
low_cpu_mem_usage=True
|
|
|
|
| 307 |
)
|
| 308 |
else:
|
| 309 |
self._load_optimized_model(model_path)
|
|
@@ -314,45 +329,48 @@ class Phi2EducationalLLM(Runnable):
|
|
| 314 |
log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Model: {model_path}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 315 |
|
| 316 |
except Exception as e:
|
| 317 |
-
logger.
|
| 318 |
-
|
| 319 |
-
end_Loading_Model_time = time.perf_counter()
|
| 320 |
-
Loading_Model_time = end_Loading_Model_time - start_Loading_Model_time
|
| 321 |
-
log_metric(f"Model Load time (fallback): {Loading_Model_time:0.4f} seconds. Model: {fallback_model}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 322 |
|
| 323 |
-
# Ensure pad token
|
| 324 |
if self.tokenizer.pad_token is None:
|
| 325 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 326 |
|
|
|
|
|
|
|
|
|
|
| 327 |
def _load_optimized_model(self, model_path: str):
|
| 328 |
-
"""Optimized model loading for
|
| 329 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 330 |
model_path,
|
| 331 |
-
|
| 332 |
-
device_map="
|
| 333 |
trust_remote_code=True,
|
| 334 |
low_cpu_mem_usage=True,
|
| 335 |
-
|
| 336 |
)
|
| 337 |
|
| 338 |
-
def
|
| 339 |
-
"""
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
def invoke(self, input: Input, config=None) -> Output:
|
| 355 |
-
"""Main invoke method optimized for
|
| 356 |
start_invoke_time = time.perf_counter()
|
| 357 |
current_time = datetime.now()
|
| 358 |
|
|
@@ -363,42 +381,37 @@ class Phi2EducationalLLM(Runnable):
|
|
| 363 |
prompt = str(input)
|
| 364 |
|
| 365 |
try:
|
| 366 |
-
#
|
| 367 |
-
|
| 368 |
-
messages = [
|
| 369 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 370 |
-
{"role": "user", "content": prompt}
|
| 371 |
-
]
|
| 372 |
-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 373 |
-
except:
|
| 374 |
-
# Fallback for models without chat template support
|
| 375 |
-
if "phi" in self.model_name.lower():
|
| 376 |
-
# Phi-2 proper format
|
| 377 |
-
text = f"{SYSTEM_PROMPT}\n\nQuestion: {prompt}\nAnswer:"
|
| 378 |
-
else:
|
| 379 |
-
# Generic format for other models
|
| 380 |
-
text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
| 381 |
|
| 382 |
-
inputs = self.tokenizer(
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
with torch.no_grad():
|
| 387 |
outputs = self.model.generate(
|
| 388 |
**inputs,
|
| 389 |
-
max_new_tokens=
|
| 390 |
do_sample=True,
|
| 391 |
-
temperature=0.7,
|
| 392 |
top_p=0.9,
|
| 393 |
-
top_k=50,
|
| 394 |
repetition_penalty=1.1,
|
| 395 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 396 |
early_stopping=True,
|
| 397 |
-
use_cache=True
|
| 398 |
)
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
|
|
|
| 402 |
|
| 403 |
end_invoke_time = time.perf_counter()
|
| 404 |
invoke_time = end_invoke_time - start_invoke_time
|
|
@@ -414,115 +427,78 @@ class Phi2EducationalLLM(Runnable):
|
|
| 414 |
return f"[Error generating response: {str(e)}]"
|
| 415 |
|
| 416 |
def stream_generate(self, input: Input, config=None):
|
| 417 |
-
"""Streaming generation
|
| 418 |
start_stream_time = time.perf_counter()
|
| 419 |
current_time = datetime.now()
|
| 420 |
-
logger.info("Starting stream_generate...")
|
| 421 |
|
| 422 |
-
# Handle both string and dict inputs
|
| 423 |
if isinstance(input, dict):
|
| 424 |
prompt = input.get('input', str(input))
|
| 425 |
else:
|
| 426 |
prompt = str(input)
|
| 427 |
|
| 428 |
try:
|
| 429 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
try:
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 436 |
-
logger.info("Successfully used chat template")
|
| 437 |
except Exception as e:
|
| 438 |
-
logger.
|
| 439 |
-
|
| 440 |
-
text = f"Instruct: {SYSTEM_PROMPT}\n\nUser: {prompt}\nOutput:"
|
| 441 |
-
logger.info("Using Phi-2 format")
|
| 442 |
-
else:
|
| 443 |
-
text = f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
| 444 |
-
logger.info("Using generic format")
|
| 445 |
-
|
| 446 |
-
inputs = self.tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
| 447 |
-
if torch.cuda.is_available():
|
| 448 |
-
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 449 |
|
| 450 |
-
#
|
| 451 |
-
|
| 452 |
-
max_new_tokens = 600
|
| 453 |
-
logger.info("Beginning token-by-token generation...")
|
| 454 |
-
|
| 455 |
-
# Generate token by token
|
| 456 |
-
current_input_ids = inputs.input_ids
|
| 457 |
-
current_attention_mask = inputs.attention_mask
|
| 458 |
-
|
| 459 |
-
for step in range(max_new_tokens):
|
| 460 |
-
try:
|
| 461 |
-
with torch.no_grad():
|
| 462 |
-
outputs = self.model(
|
| 463 |
-
input_ids=current_input_ids,
|
| 464 |
-
attention_mask=current_attention_mask,
|
| 465 |
-
use_cache=True
|
| 466 |
-
)
|
| 467 |
-
|
| 468 |
-
# Get next token probabilities
|
| 469 |
-
next_token_logits = outputs.logits[:, -1, :]
|
| 470 |
-
|
| 471 |
-
# Apply temperature and sampling
|
| 472 |
-
next_token_logits = next_token_logits / 0.7
|
| 473 |
-
|
| 474 |
-
# Apply top-k and top-p filtering
|
| 475 |
-
filtered_logits = self._top_k_top_p_filtering(next_token_logits, top_k=50, top_p=0.9)
|
| 476 |
-
|
| 477 |
-
# Sample next token
|
| 478 |
-
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
|
| 479 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 480 |
-
|
| 481 |
-
# Check for end of sequence
|
| 482 |
-
if next_token.item() == self.tokenizer.eos_token_id:
|
| 483 |
-
logger.info(f"Reached EOS token at step {step}")
|
| 484 |
-
break
|
| 485 |
-
|
| 486 |
-
# Add to generated tokens
|
| 487 |
-
generated_tokens.append(next_token.item())
|
| 488 |
-
|
| 489 |
-
# Decode and yield partial result every few tokens for efficiency
|
| 490 |
-
if step % 3 == 0 or step < 10: # Yield more frequently at start
|
| 491 |
-
partial_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 492 |
-
if partial_text.strip(): # Only yield non-empty text
|
| 493 |
-
yield partial_text
|
| 494 |
-
|
| 495 |
-
# Safety checks to prevent infinite loops
|
| 496 |
-
if step > 10 and len(generated_tokens) == 0:
|
| 497 |
-
logger.error("No tokens generated after 10 steps, breaking")
|
| 498 |
-
break
|
| 499 |
-
|
| 500 |
-
if step > 50 and len(partial_text.strip()) < 10:
|
| 501 |
-
logger.warning("Very little text generated, continuing...")
|
| 502 |
-
|
| 503 |
-
# Update input for next iteration
|
| 504 |
-
current_input_ids = torch.cat([current_input_ids, next_token], dim=-1)
|
| 505 |
-
current_attention_mask = torch.cat([
|
| 506 |
-
current_attention_mask,
|
| 507 |
-
torch.ones((1, 1), dtype=current_attention_mask.dtype, device=current_attention_mask.device)
|
| 508 |
-
], dim=-1)
|
| 509 |
-
|
| 510 |
-
except Exception as e:
|
| 511 |
-
logger.error(f"Error in generation step {step}: {e}")
|
| 512 |
-
break
|
| 513 |
-
|
| 514 |
-
# Final result
|
| 515 |
-
final_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 516 |
-
if final_text:
|
| 517 |
-
yield final_text
|
| 518 |
-
else:
|
| 519 |
-
logger.error("No final text generated")
|
| 520 |
-
yield "I'm having trouble generating a response. Please try again."
|
| 521 |
|
| 522 |
end_stream_time = time.perf_counter()
|
| 523 |
stream_time = end_stream_time - start_stream_time
|
| 524 |
-
log_metric(f"LLM Stream time: {stream_time:0.4f} seconds.
|
| 525 |
-
logger.info(f"Stream generation completed: {len(
|
| 526 |
|
| 527 |
except Exception as e:
|
| 528 |
logger.error(f"Streaming generation error: {e}")
|
|
@@ -531,29 +507,6 @@ class Phi2EducationalLLM(Runnable):
|
|
| 531 |
log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 532 |
yield f"[Error in streaming generation: {str(e)}]"
|
| 533 |
|
| 534 |
-
def _top_k_top_p_filtering(self, logits, top_k=50, top_p=0.9):
|
| 535 |
-
"""Apply top-k and top-p filtering to logits"""
|
| 536 |
-
if top_k > 0:
|
| 537 |
-
# Get top-k indices
|
| 538 |
-
top_k = min(top_k, logits.size(-1))
|
| 539 |
-
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 540 |
-
logits[indices_to_remove] = float('-inf')
|
| 541 |
-
|
| 542 |
-
if top_p < 1.0:
|
| 543 |
-
# Sort and get cumulative probabilities
|
| 544 |
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 545 |
-
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
| 546 |
-
|
| 547 |
-
# Remove tokens with cumulative probability above the threshold
|
| 548 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 549 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 550 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 551 |
-
|
| 552 |
-
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 553 |
-
logits[indices_to_remove] = float('-inf')
|
| 554 |
-
|
| 555 |
-
return logits
|
| 556 |
-
|
| 557 |
@property
|
| 558 |
def InputType(self) -> Type[Input]:
|
| 559 |
return str
|
|
@@ -562,15 +515,15 @@ class Phi2EducationalLLM(Runnable):
|
|
| 562 |
def OutputType(self) -> Type[Output]:
|
| 563 |
return str
|
| 564 |
|
| 565 |
-
# --- LangGraph Agent Implementation ---
|
| 566 |
class Educational_Agent:
|
| 567 |
-
"""Modern LangGraph-based educational agent with Phi-
|
| 568 |
|
| 569 |
def __init__(self):
|
| 570 |
start_init_and_langgraph_time = time.perf_counter()
|
| 571 |
current_time = datetime.now()
|
| 572 |
|
| 573 |
-
self.llm =
|
| 574 |
self.tool_decision_engine = Tool_Decision_Engine(self.llm)
|
| 575 |
|
| 576 |
# Create LangGraph workflow
|
|
@@ -581,29 +534,33 @@ class Educational_Agent:
|
|
| 581 |
log_metric(f"Init and LangGraph workflow setup time: {init_and_langgraph_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 582 |
|
| 583 |
def _create_langgraph_workflow(self):
|
| 584 |
-
"""Create the complete LangGraph workflow"""
|
| 585 |
# Define tools
|
| 586 |
tools = [Create_Graph_Tool]
|
| 587 |
tool_node = ToolNode(tools)
|
| 588 |
|
| 589 |
-
# Bind tools to model
|
| 590 |
-
model_with_tools = self.llm
|
| 591 |
-
|
| 592 |
def should_continue(state: EducationalAgentState) -> str:
|
| 593 |
"""Determine next step in the workflow"""
|
| 594 |
messages = state["messages"]
|
| 595 |
last_message = messages[-1]
|
| 596 |
|
| 597 |
-
#
|
| 598 |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
| 599 |
logger.info("Executing tools based on model decision")
|
| 600 |
return "tools"
|
| 601 |
|
| 602 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
return END
|
| 604 |
|
| 605 |
def call_model(state: EducationalAgentState) -> dict:
|
| 606 |
-
"""Call the model with tool decision logic"""
|
| 607 |
start_call_model_time = time.perf_counter()
|
| 608 |
current_time = datetime.now()
|
| 609 |
|
|
@@ -619,74 +576,38 @@ class Educational_Agent:
|
|
| 619 |
# Decide if tools should be used
|
| 620 |
needs_tools = self.tool_decision_engine.should_use_visualization(user_query)
|
| 621 |
|
| 622 |
-
if needs_tools:
|
| 623 |
-
logger.info("Query requires visualization - model will consider tools")
|
| 624 |
-
# Create a special prompt that encourages tool use
|
| 625 |
-
enhanced_messages = messages + [
|
| 626 |
-
SystemMessage(content="The user's query would benefit from visualization. Consider using the Create_Graph_Tool if appropriate for educational purposes.")
|
| 627 |
-
]
|
| 628 |
-
else:
|
| 629 |
-
logger.info("Query doesn't need tools - responding normally")
|
| 630 |
-
enhanced_messages = messages
|
| 631 |
-
|
| 632 |
try:
|
| 633 |
-
# For this implementation, we'll handle tool calling manually
|
| 634 |
-
# since our custom LLM doesn't automatically generate tool calls
|
| 635 |
-
|
| 636 |
if needs_tools:
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
-
|
| 642 |
-
Otherwise, provide a regular educational response.
|
| 643 |
|
| 644 |
-
|
| 645 |
-
TOOL_CALL: Create_Graph_Tool
|
| 646 |
{{
|
| 647 |
-
"data": {{"
|
| 648 |
"plot_type": "bar|line|pie",
|
| 649 |
-
"title": "
|
| 650 |
-
"x_label": "X Label",
|
| 651 |
-
"y_label": "Y Label",
|
| 652 |
-
"educational_context": "
|
| 653 |
}}
|
|
|
|
| 654 |
|
| 655 |
-
|
| 656 |
"""
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
if "TOOL_CALL:" in response and "Create_Graph_Tool" in response:
|
| 662 |
-
# Extract the JSON part
|
| 663 |
-
json_start = response.find("{")
|
| 664 |
-
json_end = response.rfind("}") + 1
|
| 665 |
-
if json_start != -1 and json_end > json_start:
|
| 666 |
-
json_config = response[json_start:json_end]
|
| 667 |
-
|
| 668 |
-
# Create a mock tool call message
|
| 669 |
-
tool_call_message = AIMessage(
|
| 670 |
-
content="I'll create a visualization to help explain this concept.",
|
| 671 |
-
tool_calls=[{
|
| 672 |
-
"name": "Create_Graph_Tool",
|
| 673 |
-
"args": {"graph_config": json_config},
|
| 674 |
-
"id": "tool_call_1"
|
| 675 |
-
}]
|
| 676 |
-
)
|
| 677 |
-
|
| 678 |
-
end_call_model_time = time.perf_counter()
|
| 679 |
-
call_model_time = end_call_model_time - start_call_model_time
|
| 680 |
-
log_metric(f"Call model time (with tools): {call_model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 681 |
-
|
| 682 |
-
return {"messages": [tool_call_message]}
|
| 683 |
-
|
| 684 |
-
# Regular response without tools
|
| 685 |
-
response = model_with_tools.invoke(enhanced_messages)
|
| 686 |
|
| 687 |
end_call_model_time = time.perf_counter()
|
| 688 |
call_model_time = end_call_model_time - start_call_model_time
|
| 689 |
-
log_metric(f"Call model time
|
| 690 |
|
| 691 |
return {"messages": [AIMessage(content=response)]}
|
| 692 |
|
|
@@ -700,7 +621,7 @@ Otherwise, provide a regular educational response.
|
|
| 700 |
return {"messages": [error_response]}
|
| 701 |
|
| 702 |
def handle_tools(state: EducationalAgentState) -> dict:
|
| 703 |
-
"""Handle tool execution"""
|
| 704 |
start_handle_tools_time = time.perf_counter()
|
| 705 |
current_time = datetime.now()
|
| 706 |
|
|
@@ -708,29 +629,39 @@ Otherwise, provide a regular educational response.
|
|
| 708 |
messages = state["messages"]
|
| 709 |
last_message = messages[-1]
|
| 710 |
|
| 711 |
-
if
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
-
# If no valid tool call, return
|
| 731 |
end_handle_tools_time = time.perf_counter()
|
| 732 |
handle_tools_time = end_handle_tools_time - start_handle_tools_time
|
| 733 |
-
log_metric(f"Handle tools time (no
|
| 734 |
|
| 735 |
return {"messages": []}
|
| 736 |
|
|
@@ -740,11 +671,7 @@ Otherwise, provide a regular educational response.
|
|
| 740 |
handle_tools_time = end_handle_tools_time - start_handle_tools_time
|
| 741 |
log_metric(f"Handle tools time (error): {handle_tools_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 742 |
|
| 743 |
-
|
| 744 |
-
content=f"Tool execution failed: {str(e)}",
|
| 745 |
-
tool_call_id="error"
|
| 746 |
-
)
|
| 747 |
-
return {"messages": [error_msg]}
|
| 748 |
|
| 749 |
# Build the workflow
|
| 750 |
workflow = StateGraph(EducationalAgentState)
|
|
@@ -763,7 +690,7 @@ Otherwise, provide a regular educational response.
|
|
| 763 |
END: END,
|
| 764 |
}
|
| 765 |
)
|
| 766 |
-
workflow.add_edge("tools",
|
| 767 |
|
| 768 |
# Add memory
|
| 769 |
memory = MemorySaver()
|
|
@@ -796,7 +723,7 @@ Otherwise, provide a regular educational response.
|
|
| 796 |
return f"I apologize, but I encountered an error: {str(e)}"
|
| 797 |
|
| 798 |
def stream_chat(self, message: str, thread_id: str = "default"):
|
| 799 |
-
"""Streaming chat interface that yields partial responses"""
|
| 800 |
start_chat_time = time.perf_counter()
|
| 801 |
current_time = datetime.now()
|
| 802 |
|
|
@@ -810,17 +737,17 @@ Otherwise, provide a regular educational response.
|
|
| 810 |
"educational_context": None
|
| 811 |
}
|
| 812 |
|
| 813 |
-
#
|
| 814 |
user_query = message
|
| 815 |
needs_tools = self.tool_decision_engine.should_use_visualization(user_query)
|
| 816 |
|
| 817 |
if needs_tools:
|
| 818 |
logger.info("Query requires visualization - handling tool call first")
|
| 819 |
-
# Handle tool generation
|
| 820 |
result = self.app.invoke(initial_state, config=config)
|
| 821 |
final_messages = result["messages"]
|
| 822 |
|
| 823 |
-
# Build the response from all
|
| 824 |
response_parts = []
|
| 825 |
for msg in final_messages:
|
| 826 |
if isinstance(msg, AIMessage) and msg.content:
|
|
@@ -834,8 +761,8 @@ Otherwise, provide a regular educational response.
|
|
| 834 |
yield final_response
|
| 835 |
|
| 836 |
else:
|
| 837 |
-
logger.info("Streaming regular response without tools")
|
| 838 |
-
# Stream the LLM response directly
|
| 839 |
for partial_text in self.llm.stream_generate(message):
|
| 840 |
yield smart_truncate(partial_text, max_length=3000)
|
| 841 |
|
|
@@ -866,7 +793,7 @@ mathjax_config = '''
|
|
| 866 |
window.MathJax = {
|
| 867 |
tex: {
|
| 868 |
inlineMath: [['\\\\(', '\\\\)']],
|
| 869 |
-
displayMath: [[', '], ['\\\\[', '\\\\]']],
|
| 870 |
packages: {'[+]': ['ams']}
|
| 871 |
},
|
| 872 |
svg: {fontCache: 'global'},
|
|
@@ -936,7 +863,7 @@ def smart_truncate(text, max_length=3000):
|
|
| 936 |
return result
|
| 937 |
|
| 938 |
def generate_response_with_agent(message, max_retries=3):
|
| 939 |
-
"""Generate streaming response using LangGraph agent."""
|
| 940 |
start_generate_response_with_agent_time = time.perf_counter()
|
| 941 |
current_time = datetime.now()
|
| 942 |
|
|
@@ -1016,20 +943,20 @@ def warmup_agent():
|
|
| 1016 |
start_agent_warmup_time = time.perf_counter()
|
| 1017 |
current_time = datetime.now()
|
| 1018 |
|
| 1019 |
-
logger.info("Warming up LangGraph agent with test query...")
|
| 1020 |
try:
|
| 1021 |
current_agent = get_agent()
|
| 1022 |
|
| 1023 |
# Run a simple test query
|
| 1024 |
test_response = current_agent.chat("Hello, this is a warmup test.")
|
| 1025 |
-
logger.info(f"LangGraph agent warmup completed successfully! Test response length: {len(test_response)} chars")
|
| 1026 |
|
| 1027 |
end_agent_warmup_time = time.perf_counter()
|
| 1028 |
agent_warmup_time = end_agent_warmup_time - start_agent_warmup_time
|
| 1029 |
log_metric(f"Agent warmup time: {agent_warmup_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 1030 |
|
| 1031 |
except Exception as e:
|
| 1032 |
-
logger.error(f"LangGraph agent warmup failed: {e}")
|
| 1033 |
end_agent_warmup_time = time.perf_counter()
|
| 1034 |
agent_warmup_time = end_agent_warmup_time - start_agent_warmup_time
|
| 1035 |
log_metric(f"Agent warmup time (error): {agent_warmup_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
|
@@ -1064,7 +991,7 @@ def create_interface():
|
|
| 1064 |
|
| 1065 |
with gr.Column(elem_classes=["main-container"]):
|
| 1066 |
# Title Section
|
| 1067 |
-
gr.HTML('<div class="title-header"><h1
|
| 1068 |
|
| 1069 |
# Chat Section
|
| 1070 |
with gr.Row():
|
|
@@ -1112,18 +1039,18 @@ def create_interface():
|
|
| 1112 |
if __name__ == "__main__":
|
| 1113 |
try:
|
| 1114 |
logger.info("=" * 50)
|
| 1115 |
-
logger.info("Starting Mimir Application with Microsoft Phi-
|
| 1116 |
logger.info("=" * 50)
|
| 1117 |
|
| 1118 |
# Step 1: Preload the model and agent
|
| 1119 |
-
logger.info("Loading
|
| 1120 |
start_time = time.time()
|
| 1121 |
agent = Educational_Agent()
|
| 1122 |
load_time = time.time() - start_time
|
| 1123 |
-
logger.info(f"LangGraph agent loaded successfully in {load_time:.2f} seconds")
|
| 1124 |
|
| 1125 |
# Step 2: Warm up the model
|
| 1126 |
-
logger.info("Warming up
|
| 1127 |
warmup_agent()
|
| 1128 |
|
| 1129 |
interface = create_interface()
|
|
@@ -1136,5 +1063,5 @@ if __name__ == "__main__":
|
|
| 1136 |
)
|
| 1137 |
|
| 1138 |
except Exception as e:
|
| 1139 |
-
logger.error(f"❌ Failed to launch Mimir with
|
| 1140 |
raise
|
|
|
|
| 11 |
import logging
|
| 12 |
import re
|
| 13 |
import json
|
| 14 |
+
import threading
|
| 15 |
from datetime import datetime
|
| 16 |
from typing import Annotated, Sequence, TypedDict, List, Optional, Any, Type
|
| 17 |
from pydantic import BaseModel, Field
|
| 18 |
|
| 19 |
+
# LangGraph imports
|
| 20 |
from langgraph.graph import StateGraph, START, END
|
| 21 |
from langgraph.graph.message import add_messages
|
| 22 |
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
| 25 |
# Updated LangChain imports
|
| 26 |
from langchain_core.tools import tool
|
| 27 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage, BaseMessage
|
| 28 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 29 |
from langchain_core.runnables import Runnable
|
| 30 |
from langchain_core.runnables.utils import Input, Output
|
| 31 |
|
| 32 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
|
| 33 |
import torch
|
| 34 |
|
| 35 |
load_dotenv(".env")
|
|
|
|
| 224 |
log_metric(f"Tool decision time (error): {graph_decision_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 225 |
return False
|
| 226 |
|
| 227 |
+
# --- System Prompt with ReAct Framework for Phi-3-mini ---
|
| 228 |
SYSTEM_PROMPT = """You are Mimir, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
|
| 229 |
+
|
| 230 |
## Core Educational Principles
|
| 231 |
- Provide comprehensive, educational responses that help students truly understand concepts
|
| 232 |
- Use minimal formatting, with markdown bolding reserved for **key terms** only
|
| 233 |
- Prioritize teaching methodology over answer delivery
|
| 234 |
- Foster critical thinking and independent problem-solving skills
|
| 235 |
+
|
| 236 |
## Tone and Communication Style
|
| 237 |
- Maintain an engaging, friendly tone appropriate for high school students
|
| 238 |
- Write at a reading level that is accessible yet intellectually stimulating
|
|
|
|
| 242 |
- Skip flattery and respond directly to questions
|
| 243 |
- Do not use emojis or actions in asterisks unless specifically requested
|
| 244 |
- Present critiques and corrections kindly as educational opportunities
|
| 245 |
+
|
| 246 |
+
## Tool Usage Instructions
|
| 247 |
+
You have access to a Create_Graph_Tool that can create educational visualizations. When a query would benefit from visual representation, you should use this tool by outputting a properly formatted JSON configuration.
|
| 248 |
+
|
| 249 |
+
To use the Create_Graph_Tool, format your response like this:
|
| 250 |
+
```json
|
| 251 |
+
{
|
| 252 |
+
"data": {"Category 1": 30, "Category 2": 45, "Category 3": 25},
|
| 253 |
+
"plot_type": "bar",
|
| 254 |
+
"title": "Example Chart",
|
| 255 |
+
"x_label": "Categories",
|
| 256 |
+
"y_label": "Values",
|
| 257 |
+
"educational_context": "This visualization helps students understand..."
|
| 258 |
+
}
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
Use this tool for:
|
| 262 |
+
- Mathematical functions and relationships
|
| 263 |
+
- Statistical distributions and data analysis
|
| 264 |
+
- Scientific trends and comparisons
|
| 265 |
+
- Economic models and business metrics
|
| 266 |
+
- Any concept where visualization aids comprehension
|
| 267 |
+
|
| 268 |
## Academic Integrity Approach
|
| 269 |
+
Rather than providing complete solutions, you should:
|
| 270 |
+
- **Guide through processes**: Break down problems into conceptual components
|
| 271 |
+
- **Ask clarifying questions**: Understand what the student knows
|
| 272 |
+
- **Provide similar examples**: Work through analogous problems
|
| 273 |
+
- **Encourage original thinking**: Help students develop reasoning skills
|
| 274 |
+
- **Suggest study strategies**: Recommend effective learning approaches
|
| 275 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
## Response Guidelines
|
| 277 |
+
- **For math problems**: Explain concepts and guide through steps without computing final answers
|
| 278 |
+
- **For multiple-choice questions**: Discuss concepts being tested rather than identifying correct choices
|
| 279 |
+
- **For essays**: Discuss research strategies and organizational techniques
|
| 280 |
+
- **For factual questions**: Provide educational context and encourage synthesis
|
| 281 |
+
|
| 282 |
+
Your goal is to be an educational partner who empowers students to succeed through understanding."""
|
| 283 |
+
|
| 284 |
+
# --- Updated LLM Class with Phi-3-mini and TextIteratorStreamer ---
|
| 285 |
+
class Phi3MiniEducationalLLM(Runnable):
|
| 286 |
+
"""LLM class optimized for Microsoft Phi-3-mini-4k-instruct with TextIteratorStreamer"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
+
def __init__(self, model_path: str = "microsoft/Phi-3-mini-4k-instruct", use_4bit: bool = False):
|
| 289 |
super().__init__()
|
| 290 |
+
logger.info(f"Loading Phi-3-mini model: {model_path} (use_4bit={use_4bit})")
|
| 291 |
start_Loading_Model_time = time.perf_counter()
|
| 292 |
current_time = datetime.now()
|
| 293 |
|
| 294 |
self.model_name = model_path
|
| 295 |
|
| 296 |
try:
|
| 297 |
+
# Load tokenizer - Phi-3 requires trust_remote_code
|
| 298 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 299 |
+
model_path,
|
| 300 |
+
trust_remote_code=True,
|
| 301 |
+
token=hf_token
|
| 302 |
+
)
|
| 303 |
|
| 304 |
if use_4bit:
|
| 305 |
quant_config = BitsAndBytesConfig(
|
|
|
|
| 311 |
llm_int8_skip_modules=["lm_head"]
|
| 312 |
)
|
| 313 |
|
|
|
|
| 314 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 315 |
model_path,
|
| 316 |
quantization_config=quant_config,
|
| 317 |
device_map="auto",
|
| 318 |
+
dtype=torch.float16,
|
| 319 |
trust_remote_code=True,
|
| 320 |
+
low_cpu_mem_usage=True,
|
| 321 |
+
token=hf_token
|
| 322 |
)
|
| 323 |
else:
|
| 324 |
self._load_optimized_model(model_path)
|
|
|
|
| 329 |
log_metric(f"Model Load time: {Loading_Model_time:0.4f} seconds. Model: {model_path}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 330 |
|
| 331 |
except Exception as e:
|
| 332 |
+
logger.error(f"Failed to load Phi-3-mini model {model_path}: {e}")
|
| 333 |
+
raise
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
+
# Ensure pad token exists
|
| 336 |
if self.tokenizer.pad_token is None:
|
| 337 |
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 338 |
|
| 339 |
+
# Initialize TextIteratorStreamer
|
| 340 |
+
self.streamer = None
|
| 341 |
+
|
| 342 |
def _load_optimized_model(self, model_path: str):
|
| 343 |
+
"""Optimized model loading for Phi-3-mini."""
|
| 344 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 345 |
model_path,
|
| 346 |
+
dtype=torch.float16, # Use float16 to save memory
|
| 347 |
+
device_map="auto", # Let transformers decide placement
|
| 348 |
trust_remote_code=True,
|
| 349 |
low_cpu_mem_usage=True,
|
| 350 |
+
token=hf_token
|
| 351 |
)
|
| 352 |
|
| 353 |
+
def _format_chat_template(self, prompt: str) -> str:
|
| 354 |
+
"""Format prompt using Phi-3's chat template"""
|
| 355 |
+
try:
|
| 356 |
+
messages = [
|
| 357 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 358 |
+
{"role": "user", "content": prompt}
|
| 359 |
+
]
|
| 360 |
+
# Use Phi-3's chat template
|
| 361 |
+
formatted_text = self.tokenizer.apply_chat_template(
|
| 362 |
+
messages,
|
| 363 |
+
tokenize=False,
|
| 364 |
+
add_generation_prompt=True
|
| 365 |
+
)
|
| 366 |
+
return formatted_text
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.warning(f"Chat template failed, using fallback format: {e}")
|
| 369 |
+
# Fallback to manual Phi-3 format
|
| 370 |
+
return f"<|system|>\n{SYSTEM_PROMPT}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
| 371 |
|
| 372 |
def invoke(self, input: Input, config=None) -> Output:
|
| 373 |
+
"""Main invoke method optimized for Phi-3-mini"""
|
| 374 |
start_invoke_time = time.perf_counter()
|
| 375 |
current_time = datetime.now()
|
| 376 |
|
|
|
|
| 381 |
prompt = str(input)
|
| 382 |
|
| 383 |
try:
|
| 384 |
+
# Format using Phi-3 chat template
|
| 385 |
+
text = self._format_chat_template(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
+
inputs = self.tokenizer(
|
| 388 |
+
text,
|
| 389 |
+
return_tensors="pt",
|
| 390 |
+
padding=True,
|
| 391 |
+
truncation=True,
|
| 392 |
+
max_length=3072 # Leave room for generation within 4k context
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Move to model device
|
| 396 |
+
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 397 |
|
| 398 |
with torch.no_grad():
|
| 399 |
outputs = self.model.generate(
|
| 400 |
**inputs,
|
| 401 |
+
max_new_tokens=800, # Increased for comprehensive responses
|
| 402 |
do_sample=True,
|
| 403 |
+
temperature=0.7, # Good balance for educational content
|
| 404 |
top_p=0.9,
|
| 405 |
+
top_k=50,
|
| 406 |
repetition_penalty=1.1,
|
| 407 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 408 |
early_stopping=True,
|
| 409 |
+
use_cache=True
|
| 410 |
)
|
| 411 |
|
| 412 |
+
# Decode only new tokens
|
| 413 |
+
new_tokens = outputs[0][len(inputs.input_ids[0]):]
|
| 414 |
+
result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 415 |
|
| 416 |
end_invoke_time = time.perf_counter()
|
| 417 |
invoke_time = end_invoke_time - start_invoke_time
|
|
|
|
| 427 |
return f"[Error generating response: {str(e)}]"
|
| 428 |
|
| 429 |
def stream_generate(self, input: Input, config=None):
|
| 430 |
+
"""Streaming generation using TextIteratorStreamer"""
|
| 431 |
start_stream_time = time.perf_counter()
|
| 432 |
current_time = datetime.now()
|
| 433 |
+
logger.info("Starting stream_generate with TextIteratorStreamer...")
|
| 434 |
|
| 435 |
+
# Handle both string and dict inputs
|
| 436 |
if isinstance(input, dict):
|
| 437 |
prompt = input.get('input', str(input))
|
| 438 |
else:
|
| 439 |
prompt = str(input)
|
| 440 |
|
| 441 |
try:
|
| 442 |
+
# Format using Phi-3 chat template
|
| 443 |
+
text = self._format_chat_template(prompt)
|
| 444 |
+
|
| 445 |
+
inputs = self.tokenizer(
|
| 446 |
+
text,
|
| 447 |
+
return_tensors="pt",
|
| 448 |
+
padding=True,
|
| 449 |
+
truncation=True,
|
| 450 |
+
max_length=3072
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Move to model device
|
| 454 |
+
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
| 455 |
+
|
| 456 |
+
# Initialize TextIteratorStreamer
|
| 457 |
+
streamer = TextIteratorStreamer(
|
| 458 |
+
self.tokenizer,
|
| 459 |
+
skip_prompt=True,
|
| 460 |
+
skip_special_tokens=True
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# Generation parameters
|
| 464 |
+
generation_kwargs = {
|
| 465 |
+
**inputs,
|
| 466 |
+
"max_new_tokens": 800,
|
| 467 |
+
"do_sample": True,
|
| 468 |
+
"temperature": 0.7,
|
| 469 |
+
"top_p": 0.9,
|
| 470 |
+
"top_k": 50,
|
| 471 |
+
"repetition_penalty": 1.1,
|
| 472 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
| 473 |
+
"streamer": streamer,
|
| 474 |
+
"use_cache": True
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
# Start generation in a separate thread
|
| 478 |
+
generation_thread = threading.Thread(
|
| 479 |
+
target=self.model.generate,
|
| 480 |
+
kwargs=generation_kwargs
|
| 481 |
+
)
|
| 482 |
+
generation_thread.start()
|
| 483 |
+
|
| 484 |
+
# Yield tokens as they become available
|
| 485 |
+
generated_text = ""
|
| 486 |
try:
|
| 487 |
+
for new_text in streamer:
|
| 488 |
+
if new_text: # Only yield non-empty strings
|
| 489 |
+
generated_text += new_text
|
| 490 |
+
yield generated_text
|
|
|
|
|
|
|
| 491 |
except Exception as e:
|
| 492 |
+
logger.error(f"Error in streaming iteration: {e}")
|
| 493 |
+
yield f"[Streaming error: {str(e)}]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
+
# Wait for generation to complete
|
| 496 |
+
generation_thread.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
|
| 498 |
end_stream_time = time.perf_counter()
|
| 499 |
stream_time = end_stream_time - start_stream_time
|
| 500 |
+
log_metric(f"LLM Stream time: {stream_time:0.4f} seconds. Generated length: {len(generated_text)} chars. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 501 |
+
logger.info(f"Stream generation completed: {len(generated_text)} chars in {stream_time:.2f}s")
|
| 502 |
|
| 503 |
except Exception as e:
|
| 504 |
logger.error(f"Streaming generation error: {e}")
|
|
|
|
| 507 |
log_metric(f"LLM Stream time (error): {stream_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 508 |
yield f"[Error in streaming generation: {str(e)}]"
|
| 509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
@property
|
| 511 |
def InputType(self) -> Type[Input]:
|
| 512 |
return str
|
|
|
|
| 515 |
def OutputType(self) -> Type[Output]:
|
| 516 |
return str
|
| 517 |
|
| 518 |
+
# --- LangGraph Agent Implementation with Tool Calling ---
|
| 519 |
class Educational_Agent:
|
| 520 |
+
"""Modern LangGraph-based educational agent with Phi-3-mini and improved tool calling"""
|
| 521 |
|
| 522 |
def __init__(self):
|
| 523 |
start_init_and_langgraph_time = time.perf_counter()
|
| 524 |
current_time = datetime.now()
|
| 525 |
|
| 526 |
+
self.llm = Phi3MiniEducationalLLM(model_path="microsoft/Phi-3-mini-4k-instruct", use_4bit=True)
|
| 527 |
self.tool_decision_engine = Tool_Decision_Engine(self.llm)
|
| 528 |
|
| 529 |
# Create LangGraph workflow
|
|
|
|
| 534 |
log_metric(f"Init and LangGraph workflow setup time: {init_and_langgraph_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 535 |
|
| 536 |
def _create_langgraph_workflow(self):
|
| 537 |
+
"""Create the complete LangGraph workflow with improved tool calling"""
|
| 538 |
# Define tools
|
| 539 |
tools = [Create_Graph_Tool]
|
| 540 |
tool_node = ToolNode(tools)
|
| 541 |
|
|
|
|
|
|
|
|
|
|
| 542 |
def should_continue(state: EducationalAgentState) -> str:
|
| 543 |
"""Determine next step in the workflow"""
|
| 544 |
messages = state["messages"]
|
| 545 |
last_message = messages[-1]
|
| 546 |
|
| 547 |
+
# Check if we have tool calls to execute
|
| 548 |
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
| 549 |
logger.info("Executing tools based on model decision")
|
| 550 |
return "tools"
|
| 551 |
|
| 552 |
+
# Check if the message content contains JSON for tool calling
|
| 553 |
+
if isinstance(last_message, AIMessage) and last_message.content:
|
| 554 |
+
content = last_message.content.strip()
|
| 555 |
+
# Look for JSON blocks that might be tool calls
|
| 556 |
+
if content.startswith('```json') and 'plot_type' in content:
|
| 557 |
+
logger.info("Found JSON tool configuration in message")
|
| 558 |
+
return "tools"
|
| 559 |
+
|
| 560 |
return END
|
| 561 |
|
| 562 |
def call_model(state: EducationalAgentState) -> dict:
|
| 563 |
+
"""Call the model with enhanced tool decision logic"""
|
| 564 |
start_call_model_time = time.perf_counter()
|
| 565 |
current_time = datetime.now()
|
| 566 |
|
|
|
|
| 576 |
# Decide if tools should be used
|
| 577 |
needs_tools = self.tool_decision_engine.should_use_visualization(user_query)
|
| 578 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
try:
|
|
|
|
|
|
|
|
|
|
| 580 |
if needs_tools:
|
| 581 |
+
logger.info("Query requires visualization - prompting for tool use")
|
| 582 |
+
# Enhanced prompt that guides Phi-3 to generate tool calls
|
| 583 |
+
tool_prompt = f"""
|
| 584 |
+
You are an educational AI assistant. The user has asked: "{user_query}"
|
| 585 |
+
|
| 586 |
+
This query would benefit from a visualization. Please provide a helpful educational response AND include a JSON configuration for creating a graph or chart.
|
| 587 |
|
| 588 |
+
Format your response with explanatory text followed by a JSON block like this:
|
|
|
|
| 589 |
|
| 590 |
+
```json
|
|
|
|
| 591 |
{{
|
| 592 |
+
"data": {{"Category 1": value1, "Category 2": value2}},
|
| 593 |
"plot_type": "bar|line|pie",
|
| 594 |
+
"title": "Descriptive Title",
|
| 595 |
+
"x_label": "X Axis Label",
|
| 596 |
+
"y_label": "Y Axis Label",
|
| 597 |
+
"educational_context": "Explanation of why this visualization helps learning"
|
| 598 |
}}
|
| 599 |
+
```
|
| 600 |
|
| 601 |
+
Make sure the data is relevant to the educational concept being discussed.
|
| 602 |
"""
|
| 603 |
+
response = self.llm.invoke(tool_prompt)
|
| 604 |
+
else:
|
| 605 |
+
# Regular educational response
|
| 606 |
+
response = self.llm.invoke(user_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
|
| 608 |
end_call_model_time = time.perf_counter()
|
| 609 |
call_model_time = end_call_model_time - start_call_model_time
|
| 610 |
+
log_metric(f"Call model time: {call_model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 611 |
|
| 612 |
return {"messages": [AIMessage(content=response)]}
|
| 613 |
|
|
|
|
| 621 |
return {"messages": [error_response]}
|
| 622 |
|
| 623 |
def handle_tools(state: EducationalAgentState) -> dict:
|
| 624 |
+
"""Handle tool execution by parsing JSON from message content"""
|
| 625 |
start_handle_tools_time = time.perf_counter()
|
| 626 |
current_time = datetime.now()
|
| 627 |
|
|
|
|
| 629 |
messages = state["messages"]
|
| 630 |
last_message = messages[-1]
|
| 631 |
|
| 632 |
+
if isinstance(last_message, AIMessage) and last_message.content:
|
| 633 |
+
content = last_message.content
|
| 634 |
+
|
| 635 |
+
# Extract JSON from code blocks
|
| 636 |
+
json_pattern = r'```json\s*(\{.*?\})\s*```'
|
| 637 |
+
json_match = re.search(json_pattern, content, re.DOTALL)
|
| 638 |
+
|
| 639 |
+
if json_match:
|
| 640 |
+
json_str = json_match.group(1)
|
| 641 |
+
try:
|
| 642 |
+
# Validate and execute the tool
|
| 643 |
+
json.loads(json_str) # Validate JSON
|
| 644 |
+
result = Create_Graph_Tool.invoke({"graph_config": json_str})
|
| 645 |
+
|
| 646 |
+
# Create a response that combines the explanation with the visualization
|
| 647 |
+
text_before_json = content[:json_match.start()].strip()
|
| 648 |
+
combined_response = f"{text_before_json}\n\n{result}"
|
| 649 |
+
|
| 650 |
+
end_handle_tools_time = time.perf_counter()
|
| 651 |
+
handle_tools_time = end_handle_tools_time - start_handle_tools_time
|
| 652 |
+
log_metric(f"Handle tools time: {handle_tools_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 653 |
+
|
| 654 |
+
# Replace the last message with the combined response
|
| 655 |
+
return {"messages": [AIMessage(content=combined_response)]}
|
| 656 |
+
|
| 657 |
+
except json.JSONDecodeError as e:
|
| 658 |
+
logger.error(f"Invalid JSON in tool call: {e}")
|
| 659 |
+
return {"messages": [AIMessage(content=f"{content}\n\n[Error: Invalid JSON format for visualization]")]}
|
| 660 |
|
| 661 |
+
# If no valid tool call found, return the message as-is
|
| 662 |
end_handle_tools_time = time.perf_counter()
|
| 663 |
handle_tools_time = end_handle_tools_time - start_handle_tools_time
|
| 664 |
+
log_metric(f"Handle tools time (no tool found): {handle_tools_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 665 |
|
| 666 |
return {"messages": []}
|
| 667 |
|
|
|
|
| 671 |
handle_tools_time = end_handle_tools_time - start_handle_tools_time
|
| 672 |
log_metric(f"Handle tools time (error): {handle_tools_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 673 |
|
| 674 |
+
return {"messages": [AIMessage(content=f"Tool execution failed: {str(e)}")]}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
|
| 676 |
# Build the workflow
|
| 677 |
workflow = StateGraph(EducationalAgentState)
|
|
|
|
| 690 |
END: END,
|
| 691 |
}
|
| 692 |
)
|
| 693 |
+
workflow.add_edge("tools", END) # After tools, we're done
|
| 694 |
|
| 695 |
# Add memory
|
| 696 |
memory = MemorySaver()
|
|
|
|
| 723 |
return f"I apologize, but I encountered an error: {str(e)}"
|
| 724 |
|
| 725 |
def stream_chat(self, message: str, thread_id: str = "default"):
|
| 726 |
+
"""Streaming chat interface that yields partial responses using TextIteratorStreamer"""
|
| 727 |
start_chat_time = time.perf_counter()
|
| 728 |
current_time = datetime.now()
|
| 729 |
|
|
|
|
| 737 |
"educational_context": None
|
| 738 |
}
|
| 739 |
|
| 740 |
+
# Check if tools are needed
|
| 741 |
user_query = message
|
| 742 |
needs_tools = self.tool_decision_engine.should_use_visualization(user_query)
|
| 743 |
|
| 744 |
if needs_tools:
|
| 745 |
logger.info("Query requires visualization - handling tool call first")
|
| 746 |
+
# Handle tool generation (non-streaming for tools since they involve JSON parsing)
|
| 747 |
result = self.app.invoke(initial_state, config=config)
|
| 748 |
final_messages = result["messages"]
|
| 749 |
|
| 750 |
+
# Build the response from all messages
|
| 751 |
response_parts = []
|
| 752 |
for msg in final_messages:
|
| 753 |
if isinstance(msg, AIMessage) and msg.content:
|
|
|
|
| 761 |
yield final_response
|
| 762 |
|
| 763 |
else:
|
| 764 |
+
logger.info("Streaming regular response without tools using TextIteratorStreamer")
|
| 765 |
+
# Stream the LLM response directly using TextIteratorStreamer
|
| 766 |
for partial_text in self.llm.stream_generate(message):
|
| 767 |
yield smart_truncate(partial_text, max_length=3000)
|
| 768 |
|
|
|
|
| 793 |
window.MathJax = {
|
| 794 |
tex: {
|
| 795 |
inlineMath: [['\\\\(', '\\\\)']],
|
| 796 |
+
displayMath: [['$', '$'], ['\\\\[', '\\\\]']],
|
| 797 |
packages: {'[+]': ['ams']}
|
| 798 |
},
|
| 799 |
svg: {fontCache: 'global'},
|
|
|
|
| 863 |
return result
|
| 864 |
|
| 865 |
def generate_response_with_agent(message, max_retries=3):
|
| 866 |
+
"""Generate streaming response using LangGraph agent with Phi-3-mini."""
|
| 867 |
start_generate_response_with_agent_time = time.perf_counter()
|
| 868 |
current_time = datetime.now()
|
| 869 |
|
|
|
|
| 943 |
start_agent_warmup_time = time.perf_counter()
|
| 944 |
current_time = datetime.now()
|
| 945 |
|
| 946 |
+
logger.info("Warming up Phi-3-mini LangGraph agent with test query...")
|
| 947 |
try:
|
| 948 |
current_agent = get_agent()
|
| 949 |
|
| 950 |
# Run a simple test query
|
| 951 |
test_response = current_agent.chat("Hello, this is a warmup test.")
|
| 952 |
+
logger.info(f"Phi-3-mini LangGraph agent warmup completed successfully! Test response length: {len(test_response)} chars")
|
| 953 |
|
| 954 |
end_agent_warmup_time = time.perf_counter()
|
| 955 |
agent_warmup_time = end_agent_warmup_time - start_agent_warmup_time
|
| 956 |
log_metric(f"Agent warmup time: {agent_warmup_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 957 |
|
| 958 |
except Exception as e:
|
| 959 |
+
logger.error(f"Phi-3-mini LangGraph agent warmup failed: {e}")
|
| 960 |
end_agent_warmup_time = time.perf_counter()
|
| 961 |
agent_warmup_time = end_agent_warmup_time - start_agent_warmup_time
|
| 962 |
log_metric(f"Agent warmup time (error): {agent_warmup_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
|
|
|
| 991 |
|
| 992 |
with gr.Column(elem_classes=["main-container"]):
|
| 993 |
# Title Section
|
| 994 |
+
gr.HTML('<div class="title-header"><h1>🎓 Mimir - Powered by Phi-3-mini</h1></div>')
|
| 995 |
|
| 996 |
# Chat Section
|
| 997 |
with gr.Row():
|
|
|
|
| 1039 |
if __name__ == "__main__":
|
| 1040 |
try:
|
| 1041 |
logger.info("=" * 50)
|
| 1042 |
+
logger.info("Starting Mimir Application with Microsoft Phi-3-mini-4k-instruct and TextIteratorStreamer")
|
| 1043 |
logger.info("=" * 50)
|
| 1044 |
|
| 1045 |
# Step 1: Preload the model and agent
|
| 1046 |
+
logger.info("Loading Phi-3-mini model and LangGraph workflow...")
|
| 1047 |
start_time = time.time()
|
| 1048 |
agent = Educational_Agent()
|
| 1049 |
load_time = time.time() - start_time
|
| 1050 |
+
logger.info(f"Phi-3-mini LangGraph agent loaded successfully in {load_time:.2f} seconds")
|
| 1051 |
|
| 1052 |
# Step 2: Warm up the model
|
| 1053 |
+
logger.info("Warming up Phi-3-mini model...")
|
| 1054 |
warmup_agent()
|
| 1055 |
|
| 1056 |
interface = create_interface()
|
|
|
|
| 1063 |
)
|
| 1064 |
|
| 1065 |
except Exception as e:
|
| 1066 |
+
logger.error(f"❌ Failed to launch Mimir with Phi-3-mini: {e}")
|
| 1067 |
raise
|