""" LangGraph node implementations for the multi-agent algebra chatbot. Agents: ocr_agent, planner, parallel_executor, synthetic_agent Tools: wolfram_tool_node, code_tool_node """ import os import time import json import re import asyncio from typing import List, Dict, Any, Optional from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from backend.agent.state import ( AgentState, ToolCall, ModelCall, add_agent_used, add_tool_call, add_model_call ) from backend.agent.models import model_manager, get_model from backend.tools.wolfram import query_wolfram_alpha from backend.tools.code_executor import CodeTool from backend.utils.memory import ( memory_tracker, estimate_tokens, estimate_message_tokens, TokenOverflowError, truncate_history_to_fit ) from backend.agent.prompts import ( OCR_PROMPT, SYNTHETIC_PROMPT, CODEGEN_PROMPT, CODEGEN_FIX_PROMPT, PLANNER_SYSTEM_PROMPT, PLANNER_USER_PROMPT ) # ============================================================================ # HELPER FUNCTIONS FOR OUTPUT FORMATTING # ============================================================================ def format_latex_for_markdown(text: str) -> str: """ Format LaTeX content for proper Markdown rendering. Key principle: - Add paragraph breaks (double newlines) OUTSIDE of $$...$$ blocks - NEVER modify content INSIDE $$...$$ blocks (preserves aligned, matrix, etc.) - Ensure $$ is on its own line for block rendering Args: text: Raw text containing LaTeX expressions Returns: Formatted text suitable for Markdown rendering """ if not text: return text # Split by $$ to separate math blocks from text parts = text.split('$$') formatted_parts = [] for i, part in enumerate(parts): if i % 2 == 0: # OUTSIDE math block (text content) # Add paragraph spacing for better readability # But be careful not to add excessive whitespace formatted_parts.append(part) else: # INSIDE math block - preserve exactly as-is # Just wrap with $$ and ensure it's on its own line formatted_parts.append(f'\n$$\n{part.strip()}\n$$\n') # Rejoin: even parts are text, odd parts are already formatted with $$ result = '' for i, part in enumerate(formatted_parts): if i % 2 == 0: result += part else: # This is the formatted math block, append directly result += part # Clean up excessive whitespace (more than 2 consecutive newlines) result = re.sub(r'\n{3,}', '\n\n', result) return result.strip() # ============================================================================ # AGENT NODES # ============================================================================ async def ocr_agent_node(state: AgentState) -> AgentState: """ OCR Agent: Extract text from images using vision model. Supports multiple images with parallel processing. Primary: llama-4-maverick, Fallback: llama-4-scout """ import asyncio add_agent_used(state, "ocr_agent") # Check for images (new list or legacy single image) image_list = state.get("image_data_list", []) if not image_list and state.get("image_data"): image_list = [state["image_data"]] # Backward compatibility if not image_list: # No images - proceed directly to planner (OCR skipped) state["current_agent"] = "planner" return state start_time = time.time() primary_model = "llama-4-maverick" fallback_model = "llama-4-scout" async def ocr_single_image(image_data: str, index: int) -> dict: """Process a single image and return result dict.""" content = [ {"type": "text", "text": OCR_PROMPT}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}} ] messages = [HumanMessage(content=content)] model_used = primary_model try: # Check rate limit for primary can_use, error = model_manager.check_rate_limit(primary_model) if not can_use: model_used = fallback_model can_use, error = model_manager.check_rate_limit(fallback_model) if not can_use: return {"image_index": index + 1, "text": None, "error": error} llm = get_model(model_used) response = await llm.ainvoke(messages) return {"image_index": index + 1, "text": response.content, "error": None} except Exception as e: return {"image_index": index + 1, "text": None, "error": str(e)} # Process all images in parallel tasks = [ocr_single_image(img, i) for i, img in enumerate(image_list)] results = await asyncio.gather(*tasks) duration_ms = int((time.time() - start_time) * 1000) # Store results state["ocr_results"] = results # Build combined OCR text for backward compatibility successful_texts = [] for r in results: if r["text"]: if len(image_list) > 1: successful_texts.append(f"[Ảnh {r['image_index']}]:\n{r['text']}") else: successful_texts.append(r["text"]) state["ocr_text"] = "\n\n".join(successful_texts) if successful_texts else None # Log model calls add_model_call(state, ModelCall( model=primary_model, agent="ocr_agent", tokens_in=500 * len(image_list), tokens_out=sum(len(r.get("text", "") or "") // 4 for r in results), duration_ms=duration_ms, success=any(r["text"] for r in results) )) # Report any errors but continue errors = [f"Ảnh {r['image_index']}: {r['error']}" for r in results if r["error"]] if errors and not successful_texts: state["error_message"] = "OCR failed: " + "; ".join(errors) # Route to planner for multi-question analysis state["current_agent"] = "planner" return state async def planner_node(state: AgentState) -> AgentState: """ Planner Node: Analyze all content (text + OCR) and identify individual questions. Creates an execution plan for parallel processing. NOW WITH FULL CONVERSATION HISTORY FOR MEMORY! """ import asyncio add_agent_used(state, "planner") start_time = time.time() model_name = "kimi-k2" # Get user text from last message user_text = "" for msg in reversed(state["messages"]): if isinstance(msg, HumanMessage): user_text = msg.content if isinstance(msg.content, str) else str(msg.content) break ocr_text = state.get("ocr_text") or "(Không có ảnh)" # Build user prompt for current request current_prompt = PLANNER_USER_PROMPT.format( user_text=user_text or "(Không có text)", ocr_text=ocr_text ) # ======================================== # NEW: Build messages WITH conversation history # ======================================== llm_messages = [] # 1. Add system prompt with memory-awareness instructions llm_messages.append(SystemMessage(content=PLANNER_SYSTEM_PROMPT)) # 2. Add truncated conversation history (smart token management) history_messages = state.get("messages", []) # Exclude the last message since we'll add current_prompt separately if history_messages: history_to_include = history_messages[:-1] if len(history_messages) > 1 else [] else: history_to_include = [] # Truncate history to fit within token limits system_tokens = estimate_tokens(PLANNER_SYSTEM_PROMPT) current_tokens = estimate_tokens(current_prompt) truncated_history = truncate_history_to_fit( history_to_include, system_tokens=system_tokens, current_tokens=current_tokens, max_context_tokens=200000 # Leave room within 256K limit ) # Add history messages for msg in truncated_history: llm_messages.append(msg) # 3. Add current user request as last message llm_messages.append(HumanMessage(content=current_prompt)) # Calculate total input tokens for tracking total_input_tokens = system_tokens + estimate_message_tokens(truncated_history) + current_tokens try: llm = get_model(model_name) response = await llm.ainvoke(llm_messages) content = response.content.strip() duration_ms = int((time.time() - start_time) * 1000) add_model_call(state, ModelCall( model=model_name, agent="planner", tokens_in=total_input_tokens, tokens_out=len(content) // 4, duration_ms=duration_ms, success=True )) # Parse JSON from response # Handle markdown code blocks if "```json" in content: content = content.split("```json")[1].split("```")[0].strip() elif "```" in content: content = content.split("```")[1].split("```")[0].strip() try: # Try to parse JSON (Mixed/Tool Case) plan = json.loads(content) except json.JSONDecodeError: try: # Try repair: Fix invalid escapes for LaTeX (e.g., \frac -> \\frac) # Matches backslash NOT followed by valid JSON escape chars (excluding \\ itself) fixed_content = re.sub(r'\\(?![unrtbf"\/])', r'\\\\', content) plan = json.loads(fixed_content) except Exception: # If JSON parsing fails completely, try Regex Fallback # This catches cases where LLM returns valid-looking JSON but with syntax errors if content.strip().startswith("{") and '"questions"' in content: # Attempt to extract answers using Regex # Pattern: "answer": "..." (handling escaped quotes is hard in regex, simplified) import re # Extract individual question blocks (simplified assumption) # Use a rough scan for "answer": "..." # Find all "answer": "(.*?)" where content is non-greedy until next quote # Note: this is fragile but better than raw JSON # Better fallback: Just treat it as raw text but tell user format error pass # If JSON fails, it means Planner returned Direct Text Answer (All Direct Case) # OR malformed JSON that looks like text. # Check directly if it looks like the raw JSON output if content.strip().startswith('{') and '"type": "direct"' in content: # This is likely the malformed JSON case the user saw # Use Regex to extract answers answers = re.findall(r'"answer":\s*"(.*?)(? Check if all questions are direct (LLM didn't follow prompt correctly) all_direct = all(q.get("type") == "direct" for q in plan.get("questions", [])) if all_direct: # LLM returned JSON for all-direct case (should have returned text) # Check if answers are provided questions = plan.get("questions", []) has_valid_answers = all(q.get("answer") for q in questions) if has_valid_answers: # Answers are in the JSON, extract them final_parts = [] for q in questions: q_id = q.get("id", "?") q_answer = q.get("answer", "") # Use helper to properly format LaTeX for Markdown formatted_answer = format_latex_for_markdown(q_answer) final_parts.append(f"## Bài {q_id}:\n{formatted_answer}\n") final_response = "\n".join(final_parts) else: # No answers provided - LLM didn't follow prompt correctly # Route to executor to re-process these as direct questions # For now, mark as needing tool (wolfram) so they get solved for q in questions: if not q.get("answer"): q["type"] = "wolfram" # Force tool use if not q.get("tool_input"): q["tool_input"] = q.get("content", "") state["execution_plan"] = plan state["current_agent"] = "executor" # Update memory tracking session_id = state["session_id"] tokens_in = total_input_tokens tokens_out = len(content) // 4 total_turn_tokens = tokens_in + tokens_out memory_tracker.add_usage(session_id, total_turn_tokens) new_status = memory_tracker.check_status(session_id) state["session_token_count"] = new_status.used_tokens state["context_status"] = new_status.status state["context_message"] = new_status.message return state state["execution_plan"] = None state["final_response"] = final_response state["messages"].append(AIMessage(content=final_response)) state["current_agent"] = "done" # Update memory tracking session_id = state["session_id"] tokens_in = total_input_tokens tokens_out = len(content) // 4 total_turn_tokens = tokens_in + tokens_out memory_tracker.add_usage(session_id, total_turn_tokens) new_status = memory_tracker.check_status(session_id) state["session_token_count"] = new_status.used_tokens state["context_status"] = new_status.status state["context_message"] = new_status.message return state # Mixed/Tool Case -> Route to Executor state["execution_plan"] = plan state["current_agent"] = "executor" # Update memory tracking (consistent with other agents) session_id = state["session_id"] tokens_in = total_input_tokens tokens_out = len(content) // 4 total_turn_tokens = tokens_in + tokens_out memory_tracker.add_usage(session_id, total_turn_tokens) new_status = memory_tracker.check_status(session_id) state["session_token_count"] = new_status.used_tokens state["context_status"] = new_status.status state["context_message"] = new_status.message # Check for memory overflow if new_status.status == "blocked": state["final_response"] = new_status.message state["current_agent"] = "done" except Exception as e: add_model_call(state, ModelCall( model=model_name, agent="planner", tokens_in=0, tokens_out=0, duration_ms=int((time.time() - start_time) * 1000), success=False, error=str(e) )) # Fallback: Planner failed, return error to user error_msg = str(e) user_friendly_msg = "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi." if "413" in error_msg or "Request too large" in error_msg: user_friendly_msg = "Nội dung lịch sử trò chuyện vượt quá giới hạn mô hình. Vui lòng tạo hội thoại mới để tiếp tục." elif "rate_limit" in error_msg or "TPM" in error_msg: user_friendly_msg = "Hệ thống đang quá tải (Rate Limit). Bạn vui lòng đợi khoảng 10-20 giây rồi thử lại nhé!" elif "context_length_exceeded" in error_msg: user_friendly_msg = "Hội thoại đã quá dài. Vui lòng tạo hội thoại mới để tiếp tục." else: user_friendly_msg = f"Xin lỗi, đã có lỗi kỹ thuật: {error_msg}." state["execution_plan"] = None state["final_response"] = user_friendly_msg state["current_agent"] = "done" return state async def parallel_executor_node(state: AgentState) -> AgentState: """ Parallel Executor: Execute multiple questions in parallel. - Direct questions: Process with kimi-k2 - Wolfram questions: Call API in parallel - Code questions: Execute code in parallel """ import asyncio add_agent_used(state, "parallel_executor") plan = state.get("execution_plan") if not plan or not plan.get("questions"): # No plan - planner should have handled this, go to done state["current_agent"] = "done" return state questions = plan["questions"] start_time = time.time() async def execute_single_question(q: dict) -> dict: """Execute a single question and return result.""" q_id = q.get("id", 0) q_type = q.get("type", "direct") q_content = q.get("content", "") q_tool_input = q.get("tool_input", "") result = { "id": q_id, "content": q_content, "type": q_type, "result": None, "error": None } async def solve_with_code(task_description: str, retries: int = 3) -> dict: """Helper to run code tool with retries.""" code_tool = CodeTool() out = {"result": None, "error": None} last_code = "" last_error = "" for attempt in range(retries): try: llm = get_model("qwen3-32b") # SMART RETRY: If we have an error, ask LLM to FIX it if attempt > 0 and last_error: code_prompt = CODEGEN_FIX_PROMPT.format(code=last_code, error=last_error) else: code_prompt = CODEGEN_PROMPT.format(task=task_description) code_response = await llm.ainvoke([HumanMessage(content=code_prompt)]) # Extract code code = code_response.content if "```python" in code: code = code.split("```python")[1].split("```")[0] elif "```" in code: code = code.split("```")[1].split("```")[0] last_code = code # Save for next retry if needed # Execute exec_result = code_tool.execute(code) if exec_result.get("success"): out["result"] = exec_result.get("output", "") return out else: last_error = exec_result.get("error", "Unknown error") if attempt == retries - 1: out["error"] = last_error except Exception as e: last_error = str(e) if attempt == retries - 1: out["error"] = str(e) return out try: if q_type == "wolfram": wolfram_done = False # Call Wolfram Alpha (with retry logic) # Call Wolfram Alpha (1 attempt only) for attempt in range(1): try: can_use, err = model_manager.check_rate_limit("wolfram") if not can_use: if attempt == 0: break await asyncio.sleep(1) continue wolfram_success, wolfram_result = await query_wolfram_alpha(q_tool_input) if wolfram_success: result["result"] = wolfram_result wolfram_done = True break else: # Treat logical failure as exception to trigger retry/fallback if attempt == 0: raise Exception(wolfram_result) except Exception as e: if attempt == 0: result["error"] = f"Wolfram failed: {str(e)}" await asyncio.sleep(0.5) # --- FALLBACK TO CODE IF WOLFRAM FAILED --- if not wolfram_done: # Append status to result fallback_note = f"\n(Wolfram failed, tried Code fallback)" code_out = await solve_with_code(q_tool_input) if code_out["result"]: result["result"] = code_out["result"] + fallback_note result["error"] = None # Clear error if fallback succeeded result["type"] = "wolfram+code" # Indicate hybrid path else: result["error"] += f" | Code Fallback also failed: {code_out['error']}" elif q_type == "code": # Execute code directly code_out = await solve_with_code(q_tool_input) result["result"] = code_out["result"] result["error"] = code_out["error"] else: # direct # User Optimization: If planner provided answer, use it directly (Save API) if q.get("answer"): result["result"] = q.get("answer") else: # Fallback: Solve directly with kimi-k2 (if planner forgot answer) llm = get_model("kimi-k2") solve_prompt = f"Giải bài toán sau một cách chi tiết:\n{q_content}" response = await llm.ainvoke([ SystemMessage(content="Bạn là chuyên gia giải toán. Trả lời ngắn gọn, đúng trọng tâm."), HumanMessage(content=solve_prompt) ]) result["result"] = format_latex_for_markdown(response.content) # Direct result except Exception as e: result["error"] = str(e) return result # Execute all questions in parallel tasks = [execute_single_question(q) for q in questions] results = await asyncio.gather(*tasks, return_exceptions=True) # Process results and collect metrics question_results = [] total_tokens_in = 0 total_tokens_out = 0 for i, r in enumerate(results): q = questions[i] q_type = q.get("type", "direct") # Prepare result entry res_entry = { "id": q.get("id", i+1), "content": q.get("content", ""), "result": None, "error": None, "type": q_type } if isinstance(r, Exception): error_msg = str(r) if "413" in error_msg or "Request too large" in error_msg: friendly = "Nội dung quá dài, vui lòng gửi ngắn hơn." elif "rate_limit" in error_msg or "TPM" in error_msg: friendly = "Rate Limit (Quá tải), vui lòng đợi giây lát." else: friendly = f"Lỗi kỹ thuật: {error_msg}" res_entry["error"] = friendly success = False r_content = friendly else: # r is the result dict from execute_single_question res_entry.update(r) success = not bool(r.get("error")) r_content = str(r.get("result", "")) # Use friendly error if present in result dict raw_err = r.get("error") if raw_err: error_msg = str(raw_err) if "413" in error_msg or "Request too large" in error_msg: friendly = "Nội dung quá dài, vui lòng gửi ngắn hơn." elif "rate_limit" in error_msg or "TPM" in error_msg: friendly = "Rate Limit (Quá tải), vui lòng đợi giây lát." else: friendly = f"Lỗi kỹ thuật: {error_msg}" res_entry["error"] = friendly r_content = friendly question_results.append(res_entry) # Add individual model call trace for each parallel task # This allows the frontend to show "Wolfram", "Code", "Kimi" calls clearly # Estimate tokens for metrics (rough check) t_in = len(q.get("content", "")) // 4 t_out = len(r_content) // 4 total_tokens_in += t_in total_tokens_out += t_out model_name_trace = "unknown" if q_type == "wolfram": model_name_trace = "wolfram-alpha" elif q_type == "code": model_name_trace = "python-code-executor" else: model_name_trace = "kimi-k2" add_model_call(state, ModelCall( model=model_name_trace, agent=f"parallel_executor_q{res_entry['id']}", tokens_in=t_in, tokens_out=t_out, duration_ms=int((time.time() - start_time) * 1000), # Approx sharing total time success=success, tool_calls=[{ "tool": q_type, "input": q.get("tool_input") or q.get("content"), "output": r_content[:200] + "..." if len(r_content) > 200 else r_content }] )) state["question_results"] = question_results # --- UI COMPATIBILITY FIX --- # Populate legacy fields so the Tracing UI (which expects single tool per turn) shows SOMETHING. # We aggregate all parallel results into a single string. start_time_ms = int(start_time * 1000) # 1. Selected Tool tool_names = list(set(r["type"] for r in question_results)) state["selected_tool"] = f"parallel({','.join(tool_names)})" state["should_use_tools"] = True # 2. Tool Result (Aggregated) agg_result = [] for r in question_results: status = "✅" if not r.get("error") else "❌" val = r.get("result") or r.get("error") agg_result.append(f"[{status} {r['type'].upper()}]: {str(val)[:100]}...") state["tool_result"] = "\n".join(agg_result) # 3. Tools Called (List of ToolCall objects) tools_called_list = [] for r in question_results: tools_called_list.append({ "tool": r["type"], "tool_input": str(questions[next((i for i, q in enumerate(questions) if q.get("id") == r["id"]), 0)].get("tool_input", "") or r.get("content")), "tool_output": str(r.get("result") or r.get("error")) }) state["tools_called"] = tools_called_list state["tool_success"] = any(not r.get("error") for r in question_results) # --------------------------- duration_ms = int((time.time() - start_time) * 1000) add_model_call(state, ModelCall( model="parallel_orchestrator", agent="parallel_executor", tokens_in=total_tokens_in, tokens_out=total_tokens_out, duration_ms=duration_ms, success=state["tool_success"] )) # Go to synthesizer to combine results state["current_agent"] = "synthetic" return state # NOTE: reasoning_agent_node has been DEPRECATED and REMOVED. # The workflow now flows: OCR -> Planner -> Executor -> Synthetic # (See user's workflow diagram for reference) async def synthetic_agent_node(state: AgentState) -> AgentState: """ Synthetic Agent: Synthesize tool results into final response. Handles both single-tool results and multi-question parallel results. Uses kimi-k2. """ add_agent_used(state, "synthetic_agent") start_time = time.time() model_name = "kimi-k2" session_id = state["session_id"] # Check memory status before processing mem_status = memory_tracker.check_status(session_id) if mem_status.status == "blocked": state["context_status"] = "blocked" state["context_message"] = mem_status.message state["final_response"] = mem_status.message state["current_agent"] = "done" return state # Check if we have multi-question results from parallel executor question_results = state.get("question_results", []) if question_results: # Multi-question mode: combine all results # Use LLM to synthesize a natural response instead of raw concatenation # Prepare context for synthesis results_context = [] for r in question_results: q_id = r.get("id", 0) q_content = r.get("content", "") q_result = r.get("result", "Không có kết quả") q_error = r.get("error") status = "Thành công" if not q_error else f"Lỗi: {q_error}" results_context.append(f"--- BÀI TOÁN {q_id} ---\nNội dung: {q_content}\nTrạng thái: {status}\nKết quả gốc:\n{q_result}\n\n") combined_context = "".join(results_context) # Get original question text for context original_q_text = "Nhiều câu hỏi (xem chi tiết bên trên)" if state.get("ocr_text"): original_q_text = f"[OCR]: {state['ocr_text']}" elif state["messages"]: for m in reversed(state["messages"]): if isinstance(m, HumanMessage): original_q_text = str(m.content) break # Use Standard SYNTHETIC_PROMPT synth_prompt = SYNTHETIC_PROMPT.format( tool_result=combined_context, original_question=original_q_text ) # ======================================== # NEW: Include recent conversation history for contextual synthesis # ======================================== llm_messages = [ SystemMessage(content="""Bạn là chuyên gia toán học Việt Nam. Hãy giải thích lời giải một cách sư phạm, dễ hiểu. VỀ BỘ NHỚ HỘI THOẠI: - Bạn có thể tham chiếu đến các câu hỏi trước đó trong hội thoại. - Nếu người dùng đề cập đến "bài trước", "câu đó", hãy hiểu ngữ cảnh. - Trả lời tự nhiên như một cuộc trò chuyện liên tục."""), ] # Add recent conversation history (last 3 turns = 6 messages) recent_history = state.get("messages", [])[-6:] for msg in recent_history: llm_messages.append(msg) # Add synthesis prompt llm_messages.append(HumanMessage(content=synth_prompt)) try: llm = get_model("kimi-k2") response = await llm.ainvoke(llm_messages) final_response = format_latex_for_markdown(response.content) except Exception as e: # Fallback manual synthesis if LLM fails error_msg = str(e) if "413" in error_msg or "Request too large" in error_msg: friendly_err = "Nội dung quá dài để tổng hợp." elif "rate_limit" in error_msg or "TPM" in error_msg: friendly_err = "Hệ thống đang bận (Rate Limit)." else: friendly_err = f"Lỗi kỹ thuật: {error_msg}" final_response = f"**Kết quả (Tổng hợp tự động thất bại do {friendly_err}):**\n\n" + combined_context state["final_response"] = final_response state["messages"].append(AIMessage(content=final_response)) state["current_agent"] = "done" # Update memory tokens_out = len(final_response) // 4 memory_tracker.add_usage(session_id, tokens_out) new_status = memory_tracker.check_status(session_id) state["session_token_count"] = new_status.used_tokens state["context_status"] = new_status.status state["context_message"] = new_status.message return state # Single-question mode: original logic # Get original question original_question = "" if state["messages"]: for msg in state["messages"]: if hasattr(msg, "content") and isinstance(msg, HumanMessage): original_question = msg.content if isinstance(msg.content, str) else str(msg.content) break # Add OCR context if available if state.get("ocr_text"): original_question = f"[Từ ảnh]: {state['ocr_text']}\n\n{original_question}" # Build prompt tool_result = state.get("tool_result", "Không có kết quả") if not state.get("tool_success"): tool_result = f"[Công cụ thất bại]: {state.get('error_message', 'Unknown error')}\n\nHãy cố gắng trả lời dựa trên kiến thức của bạn." prompt = SYNTHETIC_PROMPT.format( tool_result=tool_result, original_question=original_question ) messages = [HumanMessage(content=prompt)] tokens_in = estimate_tokens(prompt) try: llm = get_model(model_name) response = await llm.ainvoke(messages) duration_ms = int((time.time() - start_time) * 1000) tokens_out = len(response.content) // 4 add_model_call(state, ModelCall( model=model_name, agent="synthetic_agent", tokens_in=tokens_in, tokens_out=tokens_out, duration_ms=duration_ms, success=True )) # Update session memory tracker total_turn_tokens = tokens_in + tokens_out memory_tracker.add_usage(session_id, total_turn_tokens) new_status = memory_tracker.check_status(session_id) state["session_token_count"] = new_status.used_tokens state["context_status"] = new_status.status state["context_message"] = new_status.message # Format the synthesis with standard helper formatted_response = format_latex_for_markdown(response.content) state["final_response"] = formatted_response state["messages"].append(AIMessage(content=formatted_response)) state["current_agent"] = "done" except Exception as e: # Fallback to raw tool result if synthesis fails fallback_response = f"**Kết quả tính toán:**\n{state.get('tool_result', 'Không có kết quả')}" state["final_response"] = fallback_response state["messages"].append(AIMessage(content=fallback_response)) state["current_agent"] = "done" return state # ============================================================================ # TOOL NODES # ============================================================================ async def wolfram_tool_node(state: AgentState) -> AgentState: """ Wolfram Tool: Query Wolfram Alpha. Max 3 attempts (1 initial + 2 retries). """ add_agent_used(state, "wolfram_tool") query = state.get("_tool_query", "") state["wolfram_attempts"] += 1 start_time = time.time() success, result = await query_wolfram_alpha(query) duration_ms = int((time.time() - start_time) * 1000) tool_call = ToolCall( tool="wolfram", input=query, output=result if success else None, success=success, attempt=state["wolfram_attempts"], duration_ms=duration_ms, error=None if success else result ) add_tool_call(state, tool_call) if success: state["tool_result"] = result state["tool_success"] = True state["current_agent"] = "synthetic" else: if state["wolfram_attempts"] < 1: # Retry state["current_agent"] = "wolfram" else: # Fallback to code tool state["selected_tool"] = "code" state["current_agent"] = "code" return state async def code_tool_node(state: AgentState) -> AgentState: """ Code Tool: Generate and execute Python code. codegen_agent: qwen3-32b codefix_agent: gpt-oss-120b (max 2 fixes) """ add_agent_used(state, "code_tool") task = state.get("_tool_query", "") state["code_attempts"] += 1 code_tool = CodeTool() start_time = time.time() # Generate code using qwen3-32b codegen_start = time.time() try: llm = get_model("qwen3-32b") prompt = CODEGEN_PROMPT.format(task=task) response = await llm.ainvoke([HumanMessage(content=prompt)]) code = _extract_code(response.content) add_model_call(state, ModelCall( model="qwen3-32b", agent="codegen_agent", tokens_in=len(prompt) // 4, tokens_out=len(response.content) // 4, duration_ms=int((time.time() - codegen_start) * 1000), success=True )) except Exception as e: add_model_call(state, ModelCall( model="qwen3-32b", agent="codegen_agent", tokens_in=0, tokens_out=0, duration_ms=int((time.time() - codegen_start) * 1000), success=False, error=str(e) )) state["error_message"] = f"Code generation failed: {str(e)}" state["tool_success"] = False state["current_agent"] = "synthetic" return state # Execute code with correction loop (max 2 fixes) exec_result = code_tool.execute(code) while not exec_result["success"] and state["codefix_attempts"] < 2: state["codefix_attempts"] += 1 # Fix code using gpt-oss-120b fix_start = time.time() try: llm = get_model("gpt-oss-120b") fix_prompt = CODEGEN_FIX_PROMPT.format(code=code, error=exec_result["error"]) response = await llm.ainvoke([HumanMessage(content=fix_prompt)]) code = _extract_code(response.content) add_model_call(state, ModelCall( model="gpt-oss-120b", agent="codefix_agent", tokens_in=len(fix_prompt) // 4, tokens_out=len(response.content) // 4, duration_ms=int((time.time() - fix_start) * 1000), success=True )) exec_result = code_tool.execute(code) except Exception as e: add_model_call(state, ModelCall( model="gpt-oss-120b", agent="codefix_agent", tokens_in=0, tokens_out=0, duration_ms=int((time.time() - fix_start) * 1000), success=False, error=str(e) )) break duration_ms = int((time.time() - start_time) * 1000) tool_call = ToolCall( tool="code", input=task, output=exec_result.get("output") if exec_result["success"] else None, success=exec_result["success"], attempt=state["code_attempts"], duration_ms=duration_ms, error=exec_result.get("error") if not exec_result["success"] else None ) add_tool_call(state, tool_call) if exec_result["success"]: state["tool_result"] = exec_result["output"] state["tool_success"] = True else: state["tool_result"] = f"Code execution failed after {state['codefix_attempts']} fixes: {exec_result.get('error')}" state["tool_success"] = False state["error_message"] = exec_result.get("error") state["current_agent"] = "synthetic" return state def _extract_code(response: str) -> str: """Extract Python code from LLM response.""" if "```python" in response: return response.split("```python")[1].split("```")[0].strip() elif "```" in response: return response.split("```")[1].split("```")[0].strip() return response.strip() # ============================================================================ # ROUTER # ============================================================================ def route_agent(state: AgentState) -> str: """Route to the next agent/node based on current state.""" current = state.get("current_agent", "done") if current == "ocr": return "ocr_agent" elif current == "planner": return "planner" elif current == "executor": return "executor" elif current == "wolfram": return "wolfram_tool" elif current == "code": return "code_tool" elif current == "synthetic": return "synthetic_agent" elif current == "done": return "done" else: return "end"