Spaces:
Sleeping
Sleeping
| """ | |
| State definitions for the LangGraph multi-agent system. | |
| Includes tracking/tracing fields for observability. | |
| """ | |
| from typing import Annotated, Literal, TypedDict, Optional, List | |
| from dataclasses import dataclass, field | |
| from langgraph.graph.message import add_messages | |
| import time | |
| class ToolCall: | |
| """Record of a tool invocation.""" | |
| tool: str | |
| input: str | |
| output: Optional[str] = None | |
| success: bool = False | |
| attempt: int = 1 | |
| duration_ms: int = 0 | |
| error: Optional[str] = None | |
| class ModelCall: | |
| model: str | |
| agent: str | |
| tokens_in: int | |
| tokens_out: int | |
| duration_ms: int | |
| success: bool | |
| error: Optional[str] = None | |
| tool_calls: Optional[List[dict]] = None | |
| class AgentState(TypedDict): | |
| """ | |
| State for the multi-agent algebra chatbot. | |
| Includes user-facing data and tracking/tracing fields. | |
| """ | |
| # Core messaging | |
| messages: Annotated[list, add_messages] | |
| session_id: str | |
| # Image handling (multi-image support) | |
| image_data: Optional[str] # Legacy: single image (backward compat) | |
| image_data_list: List[str] # NEW: List of base64 encoded images | |
| ocr_text: Optional[str] # Legacy: single OCR result | |
| ocr_results: List[dict] # NEW: List of {"image_index": int, "text": str} | |
| # Agent flow control | |
| current_agent: Literal["ocr", "planner", "executor", "synthetic", "wolfram", "code", "done"] | |
| should_use_tools: bool | |
| selected_tool: Optional[Literal["wolfram", "code"]] | |
| _tool_query: Optional[str] # Internal field to pass query to tool nodes | |
| # Multi-question execution (NEW) | |
| execution_plan: Optional[dict] # Planner output: {"questions": [...]} | |
| question_results: List[dict] # Results per question: [{"id": 1, "result": "...", "error": None}] | |
| # Tool state | |
| wolfram_attempts: int # Max 3 (1 initial + 2 retries) | |
| code_attempts: int # Max 3 for codegen | |
| codefix_attempts: int # Max 2 for fixing | |
| tool_result: Optional[str] | |
| tool_success: bool | |
| # Error handling | |
| error_message: Optional[str] | |
| # Tracking/Tracing (for observability) | |
| agents_used: List[str] | |
| tools_called: List[dict] # List of ToolCall as dicts | |
| model_calls: List[dict] # List of ModelCall as dicts | |
| total_tokens: int | |
| start_time: float | |
| # Memory management | |
| session_token_count: int # Cumulative tokens used in this session | |
| context_status: Literal["ok", "warning", "blocked"] | |
| context_message: Optional[str] # Warning or error message for UI | |
| # Final response | |
| final_response: Optional[str] | |
| def create_initial_state( | |
| session_id: str, | |
| image_data: Optional[str] = None, | |
| image_data_list: Optional[List[str]] = None | |
| ) -> AgentState: | |
| """Create initial state for a new conversation turn.""" | |
| # Determine starting agent based on images | |
| has_images = bool(image_data) or bool(image_data_list) | |
| return AgentState( | |
| messages=[], | |
| session_id=session_id, | |
| image_data=image_data, | |
| image_data_list=image_data_list or [], | |
| ocr_text=None, | |
| ocr_results=[], | |
| current_agent="ocr" if has_images else "planner", | |
| should_use_tools=False, | |
| selected_tool=None, | |
| _tool_query=None, | |
| execution_plan=None, | |
| question_results=[], | |
| wolfram_attempts=0, | |
| code_attempts=0, | |
| codefix_attempts=0, | |
| tool_result=None, | |
| tool_success=False, | |
| error_message=None, | |
| agents_used=[], | |
| tools_called=[], | |
| model_calls=[], | |
| total_tokens=0, | |
| start_time=time.time(), | |
| session_token_count=0, | |
| context_status="ok", | |
| context_message=None, | |
| final_response=None, | |
| ) | |
| def add_agent_used(state: AgentState, agent_name: str) -> None: | |
| """Record that an agent was used.""" | |
| if agent_name not in state["agents_used"]: | |
| state["agents_used"].append(agent_name) | |
| def add_tool_call(state: AgentState, tool_call: ToolCall) -> None: | |
| """Record a tool call.""" | |
| state["tools_called"].append({ | |
| "tool": tool_call.tool, | |
| "input": tool_call.input, | |
| "output": tool_call.output, | |
| "success": tool_call.success, | |
| "attempt": tool_call.attempt, | |
| "duration_ms": tool_call.duration_ms, | |
| "error": tool_call.error, | |
| }) | |
| def add_model_call(state: AgentState, model_call: ModelCall) -> None: | |
| """Record a model call.""" | |
| state["model_calls"].append({ | |
| "model": model_call.model, | |
| "agent": model_call.agent, | |
| "tokens_in": model_call.tokens_in, | |
| "tokens_out": model_call.tokens_out, | |
| "duration_ms": model_call.duration_ms, | |
| "success": model_call.success, | |
| "error": model_call.error, | |
| }) | |
| state["total_tokens"] += model_call.tokens_in + model_call.tokens_out | |
| def get_total_duration_ms(state: AgentState) -> int: | |
| """Get total duration since start.""" | |
| start_time = state.get("start_time") | |
| if start_time is None: | |
| return 0 | |
| return int((time.time() - start_time) * 1000) | |