Spaces:
Sleeping
Sleeping
| """Request tracing and latency analysis.""" | |
| import logging | |
| import uuid | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Any | |
| from collections import deque | |
| import statistics | |
| from storage.database import MetricsDB | |
| from storage.models import RequestTrace | |
| logger = logging.getLogger(__name__) | |
| class LatencyBreakdown: | |
| """Breakdown of request latency by phase.""" | |
| queue_ms: float | |
| prefill_ms: float | |
| decode_ms: float | |
| total_ms: float | |
| def as_dict(self) -> Dict[str, float]: | |
| return { | |
| "queue": self.queue_ms, | |
| "prefill": self.prefill_ms, | |
| "decode": self.decode_ms, | |
| "total": self.total_ms, | |
| } | |
| class TraceCorrelation: | |
| """Correlation analysis for a trace.""" | |
| memory_pressure: bool | |
| likely_cause: str | |
| memory_delta_gb: float | |
| class RequestTracer: | |
| """Tracks and analyzes request latency.""" | |
| def __init__(self, db: Optional[MetricsDB] = None, p95_window: int = 100): | |
| """ | |
| Initialize request tracer. | |
| Args: | |
| db: Optional database for persisting traces | |
| p95_window: Number of recent requests for P95 calculation | |
| """ | |
| self.db = db | |
| self._traces: deque = deque(maxlen=1000) | |
| self._latency_window: deque = deque(maxlen=p95_window) | |
| self._baseline_p95: Optional[float] = None | |
| self._slow_threshold_ms: Optional[float] = None | |
| def record_trace( | |
| self, | |
| request_id: Optional[str] = None, | |
| prompt_tokens: int = 0, | |
| output_tokens: int = 0, | |
| queue_time_ms: float = 0, | |
| prefill_time_ms: float = 0, | |
| decode_time_ms: float = 0, | |
| total_time_ms: Optional[float] = None, | |
| gpu_memory_start: float = 0, | |
| gpu_memory_end: float = 0, | |
| ) -> RequestTrace: | |
| """ | |
| Record a request trace. | |
| Args: | |
| request_id: Unique request identifier | |
| prompt_tokens: Number of prompt tokens | |
| output_tokens: Number of output tokens | |
| queue_time_ms: Time spent in queue | |
| prefill_time_ms: Time for prefill/prompt processing | |
| decode_time_ms: Time for token generation | |
| total_time_ms: Total end-to-end time | |
| gpu_memory_start: GPU memory at request start | |
| gpu_memory_end: GPU memory at request end | |
| Returns: | |
| Created RequestTrace | |
| """ | |
| if request_id is None: | |
| request_id = str(uuid.uuid4())[:8] | |
| if total_time_ms is None: | |
| total_time_ms = queue_time_ms + prefill_time_ms + decode_time_ms | |
| # Calculate tokens per second | |
| tokens_per_sec = 0 | |
| if decode_time_ms > 0: | |
| tokens_per_sec = (output_tokens / decode_time_ms) * 1000 | |
| # Determine if slow | |
| is_slow = False | |
| if self._slow_threshold_ms and total_time_ms > self._slow_threshold_ms: | |
| is_slow = True | |
| trace = RequestTrace( | |
| request_id=request_id, | |
| prompt_tokens=prompt_tokens, | |
| output_tokens=output_tokens, | |
| queue_time_ms=queue_time_ms, | |
| prefill_time_ms=prefill_time_ms, | |
| decode_time_ms=decode_time_ms, | |
| total_time_ms=total_time_ms, | |
| tokens_per_second=tokens_per_sec, | |
| gpu_memory_at_start=gpu_memory_start, | |
| gpu_memory_at_end=gpu_memory_end, | |
| is_slow=is_slow, | |
| ) | |
| # Store in memory | |
| self._traces.append(trace) | |
| self._latency_window.append(total_time_ms) | |
| # Update P95 baseline | |
| self._update_baseline() | |
| # Persist to database | |
| if self.db: | |
| try: | |
| self.db.insert_trace(trace) | |
| except Exception as e: | |
| logger.error(f"Error persisting trace: {e}") | |
| # Log slow requests | |
| if is_slow: | |
| logger.warning( | |
| f"Slow request {request_id}: {total_time_ms:.1f}ms " | |
| f"(threshold: {self._slow_threshold_ms:.1f}ms)" | |
| ) | |
| return trace | |
| def _update_baseline(self) -> None: | |
| """Update P95 baseline from recent requests.""" | |
| if len(self._latency_window) >= 10: | |
| sorted_latencies = sorted(self._latency_window) | |
| p95_idx = int(len(sorted_latencies) * 0.95) | |
| self._baseline_p95 = sorted_latencies[p95_idx] | |
| # Set slow threshold at 1.5x P95 | |
| self._slow_threshold_ms = self._baseline_p95 * 1.5 | |
| def get_recent_traces( | |
| self, limit: int = 100, slow_only: bool = False | |
| ) -> List[RequestTrace]: | |
| """ | |
| Get recent traces. | |
| Args: | |
| limit: Maximum number of traces | |
| slow_only: Only return slow requests | |
| Returns: | |
| List of RequestTrace objects | |
| """ | |
| traces = list(self._traces) | |
| if slow_only: | |
| traces = [t for t in traces if t.is_slow] | |
| return traces[-limit:] | |
| def get_latency_breakdown(self) -> LatencyBreakdown: | |
| """ | |
| Get average latency breakdown. | |
| Returns: | |
| LatencyBreakdown with average times | |
| """ | |
| if not self._traces: | |
| return LatencyBreakdown(0, 0, 0, 0) | |
| recent = list(self._traces)[-100:] | |
| return LatencyBreakdown( | |
| queue_ms=statistics.mean(t.queue_time_ms for t in recent), | |
| prefill_ms=statistics.mean(t.prefill_time_ms for t in recent), | |
| decode_ms=statistics.mean(t.decode_time_ms for t in recent), | |
| total_ms=statistics.mean(t.total_time_ms for t in recent), | |
| ) | |
| def correlate_with_gpu_pressure(self, trace: RequestTrace) -> TraceCorrelation: | |
| """ | |
| Correlate trace latency with GPU memory pressure. | |
| Args: | |
| trace: Request trace to analyze | |
| Returns: | |
| TraceCorrelation analysis | |
| """ | |
| memory_delta = trace.gpu_memory_at_end - trace.gpu_memory_at_start | |
| # Determine likely cause based on patterns | |
| if memory_delta > 2.0: | |
| cause = "batch_contention" | |
| elif trace.queue_time_ms > trace.total_time_ms * 0.3: | |
| cause = "queue_congestion" | |
| elif trace.prefill_time_ms > trace.decode_time_ms * 2: | |
| cause = "long_prompt" | |
| else: | |
| cause = "normal" | |
| return TraceCorrelation( | |
| memory_pressure=memory_delta > 1.0, | |
| likely_cause=cause, | |
| memory_delta_gb=memory_delta, | |
| ) | |
| def get_percentiles(self) -> Dict[str, float]: | |
| """ | |
| Get latency percentiles. | |
| Returns: | |
| Dictionary with P50, P95, P99 values | |
| """ | |
| if not self._latency_window: | |
| return {"p50": 0, "p95": 0, "p99": 0} | |
| sorted_latencies = sorted(self._latency_window) | |
| n = len(sorted_latencies) | |
| return { | |
| "p50": sorted_latencies[int(n * 0.50)], | |
| "p95": sorted_latencies[int(n * 0.95)], | |
| "p99": sorted_latencies[min(int(n * 0.99), n - 1)], | |
| } | |
| def get_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get comprehensive statistics. | |
| Returns: | |
| Dictionary with various stats | |
| """ | |
| if not self._traces: | |
| return { | |
| "total_requests": 0, | |
| "slow_requests": 0, | |
| "avg_latency_ms": 0, | |
| "percentiles": {"p50": 0, "p95": 0, "p99": 0}, | |
| "breakdown": {"queue": 0, "prefill": 0, "decode": 0}, | |
| } | |
| traces = list(self._traces) | |
| slow_count = sum(1 for t in traces if t.is_slow) | |
| breakdown = self.get_latency_breakdown() | |
| return { | |
| "total_requests": len(traces), | |
| "slow_requests": slow_count, | |
| "slow_rate_percent": (slow_count / len(traces)) * 100, | |
| "avg_latency_ms": breakdown.total_ms, | |
| "percentiles": self.get_percentiles(), | |
| "breakdown": breakdown.as_dict, | |
| "baseline_p95": self._baseline_p95, | |
| } | |
| def clear(self) -> None: | |
| """Clear all traces.""" | |
| self._traces.clear() | |
| self._latency_window.clear() | |
| self._baseline_p95 = None | |
| self._slow_threshold_ms = None | |