llm-inference-dashboard / services /request_tracer.py
jkottu's picture
Initial commit: LLM Inference Dashboard
aefabf0
"""Request tracing and latency analysis."""
import logging
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional, Any
from collections import deque
import statistics
from storage.database import MetricsDB
from storage.models import RequestTrace
logger = logging.getLogger(__name__)
@dataclass
class LatencyBreakdown:
"""Breakdown of request latency by phase."""
queue_ms: float
prefill_ms: float
decode_ms: float
total_ms: float
@property
def as_dict(self) -> Dict[str, float]:
return {
"queue": self.queue_ms,
"prefill": self.prefill_ms,
"decode": self.decode_ms,
"total": self.total_ms,
}
@dataclass
class TraceCorrelation:
"""Correlation analysis for a trace."""
memory_pressure: bool
likely_cause: str
memory_delta_gb: float
class RequestTracer:
"""Tracks and analyzes request latency."""
def __init__(self, db: Optional[MetricsDB] = None, p95_window: int = 100):
"""
Initialize request tracer.
Args:
db: Optional database for persisting traces
p95_window: Number of recent requests for P95 calculation
"""
self.db = db
self._traces: deque = deque(maxlen=1000)
self._latency_window: deque = deque(maxlen=p95_window)
self._baseline_p95: Optional[float] = None
self._slow_threshold_ms: Optional[float] = None
def record_trace(
self,
request_id: Optional[str] = None,
prompt_tokens: int = 0,
output_tokens: int = 0,
queue_time_ms: float = 0,
prefill_time_ms: float = 0,
decode_time_ms: float = 0,
total_time_ms: Optional[float] = None,
gpu_memory_start: float = 0,
gpu_memory_end: float = 0,
) -> RequestTrace:
"""
Record a request trace.
Args:
request_id: Unique request identifier
prompt_tokens: Number of prompt tokens
output_tokens: Number of output tokens
queue_time_ms: Time spent in queue
prefill_time_ms: Time for prefill/prompt processing
decode_time_ms: Time for token generation
total_time_ms: Total end-to-end time
gpu_memory_start: GPU memory at request start
gpu_memory_end: GPU memory at request end
Returns:
Created RequestTrace
"""
if request_id is None:
request_id = str(uuid.uuid4())[:8]
if total_time_ms is None:
total_time_ms = queue_time_ms + prefill_time_ms + decode_time_ms
# Calculate tokens per second
tokens_per_sec = 0
if decode_time_ms > 0:
tokens_per_sec = (output_tokens / decode_time_ms) * 1000
# Determine if slow
is_slow = False
if self._slow_threshold_ms and total_time_ms > self._slow_threshold_ms:
is_slow = True
trace = RequestTrace(
request_id=request_id,
prompt_tokens=prompt_tokens,
output_tokens=output_tokens,
queue_time_ms=queue_time_ms,
prefill_time_ms=prefill_time_ms,
decode_time_ms=decode_time_ms,
total_time_ms=total_time_ms,
tokens_per_second=tokens_per_sec,
gpu_memory_at_start=gpu_memory_start,
gpu_memory_at_end=gpu_memory_end,
is_slow=is_slow,
)
# Store in memory
self._traces.append(trace)
self._latency_window.append(total_time_ms)
# Update P95 baseline
self._update_baseline()
# Persist to database
if self.db:
try:
self.db.insert_trace(trace)
except Exception as e:
logger.error(f"Error persisting trace: {e}")
# Log slow requests
if is_slow:
logger.warning(
f"Slow request {request_id}: {total_time_ms:.1f}ms "
f"(threshold: {self._slow_threshold_ms:.1f}ms)"
)
return trace
def _update_baseline(self) -> None:
"""Update P95 baseline from recent requests."""
if len(self._latency_window) >= 10:
sorted_latencies = sorted(self._latency_window)
p95_idx = int(len(sorted_latencies) * 0.95)
self._baseline_p95 = sorted_latencies[p95_idx]
# Set slow threshold at 1.5x P95
self._slow_threshold_ms = self._baseline_p95 * 1.5
def get_recent_traces(
self, limit: int = 100, slow_only: bool = False
) -> List[RequestTrace]:
"""
Get recent traces.
Args:
limit: Maximum number of traces
slow_only: Only return slow requests
Returns:
List of RequestTrace objects
"""
traces = list(self._traces)
if slow_only:
traces = [t for t in traces if t.is_slow]
return traces[-limit:]
def get_latency_breakdown(self) -> LatencyBreakdown:
"""
Get average latency breakdown.
Returns:
LatencyBreakdown with average times
"""
if not self._traces:
return LatencyBreakdown(0, 0, 0, 0)
recent = list(self._traces)[-100:]
return LatencyBreakdown(
queue_ms=statistics.mean(t.queue_time_ms for t in recent),
prefill_ms=statistics.mean(t.prefill_time_ms for t in recent),
decode_ms=statistics.mean(t.decode_time_ms for t in recent),
total_ms=statistics.mean(t.total_time_ms for t in recent),
)
def correlate_with_gpu_pressure(self, trace: RequestTrace) -> TraceCorrelation:
"""
Correlate trace latency with GPU memory pressure.
Args:
trace: Request trace to analyze
Returns:
TraceCorrelation analysis
"""
memory_delta = trace.gpu_memory_at_end - trace.gpu_memory_at_start
# Determine likely cause based on patterns
if memory_delta > 2.0:
cause = "batch_contention"
elif trace.queue_time_ms > trace.total_time_ms * 0.3:
cause = "queue_congestion"
elif trace.prefill_time_ms > trace.decode_time_ms * 2:
cause = "long_prompt"
else:
cause = "normal"
return TraceCorrelation(
memory_pressure=memory_delta > 1.0,
likely_cause=cause,
memory_delta_gb=memory_delta,
)
def get_percentiles(self) -> Dict[str, float]:
"""
Get latency percentiles.
Returns:
Dictionary with P50, P95, P99 values
"""
if not self._latency_window:
return {"p50": 0, "p95": 0, "p99": 0}
sorted_latencies = sorted(self._latency_window)
n = len(sorted_latencies)
return {
"p50": sorted_latencies[int(n * 0.50)],
"p95": sorted_latencies[int(n * 0.95)],
"p99": sorted_latencies[min(int(n * 0.99), n - 1)],
}
def get_stats(self) -> Dict[str, Any]:
"""
Get comprehensive statistics.
Returns:
Dictionary with various stats
"""
if not self._traces:
return {
"total_requests": 0,
"slow_requests": 0,
"avg_latency_ms": 0,
"percentiles": {"p50": 0, "p95": 0, "p99": 0},
"breakdown": {"queue": 0, "prefill": 0, "decode": 0},
}
traces = list(self._traces)
slow_count = sum(1 for t in traces if t.is_slow)
breakdown = self.get_latency_breakdown()
return {
"total_requests": len(traces),
"slow_requests": slow_count,
"slow_rate_percent": (slow_count / len(traces)) * 100,
"avg_latency_ms": breakdown.total_ms,
"percentiles": self.get_percentiles(),
"breakdown": breakdown.as_dict,
"baseline_p95": self._baseline_p95,
}
def clear(self) -> None:
"""Clear all traces."""
self._traces.clear()
self._latency_window.clear()
self._baseline_p95 = None
self._slow_threshold_ms = None