File size: 5,139 Bytes
ba5110e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
State definitions for the LangGraph multi-agent system.
Includes tracking/tracing fields for observability.
"""
from typing import Annotated, Literal, TypedDict, Optional, List
from dataclasses import dataclass, field
from langgraph.graph.message import add_messages
import time


@dataclass
class ToolCall:
    """Record of a tool invocation."""
    tool: str
    input: str
    output: Optional[str] = None
    success: bool = False
    attempt: int = 1
    duration_ms: int = 0
    error: Optional[str] = None


@dataclass
class ModelCall:
    model: str
    agent: str
    tokens_in: int
    tokens_out: int
    duration_ms: int
    success: bool
    error: Optional[str] = None
    tool_calls: Optional[List[dict]] = None


class AgentState(TypedDict):
    """
    State for the multi-agent algebra chatbot.
    Includes user-facing data and tracking/tracing fields.
    """
    # Core messaging
    messages: Annotated[list, add_messages]
    session_id: str
    
    # Image handling (multi-image support)
    image_data: Optional[str]          # Legacy: single image (backward compat)
    image_data_list: List[str]         # NEW: List of base64 encoded images
    ocr_text: Optional[str]            # Legacy: single OCR result
    ocr_results: List[dict]            # NEW: List of {"image_index": int, "text": str}
    
    # Agent flow control
    current_agent: Literal["ocr", "planner", "executor", "synthetic", "wolfram", "code", "done"]
    should_use_tools: bool
    selected_tool: Optional[Literal["wolfram", "code"]]
    _tool_query: Optional[str]  # Internal field to pass query to tool nodes
    
    # Multi-question execution (NEW)
    execution_plan: Optional[dict]     # Planner output: {"questions": [...]}
    question_results: List[dict]       # Results per question: [{"id": 1, "result": "...", "error": None}]
    
    # Tool state
    wolfram_attempts: int      # Max 3 (1 initial + 2 retries)
    code_attempts: int         # Max 3 for codegen
    codefix_attempts: int      # Max 2 for fixing
    tool_result: Optional[str]
    tool_success: bool
    
    # Error handling
    error_message: Optional[str]
    
    # Tracking/Tracing (for observability)
    agents_used: List[str]
    tools_called: List[dict]   # List of ToolCall as dicts
    model_calls: List[dict]    # List of ModelCall as dicts
    total_tokens: int
    start_time: float
    
    # Memory management
    session_token_count: int   # Cumulative tokens used in this session
    context_status: Literal["ok", "warning", "blocked"]
    context_message: Optional[str]  # Warning or error message for UI
    
    # Final response
    final_response: Optional[str]


def create_initial_state(
    session_id: str, 
    image_data: Optional[str] = None,
    image_data_list: Optional[List[str]] = None
) -> AgentState:
    """Create initial state for a new conversation turn."""
    # Determine starting agent based on images
    has_images = bool(image_data) or bool(image_data_list)
    
    return AgentState(
        messages=[],
        session_id=session_id,
        image_data=image_data,
        image_data_list=image_data_list or [],
        ocr_text=None,
        ocr_results=[],
        current_agent="ocr" if has_images else "planner",
        should_use_tools=False,
        selected_tool=None,
        _tool_query=None,
        execution_plan=None,
        question_results=[],
        wolfram_attempts=0,
        code_attempts=0,
        codefix_attempts=0,
        tool_result=None,
        tool_success=False,
        error_message=None,
        agents_used=[],
        tools_called=[],
        model_calls=[],
        total_tokens=0,
        start_time=time.time(),
        session_token_count=0,
        context_status="ok",
        context_message=None,
        final_response=None,
    )


def add_agent_used(state: AgentState, agent_name: str) -> None:
    """Record that an agent was used."""
    if agent_name not in state["agents_used"]:
        state["agents_used"].append(agent_name)


def add_tool_call(state: AgentState, tool_call: ToolCall) -> None:
    """Record a tool call."""
    state["tools_called"].append({
        "tool": tool_call.tool,
        "input": tool_call.input,
        "output": tool_call.output,
        "success": tool_call.success,
        "attempt": tool_call.attempt,
        "duration_ms": tool_call.duration_ms,
        "error": tool_call.error,
    })


def add_model_call(state: AgentState, model_call: ModelCall) -> None:
    """Record a model call."""
    state["model_calls"].append({
        "model": model_call.model,
        "agent": model_call.agent,
        "tokens_in": model_call.tokens_in,
        "tokens_out": model_call.tokens_out,
        "duration_ms": model_call.duration_ms,
        "success": model_call.success,
        "error": model_call.error,
    })
    state["total_tokens"] += model_call.tokens_in + model_call.tokens_out


def get_total_duration_ms(state: AgentState) -> int:
    """Get total duration since start."""
    start_time = state.get("start_time")
    if start_time is None:
        return 0
    return int((time.time() - start_time) * 1000)