|
|
""" |
|
|
Trace reader for querying LangFuse observability data. |
|
|
|
|
|
Provides Python API for programmatic access to traces, spans, and generations. |
|
|
""" |
|
|
import logging |
|
|
from typing import List, Optional, Dict, Any |
|
|
from datetime import datetime, timedelta |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from utils.langfuse_client import get_langfuse_client, is_langfuse_enabled |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TraceInfo(BaseModel): |
|
|
"""Pydantic model for trace information.""" |
|
|
id: str |
|
|
name: str |
|
|
user_id: Optional[str] = None |
|
|
session_id: Optional[str] = None |
|
|
timestamp: datetime |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict) |
|
|
input: Optional[Any] = None |
|
|
output: Optional[Any] = None |
|
|
duration_ms: Optional[float] = None |
|
|
total_cost: Optional[float] = None |
|
|
token_usage: Dict[str, int] = Field(default_factory=dict) |
|
|
|
|
|
|
|
|
class SpanInfo(BaseModel): |
|
|
"""Pydantic model for span information.""" |
|
|
id: str |
|
|
trace_id: str |
|
|
name: str |
|
|
start_time: datetime |
|
|
end_time: Optional[datetime] = None |
|
|
duration_ms: Optional[float] = None |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict) |
|
|
input: Optional[Any] = None |
|
|
output: Optional[Any] = None |
|
|
level: str = "DEFAULT" |
|
|
|
|
|
|
|
|
class GenerationInfo(BaseModel): |
|
|
"""Pydantic model for LLM generation information.""" |
|
|
id: str |
|
|
trace_id: str |
|
|
name: str |
|
|
model: Optional[str] = None |
|
|
prompt: Optional[str] = None |
|
|
completion: Optional[str] = None |
|
|
usage: Dict[str, int] = Field(default_factory=dict) |
|
|
cost: Optional[float] = None |
|
|
start_time: datetime |
|
|
end_time: Optional[datetime] = None |
|
|
duration_ms: Optional[float] = None |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
|
|
|
|
class TraceReader: |
|
|
""" |
|
|
Read and query LangFuse traces programmatically. |
|
|
|
|
|
Usage: |
|
|
reader = TraceReader() |
|
|
traces = reader.get_traces(limit=10) |
|
|
trace = reader.get_trace_by_id("trace-123") |
|
|
agent_traces = reader.filter_by_agent("retriever_agent") |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize trace reader with LangFuse client.""" |
|
|
if not is_langfuse_enabled(): |
|
|
logger.warning("LangFuse is not enabled. TraceReader will return empty results.") |
|
|
self.client = None |
|
|
else: |
|
|
self.client = get_langfuse_client() |
|
|
logger.info("TraceReader initialized with LangFuse client") |
|
|
|
|
|
def get_traces( |
|
|
self, |
|
|
limit: int = 50, |
|
|
user_id: Optional[str] = None, |
|
|
session_id: Optional[str] = None, |
|
|
from_timestamp: Optional[datetime] = None, |
|
|
to_timestamp: Optional[datetime] = None, |
|
|
) -> List[TraceInfo]: |
|
|
""" |
|
|
Get traces with optional filters. |
|
|
|
|
|
Args: |
|
|
limit: Maximum number of traces to return |
|
|
user_id: Filter by user ID |
|
|
session_id: Filter by session ID |
|
|
from_timestamp: Filter traces after this timestamp |
|
|
to_timestamp: Filter traces before this timestamp |
|
|
|
|
|
Returns: |
|
|
List of TraceInfo objects |
|
|
""" |
|
|
if not self.client: |
|
|
logger.warning("LangFuse client not available") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
params = {"limit": limit} |
|
|
if user_id: |
|
|
params["user_id"] = user_id |
|
|
if session_id: |
|
|
params["session_id"] = session_id |
|
|
if from_timestamp: |
|
|
params["from_timestamp"] = from_timestamp |
|
|
if to_timestamp: |
|
|
params["to_timestamp"] = to_timestamp |
|
|
|
|
|
|
|
|
traces_data = self.client.get_traces(**params) |
|
|
|
|
|
|
|
|
traces = [] |
|
|
for trace in traces_data.data: |
|
|
trace_info = TraceInfo( |
|
|
id=trace.id, |
|
|
name=trace.name, |
|
|
user_id=trace.user_id, |
|
|
session_id=trace.session_id, |
|
|
timestamp=trace.timestamp, |
|
|
metadata=trace.metadata or {}, |
|
|
input=trace.input, |
|
|
output=trace.output, |
|
|
duration_ms=self._calculate_duration(trace), |
|
|
total_cost=getattr(trace, "total_cost", None), |
|
|
token_usage=self._extract_token_usage(trace), |
|
|
) |
|
|
traces.append(trace_info) |
|
|
|
|
|
logger.info(f"Retrieved {len(traces)} traces") |
|
|
return traces |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error fetching traces: {e}") |
|
|
return [] |
|
|
|
|
|
def get_trace_by_id(self, trace_id: str) -> Optional[TraceInfo]: |
|
|
""" |
|
|
Get a specific trace by ID. |
|
|
|
|
|
Args: |
|
|
trace_id: Trace identifier |
|
|
|
|
|
Returns: |
|
|
TraceInfo object or None if not found |
|
|
""" |
|
|
if not self.client: |
|
|
logger.warning("LangFuse client not available") |
|
|
return None |
|
|
|
|
|
try: |
|
|
trace = self.client.get_trace(trace_id) |
|
|
|
|
|
if not trace: |
|
|
logger.warning(f"Trace {trace_id} not found") |
|
|
return None |
|
|
|
|
|
trace_info = TraceInfo( |
|
|
id=trace.id, |
|
|
name=trace.name, |
|
|
user_id=trace.user_id, |
|
|
session_id=trace.session_id, |
|
|
timestamp=trace.timestamp, |
|
|
metadata=trace.metadata or {}, |
|
|
input=trace.input, |
|
|
output=trace.output, |
|
|
duration_ms=self._calculate_duration(trace), |
|
|
total_cost=getattr(trace, "total_cost", None), |
|
|
token_usage=self._extract_token_usage(trace), |
|
|
) |
|
|
|
|
|
logger.info(f"Retrieved trace {trace_id}") |
|
|
return trace_info |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error fetching trace {trace_id}: {e}") |
|
|
return None |
|
|
|
|
|
def filter_by_agent( |
|
|
self, |
|
|
agent_name: str, |
|
|
limit: int = 50, |
|
|
from_timestamp: Optional[datetime] = None, |
|
|
) -> List[SpanInfo]: |
|
|
""" |
|
|
Filter traces by agent name. |
|
|
|
|
|
Args: |
|
|
agent_name: Name of the agent (e.g., "retriever_agent", "analyzer_agent") |
|
|
limit: Maximum number of results |
|
|
from_timestamp: Filter traces after this timestamp |
|
|
|
|
|
Returns: |
|
|
List of SpanInfo objects for the specified agent |
|
|
""" |
|
|
if not self.client: |
|
|
logger.warning("LangFuse client not available") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
params = {"limit": limit, "name": agent_name, "type": "SPAN"} |
|
|
if from_timestamp: |
|
|
params["from_timestamp"] = from_timestamp |
|
|
|
|
|
observations = self.client.get_observations(**params) |
|
|
|
|
|
spans = [] |
|
|
for obs in observations.data: |
|
|
span_info = SpanInfo( |
|
|
id=obs.id, |
|
|
trace_id=obs.trace_id, |
|
|
name=obs.name, |
|
|
start_time=obs.start_time, |
|
|
end_time=obs.end_time, |
|
|
duration_ms=self._calculate_duration(obs), |
|
|
metadata=obs.metadata or {}, |
|
|
input=obs.input, |
|
|
output=obs.output, |
|
|
level=getattr(obs, "level", "DEFAULT"), |
|
|
) |
|
|
spans.append(span_info) |
|
|
|
|
|
logger.info(f"Retrieved {len(spans)} spans for agent '{agent_name}'") |
|
|
return spans |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error filtering by agent {agent_name}: {e}") |
|
|
return [] |
|
|
|
|
|
def filter_by_date_range( |
|
|
self, |
|
|
from_date: datetime, |
|
|
to_date: datetime, |
|
|
limit: int = 100, |
|
|
) -> List[TraceInfo]: |
|
|
""" |
|
|
Filter traces by date range. |
|
|
|
|
|
Args: |
|
|
from_date: Start date |
|
|
to_date: End date |
|
|
limit: Maximum number of traces |
|
|
|
|
|
Returns: |
|
|
List of TraceInfo objects within date range |
|
|
""" |
|
|
return self.get_traces( |
|
|
limit=limit, |
|
|
from_timestamp=from_date, |
|
|
to_timestamp=to_date, |
|
|
) |
|
|
|
|
|
def get_generations( |
|
|
self, |
|
|
trace_id: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
) -> List[GenerationInfo]: |
|
|
""" |
|
|
Get LLM generations (optionally filtered by trace). |
|
|
|
|
|
Args: |
|
|
trace_id: Optional trace ID to filter generations |
|
|
limit: Maximum number of generations |
|
|
|
|
|
Returns: |
|
|
List of GenerationInfo objects |
|
|
""" |
|
|
if not self.client: |
|
|
logger.warning("LangFuse client not available") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
params = {"limit": limit, "type": "GENERATION"} |
|
|
if trace_id: |
|
|
params["trace_id"] = trace_id |
|
|
|
|
|
observations = self.client.get_observations(**params) |
|
|
|
|
|
generations = [] |
|
|
for obs in observations.data: |
|
|
gen_info = GenerationInfo( |
|
|
id=obs.id, |
|
|
trace_id=obs.trace_id, |
|
|
name=obs.name, |
|
|
model=getattr(obs, "model", None), |
|
|
prompt=getattr(obs, "input", None), |
|
|
completion=getattr(obs, "output", None), |
|
|
usage=self._extract_token_usage(obs), |
|
|
cost=getattr(obs, "calculated_total_cost", None), |
|
|
start_time=obs.start_time, |
|
|
end_time=obs.end_time, |
|
|
duration_ms=self._calculate_duration(obs), |
|
|
metadata=obs.metadata or {}, |
|
|
) |
|
|
generations.append(gen_info) |
|
|
|
|
|
logger.info(f"Retrieved {len(generations)} generations") |
|
|
return generations |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error fetching generations: {e}") |
|
|
return [] |
|
|
|
|
|
def export_traces_to_json( |
|
|
self, |
|
|
traces: List[TraceInfo], |
|
|
output_file: str, |
|
|
) -> bool: |
|
|
""" |
|
|
Export traces to JSON file. |
|
|
|
|
|
Args: |
|
|
traces: List of TraceInfo objects |
|
|
output_file: Path to output JSON file |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
import json |
|
|
|
|
|
data = [trace.dict() for trace in traces] |
|
|
|
|
|
with open(output_file, 'w') as f: |
|
|
json.dump(data, f, indent=2, default=str) |
|
|
|
|
|
logger.info(f"Exported {len(traces)} traces to {output_file}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error exporting traces: {e}") |
|
|
return False |
|
|
|
|
|
def export_traces_to_csv( |
|
|
self, |
|
|
traces: List[TraceInfo], |
|
|
output_file: str, |
|
|
) -> bool: |
|
|
""" |
|
|
Export traces to CSV file. |
|
|
|
|
|
Args: |
|
|
traces: List of TraceInfo objects |
|
|
output_file: Path to output CSV file |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
import csv |
|
|
|
|
|
if not traces: |
|
|
logger.warning("No traces to export") |
|
|
return False |
|
|
|
|
|
|
|
|
fieldnames = [ |
|
|
"id", "name", "user_id", "session_id", "timestamp", |
|
|
"duration_ms", "total_cost", "input_tokens", "output_tokens" |
|
|
] |
|
|
|
|
|
with open(output_file, 'w', newline='') as f: |
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames) |
|
|
writer.writeheader() |
|
|
|
|
|
for trace in traces: |
|
|
row = { |
|
|
"id": trace.id, |
|
|
"name": trace.name, |
|
|
"user_id": trace.user_id or "", |
|
|
"session_id": trace.session_id or "", |
|
|
"timestamp": trace.timestamp.isoformat(), |
|
|
"duration_ms": trace.duration_ms or 0, |
|
|
"total_cost": trace.total_cost or 0, |
|
|
"input_tokens": trace.token_usage.get("input", 0), |
|
|
"output_tokens": trace.token_usage.get("output", 0), |
|
|
} |
|
|
writer.writerow(row) |
|
|
|
|
|
logger.info(f"Exported {len(traces)} traces to {output_file}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error exporting traces to CSV: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def _calculate_duration(self, obj: Any) -> Optional[float]: |
|
|
"""Calculate duration in milliseconds from start and end times.""" |
|
|
try: |
|
|
if hasattr(obj, 'start_time') and hasattr(obj, 'end_time') and obj.end_time: |
|
|
duration = (obj.end_time - obj.start_time).total_seconds() * 1000 |
|
|
return duration |
|
|
return None |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def _extract_token_usage(self, obj: Any) -> Dict[str, int]: |
|
|
"""Extract token usage from observation.""" |
|
|
usage = {} |
|
|
try: |
|
|
if hasattr(obj, 'usage') and obj.usage: |
|
|
usage["input"] = getattr(obj.usage, "prompt_tokens", 0) or getattr(obj.usage, "input", 0) |
|
|
usage["output"] = getattr(obj.usage, "completion_tokens", 0) or getattr(obj.usage, "output", 0) |
|
|
usage["total"] = getattr(obj.usage, "total_tokens", 0) or getattr(obj.usage, "total", 0) |
|
|
except Exception: |
|
|
pass |
|
|
return usage |
|
|
|