Spaces:
Sleeping
Sleeping
File size: 5,139 Bytes
ba5110e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
"""
State definitions for the LangGraph multi-agent system.
Includes tracking/tracing fields for observability.
"""
from typing import Annotated, Literal, TypedDict, Optional, List
from dataclasses import dataclass, field
from langgraph.graph.message import add_messages
import time
@dataclass
class ToolCall:
"""Record of a tool invocation."""
tool: str
input: str
output: Optional[str] = None
success: bool = False
attempt: int = 1
duration_ms: int = 0
error: Optional[str] = None
@dataclass
class ModelCall:
model: str
agent: str
tokens_in: int
tokens_out: int
duration_ms: int
success: bool
error: Optional[str] = None
tool_calls: Optional[List[dict]] = None
class AgentState(TypedDict):
"""
State for the multi-agent algebra chatbot.
Includes user-facing data and tracking/tracing fields.
"""
# Core messaging
messages: Annotated[list, add_messages]
session_id: str
# Image handling (multi-image support)
image_data: Optional[str] # Legacy: single image (backward compat)
image_data_list: List[str] # NEW: List of base64 encoded images
ocr_text: Optional[str] # Legacy: single OCR result
ocr_results: List[dict] # NEW: List of {"image_index": int, "text": str}
# Agent flow control
current_agent: Literal["ocr", "planner", "executor", "synthetic", "wolfram", "code", "done"]
should_use_tools: bool
selected_tool: Optional[Literal["wolfram", "code"]]
_tool_query: Optional[str] # Internal field to pass query to tool nodes
# Multi-question execution (NEW)
execution_plan: Optional[dict] # Planner output: {"questions": [...]}
question_results: List[dict] # Results per question: [{"id": 1, "result": "...", "error": None}]
# Tool state
wolfram_attempts: int # Max 3 (1 initial + 2 retries)
code_attempts: int # Max 3 for codegen
codefix_attempts: int # Max 2 for fixing
tool_result: Optional[str]
tool_success: bool
# Error handling
error_message: Optional[str]
# Tracking/Tracing (for observability)
agents_used: List[str]
tools_called: List[dict] # List of ToolCall as dicts
model_calls: List[dict] # List of ModelCall as dicts
total_tokens: int
start_time: float
# Memory management
session_token_count: int # Cumulative tokens used in this session
context_status: Literal["ok", "warning", "blocked"]
context_message: Optional[str] # Warning or error message for UI
# Final response
final_response: Optional[str]
def create_initial_state(
session_id: str,
image_data: Optional[str] = None,
image_data_list: Optional[List[str]] = None
) -> AgentState:
"""Create initial state for a new conversation turn."""
# Determine starting agent based on images
has_images = bool(image_data) or bool(image_data_list)
return AgentState(
messages=[],
session_id=session_id,
image_data=image_data,
image_data_list=image_data_list or [],
ocr_text=None,
ocr_results=[],
current_agent="ocr" if has_images else "planner",
should_use_tools=False,
selected_tool=None,
_tool_query=None,
execution_plan=None,
question_results=[],
wolfram_attempts=0,
code_attempts=0,
codefix_attempts=0,
tool_result=None,
tool_success=False,
error_message=None,
agents_used=[],
tools_called=[],
model_calls=[],
total_tokens=0,
start_time=time.time(),
session_token_count=0,
context_status="ok",
context_message=None,
final_response=None,
)
def add_agent_used(state: AgentState, agent_name: str) -> None:
"""Record that an agent was used."""
if agent_name not in state["agents_used"]:
state["agents_used"].append(agent_name)
def add_tool_call(state: AgentState, tool_call: ToolCall) -> None:
"""Record a tool call."""
state["tools_called"].append({
"tool": tool_call.tool,
"input": tool_call.input,
"output": tool_call.output,
"success": tool_call.success,
"attempt": tool_call.attempt,
"duration_ms": tool_call.duration_ms,
"error": tool_call.error,
})
def add_model_call(state: AgentState, model_call: ModelCall) -> None:
"""Record a model call."""
state["model_calls"].append({
"model": model_call.model,
"agent": model_call.agent,
"tokens_in": model_call.tokens_in,
"tokens_out": model_call.tokens_out,
"duration_ms": model_call.duration_ms,
"success": model_call.success,
"error": model_call.error,
})
state["total_tokens"] += model_call.tokens_in + model_call.tokens_out
def get_total_duration_ms(state: AgentState) -> int:
"""Get total duration since start."""
start_time = state.get("start_time")
if start_time is None:
return 0
return int((time.time() - start_time) * 1000)
|