rohitdeshmukh318's picture
initial commit
abd4352
"""
agent/trace.py
Real-time agent execution tracing.
Collects timing, token usage, and status for each node in the LangGraph pipeline.
Trace events are streamed to the frontend via SSE for live pipeline visualization.
"""
import time
import threading
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional
@dataclass
class TraceEvent:
"""A single trace event from a node execution."""
node: str
status: str # "started" | "completed" | "failed"
latency_ms: int = 0
tokens_used: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> Dict[str, Any]:
d = asdict(self)
d["type"] = "trace"
return d
class AgentTracer:
"""
Collects trace events for a single agent run.
Thread-safe for use across LangGraph nodes.
"""
def __init__(self):
self.events: List[TraceEvent] = []
self._lock = threading.Lock()
self._timers: Dict[str, float] = {}
def start_node(self, node: str, metadata: Optional[Dict[str, Any]] = None):
"""Mark a node as started."""
self._timers[node] = time.time()
event = TraceEvent(
node=node,
status="started",
metadata=metadata or {},
)
with self._lock:
self.events.append(event)
def end_node(
self,
node: str,
tokens_used: int = 0,
metadata: Optional[Dict[str, Any]] = None,
):
"""Mark a node as completed with timing info."""
start = self._timers.pop(node, time.time())
latency_ms = int((time.time() - start) * 1000)
event = TraceEvent(
node=node,
status="completed",
latency_ms=latency_ms,
tokens_used=tokens_used,
metadata=metadata or {},
)
with self._lock:
self.events.append(event)
def fail_node(self, node: str, error: str):
"""Mark a node as failed."""
start = self._timers.pop(node, time.time())
latency_ms = int((time.time() - start) * 1000)
event = TraceEvent(
node=node,
status="failed",
latency_ms=latency_ms,
metadata={"error": error},
)
with self._lock:
self.events.append(event)
def get_events(self) -> List[Dict[str, Any]]:
"""Return all trace events as dicts."""
with self._lock:
return [e.to_dict() for e in self.events]
def get_summary(self) -> Dict[str, Any]:
"""Return a summary of the trace."""
with self._lock:
total_ms = sum(e.latency_ms for e in self.events if e.status == "completed")
total_tokens = sum(e.tokens_used for e in self.events if e.status == "completed")
node_count = len([e for e in self.events if e.status == "completed"])
failed = [e.node for e in self.events if e.status == "failed"]
return {
"total_latency_ms": total_ms,
"total_tokens": total_tokens,
"nodes_executed": node_count,
"failed_nodes": failed,
"events": [e.to_dict() for e in self.events],
}
# ── Thread-local tracer storage ───────────────────────────────────────────────
_tracer_local = threading.local()
def set_tracer(tracer: AgentTracer):
"""Set the tracer for the current thread."""
_tracer_local.tracer = tracer
def get_tracer() -> Optional[AgentTracer]:
"""Get the tracer for the current thread (may be None)."""
return getattr(_tracer_local, "tracer", None)
def trace_node(node_name: str):
"""
Decorator to automatically trace a node function.
Usage:
@trace_node("sql_generator")
def sql_generator(state: AgentState) -> AgentState:
...
"""
def decorator(func):
def wrapper(state, *args, **kwargs):
tracer = get_tracer()
if tracer:
tracer.start_node(node_name)
try:
result = func(state, *args, **kwargs)
if tracer:
tracer.end_node(node_name)
return result
except Exception as e:
if tracer:
tracer.fail_node(node_name, str(e))
raise
wrapper.__name__ = func.__name__
return wrapper
return decorator