Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -38,7 +38,6 @@ os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface'
|
|
| 38 |
warnings.filterwarnings("ignore", message="Special tokens have been added")
|
| 39 |
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 40 |
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")
|
| 41 |
-
# Suppress transformer warnings specifically
|
| 42 |
warnings.filterwarnings("ignore", message=".*TracerWarning.*")
|
| 43 |
warnings.filterwarnings("ignore", message=".*flash-attention.*")
|
| 44 |
|
|
@@ -46,11 +45,43 @@ load_dotenv(".env")
|
|
| 46 |
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 47 |
print("Environment variables loaded.")
|
| 48 |
|
| 49 |
-
#
|
| 50 |
logging.basicConfig(level=logging.INFO)
|
| 51 |
logger = logging.getLogger(__name__)
|
| 52 |
|
| 53 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def setup_metrics_logger():
|
| 55 |
"""Setup a simple file logger for human-readable metrics"""
|
| 56 |
metrics_logger = logging.getLogger('metrics')
|
|
@@ -86,7 +117,7 @@ hf_token = HF_TOKEN
|
|
| 86 |
if not hf_token:
|
| 87 |
logger.warning("Neither HF_TOKEN nor HUGGINGFACEHUB_API_TOKEN is set, the application may not work.")
|
| 88 |
|
| 89 |
-
#
|
| 90 |
class EducationalAgentState(TypedDict):
|
| 91 |
messages: Annotated[Sequence[BaseMessage], add_messages]
|
| 92 |
needs_tools: bool
|
|
@@ -156,7 +187,7 @@ def Create_Graph_Tool(graph_config: str) -> str:
|
|
| 156 |
logger.error(f"Error in graph generation: {e}")
|
| 157 |
return f'<p style="color:red;">Error creating graph: {str(e)}</p>'
|
| 158 |
|
| 159 |
-
#
|
| 160 |
class Tool_Decision_Engine:
|
| 161 |
"""Uses LLM to intelligently decide when visualization tools would be beneficial"""
|
| 162 |
|
|
@@ -234,7 +265,7 @@ Decision:"""
|
|
| 234 |
log_metric(f"Tool decision time (error): {graph_decision_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 235 |
return False
|
| 236 |
|
| 237 |
-
#
|
| 238 |
SYSTEM_PROMPT = """You are Mimir, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
|
| 239 |
|
| 240 |
## Core Educational Principles
|
|
@@ -291,8 +322,7 @@ Rather than providing complete solutions, you should:
|
|
| 291 |
|
| 292 |
Your goal is to be an educational partner who empowers students to succeed through understanding."""
|
| 293 |
|
| 294 |
-
#
|
| 295 |
-
|
| 296 |
class Phi3MiniEducationalLLM(Runnable):
|
| 297 |
"""LLM class optimized for Microsoft Phi-3-mini-4k-instruct with 4-bit quantization"""
|
| 298 |
|
|
@@ -381,8 +411,15 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 381 |
start_invoke_time = time.perf_counter()
|
| 382 |
current_time = datetime.now()
|
| 383 |
|
|
|
|
| 384 |
if isinstance(input, dict):
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
else:
|
| 387 |
prompt = str(input)
|
| 388 |
|
|
@@ -393,36 +430,59 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 393 |
# Format using Phi-3 chat template
|
| 394 |
text = self._format_chat_template(prompt)
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# Move inputs to model device
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
# Generate with optimized parameters for quantized model
|
| 408 |
with torch.no_grad():
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
# Decode only new tokens
|
| 424 |
-
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
end_invoke_time = time.perf_counter()
|
| 428 |
invoke_time = end_invoke_time - start_invoke_time
|
|
@@ -435,7 +495,7 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 435 |
end_invoke_time = time.perf_counter()
|
| 436 |
invoke_time = end_invoke_time - start_invoke_time
|
| 437 |
log_metric(f"LLM Invoke time (error): {invoke_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 438 |
-
return f"
|
| 439 |
|
| 440 |
@spaces.GPU(duration=240)
|
| 441 |
def stream_generate(self, input: Input, config=None):
|
|
@@ -444,8 +504,12 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 444 |
current_time = datetime.now()
|
| 445 |
logger.info("Starting stream_generate with 4-bit quantized model...")
|
| 446 |
|
|
|
|
| 447 |
if isinstance(input, dict):
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
| 449 |
else:
|
| 450 |
prompt = str(input)
|
| 451 |
|
|
@@ -459,69 +523,88 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 459 |
|
| 460 |
text = self._format_chat_template(prompt)
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
# Move inputs to model device
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
-
# Initialize TextIteratorStreamer
|
| 474 |
streamer = TextIteratorStreamer(
|
| 475 |
self.tokenizer,
|
| 476 |
-
skip_prompt=True,
|
| 477 |
skip_special_tokens=True
|
| 478 |
)
|
| 479 |
|
| 480 |
# Generation parameters optimized for 4-bit
|
| 481 |
generation_kwargs = {
|
| 482 |
-
|
| 483 |
-
"
|
|
|
|
| 484 |
"do_sample": True,
|
| 485 |
"temperature": 0.7,
|
| 486 |
"top_p": 0.9,
|
| 487 |
"top_k": 50,
|
| 488 |
"repetition_penalty": 1.2,
|
| 489 |
"pad_token_id": self.tokenizer.eos_token_id,
|
| 490 |
-
"streamer": streamer,
|
| 491 |
"use_cache": False,
|
| 492 |
"past_key_values": None
|
| 493 |
}
|
| 494 |
|
| 495 |
-
# Start generation in background
|
| 496 |
generation_thread = threading.Thread(
|
| 497 |
target=model.generate,
|
| 498 |
kwargs=generation_kwargs
|
| 499 |
)
|
| 500 |
generation_thread.start()
|
| 501 |
|
| 502 |
-
#
|
| 503 |
generated_text = ""
|
| 504 |
consecutive_repeats = 0
|
| 505 |
last_chunk = ""
|
| 506 |
|
| 507 |
try:
|
| 508 |
-
|
| 509 |
-
|
|
|
|
| 510 |
continue
|
| 511 |
|
| 512 |
-
|
|
|
|
| 513 |
|
| 514 |
# Simple repetition detection
|
| 515 |
-
if
|
| 516 |
consecutive_repeats += 1
|
| 517 |
if consecutive_repeats >= 5:
|
| 518 |
logger.warning("Repetitive generation detected, stopping early")
|
| 519 |
break
|
| 520 |
else:
|
| 521 |
consecutive_repeats = 0
|
| 522 |
-
last_chunk =
|
| 523 |
|
| 524 |
-
#
|
| 525 |
yield generated_text
|
| 526 |
|
| 527 |
except Exception as e:
|
|
@@ -555,8 +638,8 @@ class Phi3MiniEducationalLLM(Runnable):
|
|
| 555 |
@property
|
| 556 |
def OutputType(self) -> Type[Output]:
|
| 557 |
return str
|
| 558 |
-
|
| 559 |
-
#
|
| 560 |
class Educational_Agent:
|
| 561 |
"""Modern LangGraph-based educational agent with Phi-3-mini and improved tool calling"""
|
| 562 |
|
|
@@ -592,87 +675,7 @@ class Educational_Agent:
|
|
| 592 |
|
| 593 |
# Check if the message content contains JSON for tool calling
|
| 594 |
if isinstance(last_message, AIMessage) and last_message.content:
|
| 595 |
-
|
| 596 |
-
# Look for JSON blocks that might be tool calls
|
| 597 |
-
if content.startswith('```json') and 'plot_type' in content:
|
| 598 |
-
logger.info("Found JSON tool configuration in message")
|
| 599 |
-
return "tools"
|
| 600 |
-
|
| 601 |
-
return END
|
| 602 |
-
|
| 603 |
-
def call_model(state: EducationalAgentState) -> dict:
|
| 604 |
-
"""Call the model using the tool decision already made in state"""
|
| 605 |
-
start_call_model_time = time.perf_counter()
|
| 606 |
-
current_time = datetime.now()
|
| 607 |
-
|
| 608 |
-
messages = state["messages"]
|
| 609 |
-
needs_tools = state.get("needs_tools", False) # Use the decision from state
|
| 610 |
-
|
| 611 |
-
# Extract original user query from messages
|
| 612 |
-
user_query = ""
|
| 613 |
-
for msg in reversed(messages):
|
| 614 |
-
if isinstance(msg, HumanMessage):
|
| 615 |
-
user_query = msg.content
|
| 616 |
-
break
|
| 617 |
-
|
| 618 |
-
if not user_query:
|
| 619 |
-
logger.error("No user query found in state messages")
|
| 620 |
-
return {"messages": [AIMessage(content="I didn't receive your message properly. Please try again.")]}
|
| 621 |
-
|
| 622 |
-
try:
|
| 623 |
-
if needs_tools:
|
| 624 |
-
logger.info("Generating response with tool instructions based on state decision")
|
| 625 |
-
# Create tool prompt but preserve original user query
|
| 626 |
-
tool_prompt = f"""
|
| 627 |
-
You are an educational AI assistant. The user has asked: "{user_query}"
|
| 628 |
-
|
| 629 |
-
This query would benefit from a visualization. Please provide a helpful educational response AND include a JSON configuration for creating a graph or chart.
|
| 630 |
-
|
| 631 |
-
Format your response with explanatory text followed by a JSON block like this:
|
| 632 |
-
|
| 633 |
-
```json
|
| 634 |
-
{{
|
| 635 |
-
"data": {{"Category 1": value1, "Category 2": value2}},
|
| 636 |
-
"plot_type": "bar|line|pie",
|
| 637 |
-
"title": "Descriptive Title",
|
| 638 |
-
"x_label": "X Axis Label",
|
| 639 |
-
"y_label": "Y Axis Label",
|
| 640 |
-
"educational_context": "Explanation of why this visualization helps learning"
|
| 641 |
-
}}
|
| 642 |
-
```
|
| 643 |
-
|
| 644 |
-
Provide your educational response followed by the JSON configuration.
|
| 645 |
-
"""
|
| 646 |
-
response = self.llm.invoke(tool_prompt)
|
| 647 |
-
else:
|
| 648 |
-
logger.info("Generating standard educational response")
|
| 649 |
-
response = self.llm.invoke(user_query)
|
| 650 |
-
|
| 651 |
-
end_call_model_time = time.perf_counter()
|
| 652 |
-
call_model_time = end_call_model_time - start_call_model_time
|
| 653 |
-
log_metric(f"Call model time: {call_model_time:0.4f} seconds. Tool decision: {needs_tools}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 654 |
-
|
| 655 |
-
return {"messages": [AIMessage(content=response)]}
|
| 656 |
-
|
| 657 |
-
except Exception as e:
|
| 658 |
-
logger.error(f"Error in call_model: {e}")
|
| 659 |
-
end_call_model_time = time.perf_counter()
|
| 660 |
-
call_model_time = end_call_model_time - start_call_model_time
|
| 661 |
-
log_metric(f"Call model time (error): {call_model_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 662 |
-
return {"messages": [AIMessage(content=f"I encountered an error: {str(e)}")]}
|
| 663 |
-
|
| 664 |
-
def process_json_tools(state: EducationalAgentState) -> dict:
|
| 665 |
-
"""Extract and process JSON tool configurations from AI messages"""
|
| 666 |
-
start_process_tools_time = time.perf_counter()
|
| 667 |
-
current_time = datetime.now()
|
| 668 |
-
|
| 669 |
-
messages = state["messages"]
|
| 670 |
-
last_message = messages[-1]
|
| 671 |
-
|
| 672 |
-
if not isinstance(last_message, AIMessage):
|
| 673 |
-
return {"messages": []}
|
| 674 |
-
|
| 675 |
-
content = last_message.content
|
| 676 |
|
| 677 |
# Look for JSON blocks in the message
|
| 678 |
json_pattern = r'```json\s*\n?(.*?)\n?```'
|
|
@@ -887,7 +890,7 @@ Provide your educational response followed by the JSON configuration.
|
|
| 887 |
log_metric(f"Stream query total time (error): {stream_query_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 888 |
yield f"I encountered an error: {str(e)}"
|
| 889 |
|
| 890 |
-
#
|
| 891 |
def warmup_agent():
|
| 892 |
"""Warm up the agent with a simple test query"""
|
| 893 |
try:
|
|
@@ -904,6 +907,35 @@ def warmup_agent():
|
|
| 904 |
|
| 905 |
except Exception as e:
|
| 906 |
logger.error(f"Warmup failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
|
| 908 |
# --- UI: Interface Creation ---
|
| 909 |
def create_interface():
|
|
|
|
| 38 |
warnings.filterwarnings("ignore", message="Special tokens have been added")
|
| 39 |
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 40 |
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")
|
|
|
|
| 41 |
warnings.filterwarnings("ignore", message=".*TracerWarning.*")
|
| 42 |
warnings.filterwarnings("ignore", message=".*flash-attention.*")
|
| 43 |
|
|
|
|
| 45 |
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 46 |
print("Environment variables loaded.")
|
| 47 |
|
| 48 |
+
# Setup main logger first
|
| 49 |
logging.basicConfig(level=logging.INFO)
|
| 50 |
logger = logging.getLogger(__name__)
|
| 51 |
|
| 52 |
+
# MISSING HTML CONTENT DEFINITIONS - FIX FOR UNDEFINED VARIABLES
|
| 53 |
+
html_head_content = """
|
| 54 |
+
<meta charset="UTF-8">
|
| 55 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 56 |
+
<title>Mimir - Educational AI Assistant</title>
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
force_light_mode = """
|
| 60 |
+
<script>
|
| 61 |
+
// Force light mode
|
| 62 |
+
if (document.documentElement) {
|
| 63 |
+
document.documentElement.setAttribute('data-theme', 'light');
|
| 64 |
+
}
|
| 65 |
+
</script>
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
mathjax_config = """
|
| 69 |
+
<script>
|
| 70 |
+
window.MathJax = {
|
| 71 |
+
tex: {
|
| 72 |
+
inlineMath: [['$', '$'], ['\\(', '\\)']],
|
| 73 |
+
displayMath: [['$$', '$$'], ['\\[', '\\]']],
|
| 74 |
+
processEscapes: true,
|
| 75 |
+
processEnvironments: true
|
| 76 |
+
},
|
| 77 |
+
options: {
|
| 78 |
+
skipHtmlTags: ['script', 'noscript', 'style', 'textarea', 'pre']
|
| 79 |
+
}
|
| 80 |
+
};
|
| 81 |
+
</script>
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
# Environment and Logging Setup
|
| 85 |
def setup_metrics_logger():
|
| 86 |
"""Setup a simple file logger for human-readable metrics"""
|
| 87 |
metrics_logger = logging.getLogger('metrics')
|
|
|
|
| 117 |
if not hf_token:
|
| 118 |
logger.warning("Neither HF_TOKEN nor HUGGINGFACEHUB_API_TOKEN is set, the application may not work.")
|
| 119 |
|
| 120 |
+
# LangGraph State Definition
|
| 121 |
class EducationalAgentState(TypedDict):
|
| 122 |
messages: Annotated[Sequence[BaseMessage], add_messages]
|
| 123 |
needs_tools: bool
|
|
|
|
| 187 |
logger.error(f"Error in graph generation: {e}")
|
| 188 |
return f'<p style="color:red;">Error creating graph: {str(e)}</p>'
|
| 189 |
|
| 190 |
+
# Tool Decision Engine (Updated for LangGraph)
|
| 191 |
class Tool_Decision_Engine:
|
| 192 |
"""Uses LLM to intelligently decide when visualization tools would be beneficial"""
|
| 193 |
|
|
|
|
| 265 |
log_metric(f"Tool decision time (error): {graph_decision_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 266 |
return False
|
| 267 |
|
| 268 |
+
# System Prompt with ReAct Framework for Phi-3-mini
|
| 269 |
SYSTEM_PROMPT = """You are Mimir, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
|
| 270 |
|
| 271 |
## Core Educational Principles
|
|
|
|
| 322 |
|
| 323 |
Your goal is to be an educational partner who empowers students to succeed through understanding."""
|
| 324 |
|
| 325 |
+
# FIXED LLM Class with Phi-3-mini
|
|
|
|
| 326 |
class Phi3MiniEducationalLLM(Runnable):
|
| 327 |
"""LLM class optimized for Microsoft Phi-3-mini-4k-instruct with 4-bit quantization"""
|
| 328 |
|
|
|
|
| 411 |
start_invoke_time = time.perf_counter()
|
| 412 |
current_time = datetime.now()
|
| 413 |
|
| 414 |
+
# FIX: Handle different input types properly
|
| 415 |
if isinstance(input, dict):
|
| 416 |
+
if 'input' in input:
|
| 417 |
+
prompt = input['input']
|
| 418 |
+
elif 'messages' in input:
|
| 419 |
+
# Handle messages format
|
| 420 |
+
prompt = str(input['messages'])
|
| 421 |
+
else:
|
| 422 |
+
prompt = str(input)
|
| 423 |
else:
|
| 424 |
prompt = str(input)
|
| 425 |
|
|
|
|
| 430 |
# Format using Phi-3 chat template
|
| 431 |
text = self._format_chat_template(prompt)
|
| 432 |
|
| 433 |
+
# FIX: Proper tokenization with error handling
|
| 434 |
+
try:
|
| 435 |
+
inputs = self.tokenizer(
|
| 436 |
+
text,
|
| 437 |
+
return_tensors="pt",
|
| 438 |
+
padding=True,
|
| 439 |
+
truncation=True,
|
| 440 |
+
max_length=4096
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Ensure inputs are properly formatted
|
| 444 |
+
if not hasattr(inputs, 'input_ids'):
|
| 445 |
+
logger.error("Tokenizer did not return input_ids")
|
| 446 |
+
return "I encountered an error processing your request. Please try again."
|
| 447 |
+
|
| 448 |
+
except Exception as tokenizer_error:
|
| 449 |
+
logger.error(f"Tokenization error: {tokenizer_error}")
|
| 450 |
+
return "I encountered an error processing your request. Please try again."
|
| 451 |
|
| 452 |
# Move inputs to model device
|
| 453 |
+
try:
|
| 454 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 455 |
+
except Exception as device_error:
|
| 456 |
+
logger.error(f"Device transfer error: {device_error}")
|
| 457 |
+
return "I encountered an error processing your request. Please try again."
|
| 458 |
|
| 459 |
# Generate with optimized parameters for quantized model
|
| 460 |
with torch.no_grad():
|
| 461 |
+
try:
|
| 462 |
+
outputs = model.generate(
|
| 463 |
+
input_ids=inputs['input_ids'],
|
| 464 |
+
attention_mask=inputs.get('attention_mask', None),
|
| 465 |
+
max_new_tokens=1200,
|
| 466 |
+
do_sample=True,
|
| 467 |
+
temperature=0.7,
|
| 468 |
+
top_p=0.9,
|
| 469 |
+
top_k=50,
|
| 470 |
+
repetition_penalty=1.1,
|
| 471 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 472 |
+
use_cache=False,
|
| 473 |
+
past_key_values=None
|
| 474 |
+
)
|
| 475 |
+
except Exception as generation_error:
|
| 476 |
+
logger.error(f"Generation error: {generation_error}")
|
| 477 |
+
return "I encountered an error generating the response. Please try again."
|
| 478 |
|
| 479 |
# Decode only new tokens
|
| 480 |
+
try:
|
| 481 |
+
new_tokens = outputs[0][len(inputs['input_ids'][0]):]
|
| 482 |
+
result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 483 |
+
except Exception as decode_error:
|
| 484 |
+
logger.error(f"Decoding error: {decode_error}")
|
| 485 |
+
return "I encountered an error processing the response. Please try again."
|
| 486 |
|
| 487 |
end_invoke_time = time.perf_counter()
|
| 488 |
invoke_time = end_invoke_time - start_invoke_time
|
|
|
|
| 495 |
end_invoke_time = time.perf_counter()
|
| 496 |
invoke_time = end_invoke_time - start_invoke_time
|
| 497 |
log_metric(f"LLM Invoke time (error): {invoke_time:0.4f} seconds. Model: {self.model_name}. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 498 |
+
return f"I encountered an error: {str(e)}"
|
| 499 |
|
| 500 |
@spaces.GPU(duration=240)
|
| 501 |
def stream_generate(self, input: Input, config=None):
|
|
|
|
| 504 |
current_time = datetime.now()
|
| 505 |
logger.info("Starting stream_generate with 4-bit quantized model...")
|
| 506 |
|
| 507 |
+
# Handle input properly
|
| 508 |
if isinstance(input, dict):
|
| 509 |
+
if 'input' in input:
|
| 510 |
+
prompt = input['input']
|
| 511 |
+
else:
|
| 512 |
+
prompt = str(input)
|
| 513 |
else:
|
| 514 |
prompt = str(input)
|
| 515 |
|
|
|
|
| 523 |
|
| 524 |
text = self._format_chat_template(prompt)
|
| 525 |
|
| 526 |
+
# Proper tokenization with error handling
|
| 527 |
+
try:
|
| 528 |
+
inputs = self.tokenizer(
|
| 529 |
+
text,
|
| 530 |
+
return_tensors="pt",
|
| 531 |
+
padding=True,
|
| 532 |
+
truncation=True,
|
| 533 |
+
max_length=4096
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
if not hasattr(inputs, 'input_ids'):
|
| 537 |
+
yield "I encountered an error processing your request. Please try again."
|
| 538 |
+
return
|
| 539 |
+
|
| 540 |
+
except Exception as tokenizer_error:
|
| 541 |
+
logger.error(f"Streaming tokenization error: {tokenizer_error}")
|
| 542 |
+
yield "I encountered an error processing your request. Please try again."
|
| 543 |
+
return
|
| 544 |
|
| 545 |
# Move inputs to model device
|
| 546 |
+
try:
|
| 547 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 548 |
+
except Exception as device_error:
|
| 549 |
+
logger.error(f"Streaming device transfer error: {device_error}")
|
| 550 |
+
yield "I encountered an error processing your request. Please try again."
|
| 551 |
+
return
|
| 552 |
|
| 553 |
+
# Initialize TextIteratorStreamer - this streams the GENERATED TOKENS, not the input
|
| 554 |
streamer = TextIteratorStreamer(
|
| 555 |
self.tokenizer,
|
| 556 |
+
skip_prompt=True, # Skip the input prompt in output
|
| 557 |
skip_special_tokens=True
|
| 558 |
)
|
| 559 |
|
| 560 |
# Generation parameters optimized for 4-bit
|
| 561 |
generation_kwargs = {
|
| 562 |
+
"input_ids": inputs['input_ids'],
|
| 563 |
+
"attention_mask": inputs.get('attention_mask', None),
|
| 564 |
+
"max_new_tokens": 1200,
|
| 565 |
"do_sample": True,
|
| 566 |
"temperature": 0.7,
|
| 567 |
"top_p": 0.9,
|
| 568 |
"top_k": 50,
|
| 569 |
"repetition_penalty": 1.2,
|
| 570 |
"pad_token_id": self.tokenizer.eos_token_id,
|
| 571 |
+
"streamer": streamer, # This streams the OUTPUT tokens as they're generated
|
| 572 |
"use_cache": False,
|
| 573 |
"past_key_values": None
|
| 574 |
}
|
| 575 |
|
| 576 |
+
# Start generation in background thread
|
| 577 |
generation_thread = threading.Thread(
|
| 578 |
target=model.generate,
|
| 579 |
kwargs=generation_kwargs
|
| 580 |
)
|
| 581 |
generation_thread.start()
|
| 582 |
|
| 583 |
+
# Stream the generated tokens as they come from the model
|
| 584 |
generated_text = ""
|
| 585 |
consecutive_repeats = 0
|
| 586 |
last_chunk = ""
|
| 587 |
|
| 588 |
try:
|
| 589 |
+
# This loop receives tokens as they're generated by the model
|
| 590 |
+
for new_token_text in streamer:
|
| 591 |
+
if not new_token_text:
|
| 592 |
continue
|
| 593 |
|
| 594 |
+
# Accumulate the generated text
|
| 595 |
+
generated_text += new_token_text
|
| 596 |
|
| 597 |
# Simple repetition detection
|
| 598 |
+
if new_token_text == last_chunk:
|
| 599 |
consecutive_repeats += 1
|
| 600 |
if consecutive_repeats >= 5:
|
| 601 |
logger.warning("Repetitive generation detected, stopping early")
|
| 602 |
break
|
| 603 |
else:
|
| 604 |
consecutive_repeats = 0
|
| 605 |
+
last_chunk = new_token_text
|
| 606 |
|
| 607 |
+
# Yield the accumulated generated text (not the input prompt)
|
| 608 |
yield generated_text
|
| 609 |
|
| 610 |
except Exception as e:
|
|
|
|
| 638 |
@property
|
| 639 |
def OutputType(self) -> Type[Output]:
|
| 640 |
return str
|
| 641 |
+
|
| 642 |
+
# LangGraph Agent Implementation with Tool Calling
|
| 643 |
class Educational_Agent:
|
| 644 |
"""Modern LangGraph-based educational agent with Phi-3-mini and improved tool calling"""
|
| 645 |
|
|
|
|
| 675 |
|
| 676 |
# Check if the message content contains JSON for tool calling
|
| 677 |
if isinstance(last_message, AIMessage) and last_message.content:
|
| 678 |
+
content = last_message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
|
| 680 |
# Look for JSON blocks in the message
|
| 681 |
json_pattern = r'```json\s*\n?(.*?)\n?```'
|
|
|
|
| 890 |
log_metric(f"Stream query total time (error): {stream_query_time:0.4f} seconds. Timestamp: {current_time:%Y-%m-%d %H:%M:%S}")
|
| 891 |
yield f"I encountered an error: {str(e)}"
|
| 892 |
|
| 893 |
+
# Gradio Interface Functions
|
| 894 |
def warmup_agent():
|
| 895 |
"""Warm up the agent with a simple test query"""
|
| 896 |
try:
|
|
|
|
| 907 |
|
| 908 |
except Exception as e:
|
| 909 |
logger.error(f"Warmup failed: {e}")
|
| 910 |
+
|
| 911 |
+
def respond_and_update(message, history):
|
| 912 |
+
"""Handle user input and generate streaming response"""
|
| 913 |
+
if not message.strip():
|
| 914 |
+
return history, ""
|
| 915 |
+
|
| 916 |
+
# Add user message to history
|
| 917 |
+
history.append({"role": "user", "content": message})
|
| 918 |
+
|
| 919 |
+
# Add empty assistant message that will be updated
|
| 920 |
+
history.append({"role": "assistant", "content": ""})
|
| 921 |
+
|
| 922 |
+
try:
|
| 923 |
+
# Generate streaming response
|
| 924 |
+
full_response = ""
|
| 925 |
+
for chunk in agent.stream_query(message):
|
| 926 |
+
full_response = chunk
|
| 927 |
+
# Update the last message in history
|
| 928 |
+
history[-1]["content"] = full_response
|
| 929 |
+
yield history, ""
|
| 930 |
+
|
| 931 |
+
except Exception as e:
|
| 932 |
+
logger.error(f"Error in respond_and_update: {e}")
|
| 933 |
+
history[-1]["content"] = f"I encountered an error: {str(e)}"
|
| 934 |
+
yield history, ""
|
| 935 |
+
|
| 936 |
+
def clear_chat():
|
| 937 |
+
"""Clear the chat history"""
|
| 938 |
+
return [], ""
|
| 939 |
|
| 940 |
# --- UI: Interface Creation ---
|
| 941 |
def create_interface():
|