jkottu Claude Opus 4.5 commited on
Commit
aefabf0
·
0 Parent(s):

Initial commit: LLM Inference Dashboard

Browse files

A production-grade Gradio dashboard for monitoring vLLM inference
on multi-GPU setups with:
- GPU/Rank monitoring (memory, utilization, temperature)
- Inference metrics (tokens/sec, TTFT, KV cache)
- Quantization detection (GPTQ, AWQ, bitsandbytes)
- Model loading progress tracking
- Alerting with Slack/webhook integration
- Request tracing with latency breakdown
- A/B deployment comparison
- Built-in load testing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

.gitignore ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ ENV/
26
+ env/
27
+ .venv/
28
+
29
+ # IDE
30
+ .idea/
31
+ .vscode/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # Data files
37
+ data/*.db
38
+ *.sqlite
39
+ *.sqlite3
40
+
41
+ # Logs
42
+ *.log
43
+
44
+ # OS
45
+ .DS_Store
46
+ Thumbs.db
47
+
48
+ # Environment
49
+ .env
50
+ .env.local
51
+
52
+ # Claude
53
+ .claude/
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LLM Inference Dashboard
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # LLM Inference Dashboard
14
+
15
+ A production-grade Gradio dashboard for monitoring vLLM inference on multi-GPU setups with alerting, request tracing, A/B comparison, load testing, and historical analysis.
16
+
17
+ ## Features
18
+
19
+ | Feature | Description |
20
+ |---------|-------------|
21
+ | Core Monitoring | GPU stats, inference metrics, quantization info |
22
+ | Alerting | Configurable thresholds, Slack/webhook notifications |
23
+ | Request Tracing | Per-request latency breakdown, slow request logging |
24
+ | A/B Comparison | Side-by-side deployment comparison |
25
+ | Load Testing | Built-in load generator with saturation detection |
26
+ | Historical Analysis | SQLite storage, trend queries |
27
+
28
+ ## Tabs
29
+
30
+ 1. **GPU / Rank Status** - Real-time GPU memory, utilization, temperature, and tensor parallel rank mapping
31
+ 2. **Inference** - Tokens/sec, TTFT, batch size, KV cache utilization, latency metrics
32
+ 3. **Quantization** - Detect and display GPTQ, AWQ, bitsandbytes quantization settings
33
+ 4. **Loading** - Model loading progress with shard tracking
34
+ 5. **Alerts** - Configure alert thresholds and webhook notifications
35
+ 6. **Tracing** - Request-level latency breakdown and slow request analysis
36
+ 7. **A/B Compare** - Compare metrics between two vLLM deployments
37
+ 8. **Load Test** - Run load tests with configurable concurrency and RPS
38
+
39
+ ## Usage
40
+
41
+ ### Local Development
42
+
43
+ ```bash
44
+ pip install -r requirements.txt
45
+ python app.py
46
+ ```
47
+
48
+ ### With vLLM Server
49
+
50
+ ```bash
51
+ # Start vLLM server
52
+ python -m vllm.entrypoints.openai.api_server \
53
+ --model <model_name> \
54
+ --tensor-parallel-size <N> \
55
+ --port 8000
56
+
57
+ # Set environment variables (optional)
58
+ export VLLM_HOST=localhost
59
+ export VLLM_PORT=8000
60
+
61
+ # Launch dashboard
62
+ python app.py
63
+ ```
64
+
65
+ ## Environment Variables
66
+
67
+ | Variable | Default | Description |
68
+ |----------|---------|-------------|
69
+ | `VLLM_HOST` | localhost | vLLM server hostname |
70
+ | `VLLM_PORT` | 8000 | vLLM server port |
71
+ | `MODEL_PATH` | None | Path to model for quantization detection |
72
+ | `DB_PATH` | data/metrics.db | SQLite database path |
73
+ | `SLACK_WEBHOOK` | None | Slack webhook URL for alerts |
74
+ | `PAGERDUTY_KEY` | None | PagerDuty routing key |
75
+
76
+ ## Demo Mode
77
+
78
+ When no vLLM server is connected, the dashboard runs in demo mode with simulated GPU metrics.
79
+
80
+ ## Architecture
81
+
82
+ ```
83
+ ┌─────────────────────────────────────────────────────────┐
84
+ │ Gradio Frontend │
85
+ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────────────┐│
86
+ │ │GPU Stats│ │Loading │ │Quant │ │Inference Metrics││
87
+ │ │ Tab │ │Progress │ │Details │ │ Tab ││
88
+ │ └─────────┘ └─────────┘ └─────────┘ └─────────────────┘│
89
+ └─────────────────────────────────────────────────────────┘
90
+
91
+
92
+ ┌─────────────────────────────────────────────────────────┐
93
+ │ Metrics Collector │
94
+ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌────────────┐ │
95
+ │ │ pynvml │ │Prometheus│ │ vLLM API │ │Model Config│ │
96
+ │ │ (GPUs) │ │ (/metrics)│ │ (status) │ │ (quant) │ │
97
+ │ └──────────┘ └──────────┘ └──────────┘ └────────────┘ │
98
+ └─────────────────────────────────────────────────────────┘
99
+ ```
100
+
101
+ ## License
102
+
103
+ MIT
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Inference Dashboard - Main Application
3
+
4
+ A production-grade Gradio dashboard for monitoring vLLM inference
5
+ on multi-GPU setups with alerting, request tracing, A/B comparison,
6
+ load testing, and historical analysis.
7
+ """
8
+
9
+ import asyncio
10
+ import logging
11
+ import os
12
+ from datetime import datetime
13
+
14
+ import gradio as gr
15
+
16
+ from config import config
17
+ from collectors import GPUCollector, VLLMCollector, QuantizationCollector, LoadingTracker
18
+ from components import (
19
+ create_gpu_panel,
20
+ update_gpu_panel,
21
+ create_inference_panel,
22
+ update_inference_panel,
23
+ create_quant_panel,
24
+ update_quant_panel,
25
+ create_loading_panel,
26
+ update_loading_panel,
27
+ create_alerts_panel,
28
+ update_alerts_panel,
29
+ create_tracing_panel,
30
+ update_tracing_panel,
31
+ create_comparison_panel,
32
+ create_loadtest_panel,
33
+ )
34
+ from components.alerts_panel import get_alert_badge_html
35
+ from components.inference_panel import get_metrics_dict
36
+ from services import AlertEngine, AlertDispatcher, RequestTracer
37
+ from storage import MetricsDB
38
+ from utils import MetricHistory
39
+
40
+ # Configure logging
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ # Initialize global instances
49
+ db = MetricsDB(config.db_path)
50
+ history = MetricHistory(max_length=config.history_length)
51
+
52
+ # Collectors
53
+ gpu_collector = GPUCollector()
54
+ vllm_collector = VLLMCollector(config.metrics_endpoint)
55
+ quant_collector = QuantizationCollector(config.model_path)
56
+ loading_tracker = LoadingTracker(config.model_path)
57
+
58
+ # Services
59
+ alert_engine = AlertEngine(db)
60
+ alert_dispatcher = AlertDispatcher(
61
+ slack_webhook=config.slack_webhook,
62
+ pagerduty_key=config.pagerduty_routing_key,
63
+ generic_webhooks=config.generic_webhooks,
64
+ )
65
+ request_tracer = RequestTracer(db)
66
+
67
+
68
+ def check_connection():
69
+ """Check connection to vLLM server."""
70
+ connected = vllm_collector.check_connection()
71
+ if connected:
72
+ return (
73
+ '<div style="display: flex; align-items: center;">'
74
+ '<span style="width: 12px; height: 12px; background: #4caf50; '
75
+ 'border-radius: 50%; display: inline-block; margin-right: 8px;"></span>'
76
+ '<span style="color: #2e7d32;">Connected</span></div>'
77
+ )
78
+ return (
79
+ '<div style="display: flex; align-items: center;">'
80
+ '<span style="width: 12px; height: 12px; background: #f44336; '
81
+ 'border-radius: 50%; display: inline-block; margin-right: 8px;"></span>'
82
+ '<span style="color: #c62828;">Disconnected</span></div>'
83
+ )
84
+
85
+
86
+ def get_model_name():
87
+ """Get current model name."""
88
+ metrics = vllm_collector.collect()
89
+ return metrics.model_name or "Demo Mode"
90
+
91
+
92
+ def update_all_metrics():
93
+ """Update all metrics from collectors."""
94
+ # GPU metrics
95
+ gpu_table, gpu_memory_plot, gpu_util_plot, nccl_status = update_gpu_panel(
96
+ gpu_collector, history
97
+ )
98
+
99
+ # Inference metrics
100
+ (
101
+ throughput, ttft, batch_size, kv_cache, throughput_plot,
102
+ prefill_pct, decode_pct, queue_depth, e2e_latency, latency_plot
103
+ ) = update_inference_panel(vllm_collector, history)
104
+
105
+ # Check alerts
106
+ metrics = vllm_collector.collect()
107
+ metrics_dict = get_metrics_dict(metrics)
108
+
109
+ # Add GPU metrics for alerting
110
+ gpu_stats = gpu_collector.collect()
111
+ if gpu_stats:
112
+ max_gpu_memory = max(s.memory_percent for s in gpu_stats)
113
+ metrics_dict["gpu_memory_percent"] = max_gpu_memory
114
+
115
+ new_alerts = alert_engine.evaluate(metrics_dict)
116
+
117
+ # Dispatch new alerts (handle async properly)
118
+ for alert in new_alerts:
119
+ try:
120
+ loop = asyncio.get_event_loop()
121
+ if loop.is_running():
122
+ asyncio.create_task(alert_dispatcher.dispatch(alert))
123
+ else:
124
+ loop.run_until_complete(alert_dispatcher.dispatch(alert))
125
+ except RuntimeError:
126
+ pass # No event loop available
127
+
128
+ # Get alert badge
129
+ active_alerts = alert_engine.get_active_alerts()
130
+ alert_badge = get_alert_badge_html(active_alerts)
131
+
132
+ # Connection status
133
+ connection_status = check_connection()
134
+ model_name = get_model_name()
135
+
136
+ return (
137
+ # Header
138
+ connection_status,
139
+ model_name,
140
+ alert_badge,
141
+ # GPU tab
142
+ gpu_table,
143
+ gpu_memory_plot,
144
+ gpu_util_plot,
145
+ nccl_status,
146
+ # Inference tab
147
+ throughput,
148
+ ttft,
149
+ batch_size,
150
+ kv_cache,
151
+ throughput_plot,
152
+ prefill_pct,
153
+ decode_pct,
154
+ queue_depth,
155
+ e2e_latency,
156
+ latency_plot,
157
+ )
158
+
159
+
160
+ def create_dashboard():
161
+ """Create the main dashboard application."""
162
+
163
+ custom_css = """
164
+ .gradio-container { max-width: 1400px !important; }
165
+ .panel-header { font-size: 1.2em; font-weight: bold; margin-bottom: 10px; }
166
+ """
167
+
168
+ with gr.Blocks(title="LLM Inference Dashboard") as app:
169
+ gr.Markdown("# LLM Inference Dashboard")
170
+ gr.Markdown("*Real-time monitoring for vLLM inference servers*")
171
+
172
+ # Header row: connection status, model info, active alerts
173
+ with gr.Row():
174
+ status_indicator = gr.HTML(
175
+ value=check_connection(),
176
+ label="Connection",
177
+ )
178
+ model_name_display = gr.Textbox(
179
+ label="Model",
180
+ value=get_model_name(),
181
+ interactive=False,
182
+ scale=2,
183
+ )
184
+ active_alerts_display = gr.HTML(
185
+ value=get_alert_badge_html([]),
186
+ label="Alerts",
187
+ )
188
+
189
+ # Main tabs
190
+ with gr.Tabs():
191
+ # Tab 1: GPU Status
192
+ with gr.Tab("GPU / Rank Status"):
193
+ gpu_components = create_gpu_panel(history)
194
+
195
+ # Tab 2: Inference Metrics
196
+ with gr.Tab("Inference"):
197
+ inference_components = create_inference_panel(history)
198
+
199
+ # Tab 3: Quantization
200
+ with gr.Tab("Quantization"):
201
+ quant_components = create_quant_panel()
202
+
203
+ # Initial update
204
+ (
205
+ quant_type, bits, group_size, quant_details, layer_table
206
+ ) = update_quant_panel(quant_collector)
207
+
208
+ quant_components["quant_type"].value = quant_type
209
+ quant_components["bits"].value = bits
210
+ quant_components["group_size"].value = group_size
211
+
212
+ # Tab 4: Loading Progress
213
+ with gr.Tab("Loading"):
214
+ loading_components = create_loading_panel()
215
+
216
+ # Tab 5: Alerts
217
+ with gr.Tab("Alerts"):
218
+ alerts_components = create_alerts_panel(alert_engine, alert_dispatcher)
219
+
220
+ # Tab 6: Request Tracing
221
+ with gr.Tab("Tracing"):
222
+ tracing_components = create_tracing_panel(request_tracer)
223
+
224
+ # Tab 7: A/B Comparison
225
+ with gr.Tab("A/B Compare"):
226
+ comparison_components = create_comparison_panel()
227
+
228
+ # Tab 8: Load Testing
229
+ with gr.Tab("Load Test"):
230
+ loadtest_components = create_loadtest_panel()
231
+
232
+ # Auto-refresh timer
233
+ timer = gr.Timer(config.refresh_interval)
234
+
235
+ # Collect all outputs for timer update
236
+ timer_outputs = [
237
+ # Header
238
+ status_indicator,
239
+ model_name_display,
240
+ active_alerts_display,
241
+ # GPU tab
242
+ gpu_components["gpu_table"],
243
+ gpu_components["gpu_memory_plot"],
244
+ gpu_components["gpu_util_plot"],
245
+ gpu_components["nccl_status"],
246
+ # Inference tab
247
+ inference_components["throughput"],
248
+ inference_components["ttft"],
249
+ inference_components["batch_size"],
250
+ inference_components["kv_cache"],
251
+ inference_components["throughput_plot"],
252
+ inference_components["prefill_pct"],
253
+ inference_components["decode_pct"],
254
+ inference_components["queue_depth"],
255
+ inference_components["e2e_latency"],
256
+ inference_components["latency_plot"],
257
+ ]
258
+
259
+ timer.tick(fn=update_all_metrics, outputs=timer_outputs)
260
+
261
+ # Manual refresh for tabs that don't auto-update
262
+ def refresh_quant():
263
+ return update_quant_panel(quant_collector)
264
+
265
+ def refresh_loading():
266
+ return update_loading_panel(loading_tracker)
267
+
268
+ def refresh_alerts():
269
+ return update_alerts_panel(alert_engine, db)
270
+
271
+ return app
272
+
273
+
274
+ def main():
275
+ """Main entry point."""
276
+ logger.info("Starting LLM Inference Dashboard")
277
+ logger.info(f"vLLM endpoint: {config.metrics_endpoint}")
278
+ logger.info(f"Database: {config.db_path}")
279
+
280
+ # Check initial connection
281
+ if vllm_collector.check_connection():
282
+ logger.info("Successfully connected to vLLM server")
283
+
284
+ # Set model ready if connected
285
+ loading_tracker.set_ready()
286
+
287
+ # Try to detect quantization
288
+ metrics = vllm_collector.collect()
289
+ if metrics.model_name:
290
+ quant_collector.set_model_path(metrics.model_name)
291
+ else:
292
+ logger.warning("Could not connect to vLLM server - dashboard will show mock data")
293
+
294
+ # Create and launch the dashboard
295
+ app = create_dashboard()
296
+
297
+ # Check if running on HuggingFace Spaces
298
+ if os.getenv("SPACE_ID"):
299
+ app.launch()
300
+ else:
301
+ app.launch(
302
+ server_name="0.0.0.0",
303
+ server_port=7860,
304
+ share=False,
305
+ show_error=True,
306
+ )
307
+
308
+
309
+ # For HuggingFace Spaces - create demo instance
310
+ demo = create_dashboard()
311
+
312
+ if __name__ == "__main__":
313
+ main()
collectors/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data collectors for monitoring vLLM inference."""
2
+
3
+ from .gpu_collector import GPUCollector
4
+ from .vllm_collector import VLLMCollector
5
+ from .quant_collector import QuantizationCollector
6
+ from .loading_tracker import LoadingTracker
7
+
8
+ __all__ = [
9
+ "GPUCollector",
10
+ "VLLMCollector",
11
+ "QuantizationCollector",
12
+ "LoadingTracker",
13
+ ]
collectors/gpu_collector.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPU statistics collector using pynvml."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Try to import pynvml, provide mock if unavailable
10
+ try:
11
+ import pynvml
12
+ PYNVML_AVAILABLE = True
13
+ except ImportError:
14
+ PYNVML_AVAILABLE = False
15
+ logger.warning("pynvml not available - GPU stats will be simulated")
16
+
17
+
18
+ @dataclass
19
+ class GPUStats:
20
+ """Statistics for a single GPU."""
21
+ gpu_id: int
22
+ name: str
23
+ memory_used_gb: float
24
+ memory_total_gb: float
25
+ memory_percent: float
26
+ gpu_util_percent: float
27
+ temperature_c: int
28
+ power_watts: float
29
+ power_limit_watts: float
30
+ tp_rank: Optional[int] = None
31
+
32
+
33
+ class GPUCollector:
34
+ """Collects GPU statistics via NVIDIA Management Library."""
35
+
36
+ def __init__(self):
37
+ """Initialize the GPU collector."""
38
+ self._initialized = False
39
+ self._gpu_count = 0
40
+ self._rank_mapping: dict = {}
41
+
42
+ if PYNVML_AVAILABLE:
43
+ try:
44
+ pynvml.nvmlInit()
45
+ self._initialized = True
46
+ self._gpu_count = pynvml.nvmlDeviceGetCount()
47
+ logger.info(f"Initialized pynvml with {self._gpu_count} GPUs")
48
+ except Exception as e:
49
+ logger.error(f"Failed to initialize pynvml: {e}")
50
+
51
+ def set_rank_mapping(self, mapping: dict) -> None:
52
+ """
53
+ Set tensor parallel rank to GPU ID mapping.
54
+
55
+ Args:
56
+ mapping: Dictionary mapping TP rank to GPU ID
57
+ """
58
+ self._rank_mapping = mapping
59
+
60
+ def get_gpu_count(self) -> int:
61
+ """Get the number of available GPUs."""
62
+ return self._gpu_count
63
+
64
+ def collect(self) -> List[GPUStats]:
65
+ """
66
+ Collect stats for all GPUs.
67
+
68
+ Returns:
69
+ List of GPUStats for each GPU
70
+ """
71
+ if not self._initialized:
72
+ return self._get_mock_stats()
73
+
74
+ stats = []
75
+ for i in range(self._gpu_count):
76
+ try:
77
+ stat = self._collect_single_gpu(i)
78
+ stats.append(stat)
79
+ except Exception as e:
80
+ logger.error(f"Error collecting stats for GPU {i}: {e}")
81
+
82
+ return stats
83
+
84
+ def _collect_single_gpu(self, gpu_id: int) -> GPUStats:
85
+ """Collect stats for a single GPU."""
86
+ handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
87
+
88
+ # Get device name
89
+ name = pynvml.nvmlDeviceGetName(handle)
90
+ if isinstance(name, bytes):
91
+ name = name.decode("utf-8")
92
+
93
+ # Memory info
94
+ mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
95
+ memory_used_gb = mem_info.used / 1e9
96
+ memory_total_gb = mem_info.total / 1e9
97
+ memory_percent = (mem_info.used / mem_info.total) * 100
98
+
99
+ # Utilization
100
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
101
+ gpu_util_percent = util.gpu
102
+
103
+ # Temperature
104
+ temperature_c = pynvml.nvmlDeviceGetTemperature(
105
+ handle, pynvml.NVML_TEMPERATURE_GPU
106
+ )
107
+
108
+ # Power
109
+ try:
110
+ power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
111
+ power_limit_watts = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) / 1000.0
112
+ except pynvml.NVMLError:
113
+ power_watts = 0
114
+ power_limit_watts = 0
115
+
116
+ # Find TP rank for this GPU
117
+ tp_rank = None
118
+ for rank, gid in self._rank_mapping.items():
119
+ if gid == gpu_id:
120
+ tp_rank = rank
121
+ break
122
+
123
+ return GPUStats(
124
+ gpu_id=gpu_id,
125
+ name=name,
126
+ memory_used_gb=memory_used_gb,
127
+ memory_total_gb=memory_total_gb,
128
+ memory_percent=memory_percent,
129
+ gpu_util_percent=gpu_util_percent,
130
+ temperature_c=temperature_c,
131
+ power_watts=power_watts,
132
+ power_limit_watts=power_limit_watts,
133
+ tp_rank=tp_rank,
134
+ )
135
+
136
+ def _get_mock_stats(self) -> List[GPUStats]:
137
+ """Return mock stats when pynvml is not available."""
138
+ import random
139
+
140
+ mock_gpus = [
141
+ GPUStats(
142
+ gpu_id=0,
143
+ name="Mock GPU 0",
144
+ memory_used_gb=random.uniform(10, 20),
145
+ memory_total_gb=24.0,
146
+ memory_percent=random.uniform(40, 80),
147
+ gpu_util_percent=random.uniform(20, 90),
148
+ temperature_c=random.randint(40, 70),
149
+ power_watts=random.uniform(100, 300),
150
+ power_limit_watts=350,
151
+ tp_rank=0,
152
+ ),
153
+ GPUStats(
154
+ gpu_id=1,
155
+ name="Mock GPU 1",
156
+ memory_used_gb=random.uniform(10, 20),
157
+ memory_total_gb=24.0,
158
+ memory_percent=random.uniform(40, 80),
159
+ gpu_util_percent=random.uniform(20, 90),
160
+ temperature_c=random.randint(40, 70),
161
+ power_watts=random.uniform(100, 300),
162
+ power_limit_watts=350,
163
+ tp_rank=1,
164
+ ),
165
+ ]
166
+ return mock_gpus
167
+
168
+ def shutdown(self) -> None:
169
+ """Clean up NVML resources."""
170
+ if self._initialized and PYNVML_AVAILABLE:
171
+ try:
172
+ pynvml.nvmlShutdown()
173
+ except Exception:
174
+ pass
collectors/loading_tracker.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading progress tracker."""
2
+
3
+ import json
4
+ import re
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from typing import Optional, List, Dict, Any
8
+ from pathlib import Path
9
+ from enum import Enum
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class LoadingStatus(Enum):
15
+ """Status of model loading."""
16
+ NOT_STARTED = "not_started"
17
+ DOWNLOADING = "downloading"
18
+ LOADING = "loading"
19
+ READY = "ready"
20
+ ERROR = "error"
21
+
22
+
23
+ @dataclass
24
+ class ShardInfo:
25
+ """Information about a model shard file."""
26
+ filename: str
27
+ size_mb: float
28
+ status: str # pending, loading, loaded
29
+ layers: List[str]
30
+
31
+
32
+ @dataclass
33
+ class LoadingProgress:
34
+ """Overall loading progress."""
35
+ status: LoadingStatus
36
+ total_shards: int
37
+ loaded_shards: int
38
+ current_shard: Optional[str]
39
+ progress_percent: float
40
+ layers_loaded: int
41
+ total_layers: int
42
+ estimated_remaining_seconds: Optional[float]
43
+ error_message: Optional[str] = None
44
+
45
+
46
+ class LoadingTracker:
47
+ """Tracks model loading progress."""
48
+
49
+ def __init__(self, model_path: Optional[str] = None):
50
+ """
51
+ Initialize loading tracker.
52
+
53
+ Args:
54
+ model_path: Path to model directory
55
+ """
56
+ self.model_path = model_path
57
+ self._shards: List[ShardInfo] = []
58
+ self._status = LoadingStatus.NOT_STARTED
59
+ self._progress = 0.0
60
+ self._current_shard: Optional[str] = None
61
+ self._layers_loaded = 0
62
+ self._total_layers = 0
63
+ self._start_time: Optional[float] = None
64
+
65
+ def set_model_path(self, model_path: str) -> None:
66
+ """Set or update the model path."""
67
+ self.model_path = model_path
68
+ self._load_shard_info()
69
+
70
+ def _load_shard_info(self) -> None:
71
+ """Load shard information from safetensors index."""
72
+ if not self.model_path:
73
+ return
74
+
75
+ index_path = self._resolve_path("model.safetensors.index.json")
76
+ if not index_path:
77
+ return
78
+
79
+ try:
80
+ with open(index_path) as f:
81
+ index = json.load(f)
82
+
83
+ weight_map = index.get("weight_map", {})
84
+
85
+ # Group weights by shard file
86
+ shard_weights: Dict[str, List[str]] = {}
87
+ for weight_name, shard_file in weight_map.items():
88
+ if shard_file not in shard_weights:
89
+ shard_weights[shard_file] = []
90
+ shard_weights[shard_file].append(weight_name)
91
+
92
+ # Create shard info
93
+ self._shards = []
94
+ for shard_file, weights in sorted(shard_weights.items()):
95
+ shard_path = self._resolve_path(shard_file)
96
+ size_mb = 0
97
+ if shard_path and shard_path.exists():
98
+ size_mb = shard_path.stat().st_size / (1024 * 1024)
99
+
100
+ # Extract layer names
101
+ layers = list(set(
102
+ ".".join(w.split(".")[:3])
103
+ for w in weights
104
+ if len(w.split(".")) >= 3
105
+ ))
106
+
107
+ self._shards.append(ShardInfo(
108
+ filename=shard_file,
109
+ size_mb=size_mb,
110
+ status="pending",
111
+ layers=layers,
112
+ ))
113
+
114
+ # Count total layers
115
+ all_layers = set()
116
+ for shard in self._shards:
117
+ all_layers.update(shard.layers)
118
+ self._total_layers = len(all_layers)
119
+
120
+ except Exception as e:
121
+ logger.error(f"Error loading shard info: {e}")
122
+
123
+ def _resolve_path(self, filename: str) -> Optional[Path]:
124
+ """Resolve path to a file in the model directory."""
125
+ if not self.model_path:
126
+ return None
127
+
128
+ local_path = Path(self.model_path) / filename
129
+ if local_path.exists():
130
+ return local_path
131
+
132
+ return None
133
+
134
+ def update_from_log(self, log_line: str) -> None:
135
+ """
136
+ Update progress from a vLLM log line.
137
+
138
+ Args:
139
+ log_line: Log line from vLLM server
140
+ """
141
+ # Detect loading start
142
+ if "Loading model" in log_line:
143
+ self._status = LoadingStatus.LOADING
144
+ import time
145
+ self._start_time = time.time()
146
+
147
+ # Detect shard loading
148
+ match = re.search(r"Loading safetensors: (\d+)/(\d+)", log_line)
149
+ if match:
150
+ loaded = int(match.group(1))
151
+ total = int(match.group(2))
152
+ self._progress = loaded / total * 100
153
+ for i, shard in enumerate(self._shards):
154
+ if i < loaded:
155
+ shard.status = "loaded"
156
+ elif i == loaded:
157
+ shard.status = "loading"
158
+ self._current_shard = shard.filename
159
+
160
+ # Detect completion
161
+ if "Model loaded" in log_line or "Running with" in log_line:
162
+ self._status = LoadingStatus.READY
163
+ self._progress = 100.0
164
+ for shard in self._shards:
165
+ shard.status = "loaded"
166
+
167
+ # Detect errors
168
+ if "Error" in log_line or "Exception" in log_line:
169
+ self._status = LoadingStatus.ERROR
170
+
171
+ def get_progress(self) -> LoadingProgress:
172
+ """
173
+ Get current loading progress.
174
+
175
+ Returns:
176
+ LoadingProgress with current state
177
+ """
178
+ loaded_shards = sum(1 for s in self._shards if s.status == "loaded")
179
+ total_shards = len(self._shards) if self._shards else 1
180
+
181
+ # Estimate remaining time
182
+ remaining = None
183
+ if self._start_time and self._progress > 0:
184
+ import time
185
+ elapsed = time.time() - self._start_time
186
+ remaining = (elapsed / self._progress) * (100 - self._progress)
187
+
188
+ # Count loaded layers
189
+ loaded_layers = set()
190
+ for shard in self._shards:
191
+ if shard.status == "loaded":
192
+ loaded_layers.update(shard.layers)
193
+
194
+ return LoadingProgress(
195
+ status=self._status,
196
+ total_shards=total_shards,
197
+ loaded_shards=loaded_shards,
198
+ current_shard=self._current_shard,
199
+ progress_percent=self._progress,
200
+ layers_loaded=len(loaded_layers),
201
+ total_layers=self._total_layers,
202
+ estimated_remaining_seconds=remaining,
203
+ )
204
+
205
+ def get_shards(self) -> List[ShardInfo]:
206
+ """Get list of all shards with their status."""
207
+ return self._shards
208
+
209
+ def set_ready(self) -> None:
210
+ """Mark the model as fully loaded."""
211
+ self._status = LoadingStatus.READY
212
+ self._progress = 100.0
213
+ for shard in self._shards:
214
+ shard.status = "loaded"
215
+
216
+ def reset(self) -> None:
217
+ """Reset progress tracker."""
218
+ self._status = LoadingStatus.NOT_STARTED
219
+ self._progress = 0.0
220
+ self._current_shard = None
221
+ self._layers_loaded = 0
222
+ self._start_time = None
223
+ for shard in self._shards:
224
+ shard.status = "pending"
collectors/quant_collector.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quantization information collector."""
2
+
3
+ import json
4
+ import os
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Dict, Any, List
8
+ from pathlib import Path
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class QuantizationInfo:
15
+ """Quantization details for a model."""
16
+ method: str # GPTQ, AWQ, bitsandbytes, None
17
+ bits: int
18
+ group_size: Optional[int] = None
19
+ desc_act: Optional[bool] = None
20
+ sym: Optional[bool] = None
21
+ compute_dtype: Optional[str] = None
22
+ quant_type: Optional[str] = None # For bitsandbytes: nf4, fp4
23
+ double_quant: Optional[bool] = None
24
+ raw_config: Dict[str, Any] = None
25
+
26
+ def to_dict(self) -> Dict[str, Any]:
27
+ """Convert to dictionary for JSON display."""
28
+ result = {
29
+ "method": self.method,
30
+ "bits": self.bits,
31
+ }
32
+ if self.group_size is not None:
33
+ result["group_size"] = self.group_size
34
+ if self.desc_act is not None:
35
+ result["desc_act"] = self.desc_act
36
+ if self.sym is not None:
37
+ result["symmetric"] = self.sym
38
+ if self.compute_dtype:
39
+ result["compute_dtype"] = self.compute_dtype
40
+ if self.quant_type:
41
+ result["quant_type"] = self.quant_type
42
+ if self.double_quant is not None:
43
+ result["double_quant"] = self.double_quant
44
+ return result
45
+
46
+
47
+ @dataclass
48
+ class LayerPrecision:
49
+ """Precision information for a model layer."""
50
+ layer_name: str
51
+ bits: int
52
+ group_size: Optional[int]
53
+ dtype: str
54
+
55
+
56
+ class QuantizationCollector:
57
+ """Detects and collects quantization information from model configs."""
58
+
59
+ def __init__(self, model_path: Optional[str] = None):
60
+ """
61
+ Initialize quantization collector.
62
+
63
+ Args:
64
+ model_path: Path to model directory (local or HF model ID)
65
+ """
66
+ self.model_path = model_path
67
+ self._cached_info: Optional[QuantizationInfo] = None
68
+
69
+ def set_model_path(self, model_path: str) -> None:
70
+ """Set or update the model path."""
71
+ self.model_path = model_path
72
+ self._cached_info = None
73
+
74
+ def detect(self) -> QuantizationInfo:
75
+ """
76
+ Detect quantization method and settings.
77
+
78
+ Returns:
79
+ QuantizationInfo with detected settings
80
+ """
81
+ if self._cached_info is not None:
82
+ return self._cached_info
83
+
84
+ if not self.model_path:
85
+ return QuantizationInfo(method="Unknown", bits=16)
86
+
87
+ # Try to load config files
88
+ config = self._load_config()
89
+ quant_config = self._load_quant_config()
90
+
91
+ info = self._detect_quantization(config, quant_config)
92
+ self._cached_info = info
93
+ return info
94
+
95
+ def _load_config(self) -> Optional[Dict[str, Any]]:
96
+ """Load config.json from model path."""
97
+ config_path = self._resolve_path("config.json")
98
+ if config_path and config_path.exists():
99
+ try:
100
+ with open(config_path) as f:
101
+ return json.load(f)
102
+ except Exception as e:
103
+ logger.error(f"Error loading config.json: {e}")
104
+ return None
105
+
106
+ def _load_quant_config(self) -> Optional[Dict[str, Any]]:
107
+ """Load quantize_config.json (GPTQ/AWQ) from model path."""
108
+ config_path = self._resolve_path("quantize_config.json")
109
+ if config_path and config_path.exists():
110
+ try:
111
+ with open(config_path) as f:
112
+ return json.load(f)
113
+ except Exception as e:
114
+ logger.error(f"Error loading quantize_config.json: {e}")
115
+ return None
116
+
117
+ def _resolve_path(self, filename: str) -> Optional[Path]:
118
+ """Resolve path to a file in the model directory."""
119
+ if not self.model_path:
120
+ return None
121
+
122
+ # Handle local paths
123
+ local_path = Path(self.model_path) / filename
124
+ if local_path.exists():
125
+ return local_path
126
+
127
+ # Handle HuggingFace cache paths
128
+ cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
129
+ if cache_dir.exists():
130
+ # Search for model in cache
131
+ for model_dir in cache_dir.glob("models--*"):
132
+ model_name = model_dir.name.replace("models--", "").replace("--", "/")
133
+ if model_name.lower() == self.model_path.lower().replace("/", "--"):
134
+ snapshot_path = model_dir / "snapshots"
135
+ if snapshot_path.exists():
136
+ # Get latest snapshot
137
+ snapshots = list(snapshot_path.iterdir())
138
+ if snapshots:
139
+ file_path = snapshots[-1] / filename
140
+ if file_path.exists():
141
+ return file_path
142
+
143
+ return None
144
+
145
+ def _detect_quantization(
146
+ self,
147
+ config: Optional[Dict[str, Any]],
148
+ quant_config: Optional[Dict[str, Any]],
149
+ ) -> QuantizationInfo:
150
+ """Detect quantization from config files."""
151
+
152
+ # Check for GPTQ via quantize_config.json
153
+ if quant_config:
154
+ if "bits" in quant_config:
155
+ return QuantizationInfo(
156
+ method="GPTQ",
157
+ bits=quant_config.get("bits", 4),
158
+ group_size=quant_config.get("group_size", 128),
159
+ desc_act=quant_config.get("desc_act", False),
160
+ sym=quant_config.get("sym", True),
161
+ raw_config=quant_config,
162
+ )
163
+
164
+ if not config:
165
+ return QuantizationInfo(method="Unknown", bits=16)
166
+
167
+ # Check for quantization_config in config.json
168
+ qc = config.get("quantization_config", {})
169
+
170
+ if qc:
171
+ quant_method = qc.get("quant_method", "").lower()
172
+
173
+ # AWQ
174
+ if quant_method == "awq":
175
+ return QuantizationInfo(
176
+ method="AWQ",
177
+ bits=qc.get("bits", 4),
178
+ group_size=qc.get("group_size", 128),
179
+ raw_config=qc,
180
+ )
181
+
182
+ # GPTQ (in config.json)
183
+ if quant_method == "gptq":
184
+ return QuantizationInfo(
185
+ method="GPTQ",
186
+ bits=qc.get("bits", 4),
187
+ group_size=qc.get("group_size", 128),
188
+ desc_act=qc.get("desc_act", False),
189
+ sym=qc.get("sym", True),
190
+ raw_config=qc,
191
+ )
192
+
193
+ # bitsandbytes
194
+ if qc.get("load_in_4bit") or qc.get("load_in_8bit"):
195
+ bits = 4 if qc.get("load_in_4bit") else 8
196
+ return QuantizationInfo(
197
+ method="bitsandbytes",
198
+ bits=bits,
199
+ compute_dtype=qc.get("bnb_4bit_compute_dtype", "float16"),
200
+ quant_type=qc.get("bnb_4bit_quant_type", "nf4"),
201
+ double_quant=qc.get("bnb_4bit_use_double_quant", False),
202
+ raw_config=qc,
203
+ )
204
+
205
+ # Check torch_dtype for fp16/bf16
206
+ torch_dtype = config.get("torch_dtype", "float16")
207
+ if torch_dtype in ("float16", "bfloat16"):
208
+ return QuantizationInfo(
209
+ method="None (FP16/BF16)",
210
+ bits=16,
211
+ compute_dtype=torch_dtype,
212
+ )
213
+
214
+ return QuantizationInfo(method="Unknown", bits=16)
215
+
216
+ def get_layer_precisions(self) -> List[LayerPrecision]:
217
+ """
218
+ Get per-layer precision information.
219
+
220
+ Returns:
221
+ List of LayerPrecision for each layer
222
+ """
223
+ info = self.detect()
224
+
225
+ # For quantized models, all layers typically have same precision
226
+ # This could be extended to parse safetensors index for more detail
227
+
228
+ index_path = self._resolve_path("model.safetensors.index.json")
229
+ if not index_path or not index_path.exists():
230
+ return []
231
+
232
+ try:
233
+ with open(index_path) as f:
234
+ index = json.load(f)
235
+
236
+ weight_map = index.get("weight_map", {})
237
+ layers = []
238
+ seen_layers = set()
239
+
240
+ for weight_name in weight_map.keys():
241
+ # Extract layer name
242
+ parts = weight_name.split(".")
243
+ if len(parts) >= 3:
244
+ layer_name = ".".join(parts[:3])
245
+ if layer_name not in seen_layers:
246
+ seen_layers.add(layer_name)
247
+ layers.append(
248
+ LayerPrecision(
249
+ layer_name=layer_name,
250
+ bits=info.bits,
251
+ group_size=info.group_size,
252
+ dtype=info.compute_dtype or "float16",
253
+ )
254
+ )
255
+
256
+ return layers
257
+ except Exception as e:
258
+ logger.error(f"Error parsing layer precisions: {e}")
259
+ return []
collectors/vllm_collector.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """vLLM metrics collector via Prometheus endpoint."""
2
+
3
+ import requests
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional, Dict, List, Any
7
+ from datetime import datetime
8
+
9
+ from utils.prometheus_parser import (
10
+ parse_prometheus_metrics,
11
+ get_metric_value,
12
+ get_histogram_quantile,
13
+ MetricSample,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class InferenceMetrics:
21
+ """Inference metrics from vLLM."""
22
+ timestamp: datetime = field(default_factory=datetime.now)
23
+
24
+ # Request counts
25
+ num_requests_running: int = 0
26
+ num_requests_waiting: int = 0
27
+ num_requests_swapped: int = 0
28
+
29
+ # Token throughput
30
+ prompt_tokens_total: int = 0
31
+ generation_tokens_total: int = 0
32
+ tokens_per_second: float = 0.0
33
+
34
+ # Latency
35
+ ttft_ms: float = 0.0 # Time to first token
36
+ tpot_ms: float = 0.0 # Time per output token
37
+ e2e_latency_ms: float = 0.0 # End-to-end latency
38
+
39
+ # Cache
40
+ kv_cache_usage_percent: float = 0.0
41
+ gpu_cache_usage_percent: float = 0.0
42
+ cpu_cache_usage_percent: float = 0.0
43
+
44
+ # Model info
45
+ model_name: str = ""
46
+ max_model_len: int = 0
47
+
48
+ # Derived
49
+ prefill_ratio: float = 0.0
50
+ batch_size: int = 0
51
+
52
+
53
+ class VLLMCollector:
54
+ """Collects metrics from vLLM Prometheus endpoint."""
55
+
56
+ def __init__(self, metrics_url: str = "http://localhost:8000/metrics"):
57
+ """
58
+ Initialize the vLLM collector.
59
+
60
+ Args:
61
+ metrics_url: URL to vLLM's /metrics endpoint
62
+ """
63
+ self.metrics_url = metrics_url
64
+ self._last_prompt_tokens = 0
65
+ self._last_generation_tokens = 0
66
+ self._last_collect_time: Optional[datetime] = None
67
+ self._connected = False
68
+
69
+ def check_connection(self) -> bool:
70
+ """Check if vLLM server is accessible."""
71
+ try:
72
+ response = requests.get(self.metrics_url, timeout=2)
73
+ self._connected = response.status_code == 200
74
+ return self._connected
75
+ except Exception:
76
+ self._connected = False
77
+ return False
78
+
79
+ @property
80
+ def is_connected(self) -> bool:
81
+ """Return connection status."""
82
+ return self._connected
83
+
84
+ def collect(self) -> InferenceMetrics:
85
+ """
86
+ Collect all inference metrics from vLLM.
87
+
88
+ Returns:
89
+ InferenceMetrics dataclass with current values
90
+ """
91
+ metrics = InferenceMetrics()
92
+
93
+ try:
94
+ response = requests.get(self.metrics_url, timeout=5)
95
+ response.raise_for_status()
96
+ self._connected = True
97
+
98
+ raw_metrics = parse_prometheus_metrics(response.text)
99
+ metrics = self._parse_metrics(raw_metrics)
100
+
101
+ except requests.exceptions.ConnectionError:
102
+ self._connected = False
103
+ logger.debug("Cannot connect to vLLM metrics endpoint")
104
+ except Exception as e:
105
+ logger.error(f"Error collecting vLLM metrics: {e}")
106
+
107
+ return metrics
108
+
109
+ def _parse_metrics(self, raw: Dict[str, List[MetricSample]]) -> InferenceMetrics:
110
+ """Parse raw Prometheus metrics into InferenceMetrics."""
111
+ now = datetime.now()
112
+ metrics = InferenceMetrics(timestamp=now)
113
+
114
+ # Request counts
115
+ metrics.num_requests_running = int(
116
+ get_metric_value(raw, "vllm:num_requests_running") or 0
117
+ )
118
+ metrics.num_requests_waiting = int(
119
+ get_metric_value(raw, "vllm:num_requests_waiting") or 0
120
+ )
121
+ metrics.num_requests_swapped = int(
122
+ get_metric_value(raw, "vllm:num_requests_swapped") or 0
123
+ )
124
+ metrics.batch_size = metrics.num_requests_running
125
+
126
+ # Token counts (counters)
127
+ prompt_tokens = int(get_metric_value(raw, "vllm:prompt_tokens_total") or 0)
128
+ generation_tokens = int(
129
+ get_metric_value(raw, "vllm:generation_tokens_total") or 0
130
+ )
131
+
132
+ # Calculate tokens per second
133
+ if self._last_collect_time:
134
+ time_delta = (now - self._last_collect_time).total_seconds()
135
+ if time_delta > 0:
136
+ token_delta = generation_tokens - self._last_generation_tokens
137
+ metrics.tokens_per_second = token_delta / time_delta
138
+
139
+ self._last_prompt_tokens = prompt_tokens
140
+ self._last_generation_tokens = generation_tokens
141
+ self._last_collect_time = now
142
+
143
+ metrics.prompt_tokens_total = prompt_tokens
144
+ metrics.generation_tokens_total = generation_tokens
145
+
146
+ # Calculate prefill ratio
147
+ total_tokens = prompt_tokens + generation_tokens
148
+ if total_tokens > 0:
149
+ metrics.prefill_ratio = prompt_tokens / total_tokens
150
+
151
+ # Latency metrics (from histograms, use P50)
152
+ ttft = get_histogram_quantile(raw, "vllm:time_to_first_token_seconds", 0.5)
153
+ if ttft is not None:
154
+ metrics.ttft_ms = ttft * 1000
155
+
156
+ tpot = get_histogram_quantile(raw, "vllm:time_per_output_token_seconds", 0.5)
157
+ if tpot is not None:
158
+ metrics.tpot_ms = tpot * 1000
159
+
160
+ e2e = get_histogram_quantile(raw, "vllm:e2e_request_latency_seconds", 0.5)
161
+ if e2e is not None:
162
+ metrics.e2e_latency_ms = e2e * 1000
163
+
164
+ # Cache usage
165
+ metrics.gpu_cache_usage_percent = (
166
+ get_metric_value(raw, "vllm:gpu_cache_usage_perc") or 0
167
+ ) * 100
168
+ metrics.cpu_cache_usage_percent = (
169
+ get_metric_value(raw, "vllm:cpu_cache_usage_perc") or 0
170
+ ) * 100
171
+ metrics.kv_cache_usage_percent = metrics.gpu_cache_usage_percent
172
+
173
+ # Model info
174
+ model_name = self._get_model_name(raw)
175
+ if model_name:
176
+ metrics.model_name = model_name
177
+
178
+ return metrics
179
+
180
+ def _get_model_name(self, raw: Dict[str, List[MetricSample]]) -> Optional[str]:
181
+ """Extract model name from metrics labels."""
182
+ # Look for model name in any metric with model_name label
183
+ for metric_name, samples in raw.items():
184
+ for sample in samples:
185
+ if "model_name" in sample.labels:
186
+ return sample.labels["model_name"]
187
+ return None
188
+
189
+ def get_rank_mapping(self) -> Dict[int, int]:
190
+ """
191
+ Get tensor parallel rank to GPU mapping.
192
+
193
+ Returns:
194
+ Dictionary mapping TP rank to GPU ID
195
+ """
196
+ # This would typically come from vLLM's internal state
197
+ # For now, return empty mapping - can be extended
198
+ return {}
199
+
200
+ def get_latency_percentiles(self) -> Dict[str, Dict[str, float]]:
201
+ """
202
+ Get latency percentiles for detailed analysis.
203
+
204
+ Returns:
205
+ Dictionary with P50, P95, P99 for each latency metric
206
+ """
207
+ try:
208
+ response = requests.get(self.metrics_url, timeout=5)
209
+ raw = parse_prometheus_metrics(response.text)
210
+
211
+ result = {}
212
+ for metric_base in [
213
+ "vllm:time_to_first_token_seconds",
214
+ "vllm:time_per_output_token_seconds",
215
+ "vllm:e2e_request_latency_seconds",
216
+ ]:
217
+ result[metric_base] = {
218
+ "p50": (get_histogram_quantile(raw, metric_base, 0.5) or 0) * 1000,
219
+ "p95": (get_histogram_quantile(raw, metric_base, 0.95) or 0) * 1000,
220
+ "p99": (get_histogram_quantile(raw, metric_base, 0.99) or 0) * 1000,
221
+ }
222
+
223
+ return result
224
+ except Exception as e:
225
+ logger.error(f"Error getting latency percentiles: {e}")
226
+ return {}
components/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UI components for the Gradio dashboard."""
2
+
3
+ from .gpu_panel import create_gpu_panel, update_gpu_panel
4
+ from .inference_panel import create_inference_panel, update_inference_panel
5
+ from .quant_panel import create_quant_panel, update_quant_panel
6
+ from .loading_panel import create_loading_panel, update_loading_panel
7
+ from .alerts_panel import create_alerts_panel, update_alerts_panel
8
+ from .tracing_panel import create_tracing_panel, update_tracing_panel
9
+ from .comparison_panel import create_comparison_panel
10
+ from .loadtest_panel import create_loadtest_panel
11
+
12
+ __all__ = [
13
+ "create_gpu_panel",
14
+ "update_gpu_panel",
15
+ "create_inference_panel",
16
+ "update_inference_panel",
17
+ "create_quant_panel",
18
+ "update_quant_panel",
19
+ "create_loading_panel",
20
+ "update_loading_panel",
21
+ "create_alerts_panel",
22
+ "update_alerts_panel",
23
+ "create_tracing_panel",
24
+ "update_tracing_panel",
25
+ "create_comparison_panel",
26
+ "create_loadtest_panel",
27
+ ]
components/alerts_panel.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Alerts configuration and history panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from datetime import datetime
6
+ from typing import Dict, Any, Tuple, List
7
+
8
+ from services.alerting import AlertEngine, AlertDispatcher, Alert, AlertSeverity
9
+
10
+
11
+ def create_alerts_panel(
12
+ alert_engine: AlertEngine,
13
+ alert_dispatcher: AlertDispatcher,
14
+ ) -> Dict[str, Any]:
15
+ """
16
+ Create the alerts panel.
17
+
18
+ Args:
19
+ alert_engine: Alert engine instance
20
+ alert_dispatcher: Alert dispatcher instance
21
+
22
+ Returns:
23
+ Dictionary of Gradio components
24
+ """
25
+ with gr.Column():
26
+ with gr.Row():
27
+ # Active alerts column
28
+ with gr.Column(scale=2):
29
+ gr.Markdown("### Active Alerts")
30
+ active_alerts_table = gr.Dataframe(
31
+ headers=["Time", "Severity", "Metric", "Value", "Threshold", "Message"],
32
+ datatype=["str", "str", "str", "number", "number", "str"],
33
+ label="Active Alerts",
34
+ interactive=False,
35
+ )
36
+
37
+ gr.Markdown("### Alert History")
38
+ alert_history_table = gr.Dataframe(
39
+ headers=["Time", "Severity", "Message", "Resolved"],
40
+ datatype=["str", "str", "str", "str"],
41
+ label="Recent Alerts",
42
+ interactive=False,
43
+ )
44
+
45
+ # Configuration column
46
+ with gr.Column(scale=1):
47
+ gr.Markdown("### Alert Configuration")
48
+
49
+ kv_threshold = gr.Slider(
50
+ label="KV Cache Alert Threshold (%)",
51
+ minimum=50,
52
+ maximum=100,
53
+ value=90,
54
+ step=5,
55
+ )
56
+
57
+ gpu_memory_threshold = gr.Slider(
58
+ label="GPU Memory Alert Threshold (%)",
59
+ minimum=70,
60
+ maximum=100,
61
+ value=95,
62
+ step=5,
63
+ )
64
+
65
+ ttft_multiplier = gr.Slider(
66
+ label="TTFT Spike Multiplier",
67
+ minimum=1.5,
68
+ maximum=5,
69
+ value=2,
70
+ step=0.5,
71
+ )
72
+
73
+ throughput_drop = gr.Slider(
74
+ label="Throughput Drop Alert (%)",
75
+ minimum=20,
76
+ maximum=80,
77
+ value=50,
78
+ step=10,
79
+ )
80
+
81
+ gr.Markdown("### Webhook Configuration")
82
+
83
+ slack_webhook = gr.Textbox(
84
+ label="Slack Webhook URL",
85
+ placeholder="https://hooks.slack.com/services/...",
86
+ type="password",
87
+ )
88
+
89
+ pagerduty_key = gr.Textbox(
90
+ label="PagerDuty Routing Key",
91
+ placeholder="Enter routing key...",
92
+ type="password",
93
+ )
94
+
95
+ with gr.Row():
96
+ save_config_btn = gr.Button("Save Configuration")
97
+ test_alert_btn = gr.Button("Send Test Alert", variant="secondary")
98
+
99
+ config_status = gr.Textbox(
100
+ label="Status",
101
+ interactive=False,
102
+ visible=True,
103
+ )
104
+
105
+ # Event handlers
106
+ def save_config(kv, gpu, ttft, tp_drop, slack, pd_key):
107
+ # Update alert thresholds
108
+ if "kv_cache_high" in alert_engine.rules:
109
+ alert_engine.rules["kv_cache_high"].threshold = kv
110
+ if "gpu_memory_critical" in alert_engine.rules:
111
+ alert_engine.rules["gpu_memory_critical"].threshold = gpu
112
+ if "ttft_spike" in alert_engine.rules:
113
+ alert_engine.rules["ttft_spike"].multiplier = ttft
114
+ if "throughput_drop" in alert_engine.rules:
115
+ alert_engine.rules["throughput_drop"].percent = tp_drop
116
+
117
+ # Update webhook config
118
+ alert_dispatcher.slack_webhook = slack if slack else None
119
+ alert_dispatcher.pagerduty_key = pd_key if pd_key else None
120
+
121
+ return "Configuration saved successfully"
122
+
123
+ save_config_btn.click(
124
+ fn=save_config,
125
+ inputs=[
126
+ kv_threshold,
127
+ gpu_memory_threshold,
128
+ ttft_multiplier,
129
+ throughput_drop,
130
+ slack_webhook,
131
+ pagerduty_key,
132
+ ],
133
+ outputs=config_status,
134
+ )
135
+
136
+ async def send_test():
137
+ success = await alert_dispatcher.send_test_alert()
138
+ if success:
139
+ return "Test alert sent successfully"
140
+ return "Failed to send test alert - check webhook configuration"
141
+
142
+ test_alert_btn.click(
143
+ fn=send_test,
144
+ outputs=config_status,
145
+ )
146
+
147
+ return {
148
+ "active_alerts_table": active_alerts_table,
149
+ "alert_history_table": alert_history_table,
150
+ "kv_threshold": kv_threshold,
151
+ "gpu_memory_threshold": gpu_memory_threshold,
152
+ "ttft_multiplier": ttft_multiplier,
153
+ "throughput_drop": throughput_drop,
154
+ "slack_webhook": slack_webhook,
155
+ "pagerduty_key": pagerduty_key,
156
+ "config_status": config_status,
157
+ }
158
+
159
+
160
+ def update_alerts_panel(
161
+ alert_engine: AlertEngine,
162
+ db=None,
163
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
164
+ """
165
+ Update the alerts panel with current data.
166
+
167
+ Args:
168
+ alert_engine: Alert engine instance
169
+ db: Optional database for history
170
+
171
+ Returns:
172
+ Tuple of (active_alerts_df, history_df)
173
+ """
174
+ # Get active alerts
175
+ active = alert_engine.get_active_alerts()
176
+ active_rows = []
177
+ for alert in active:
178
+ active_rows.append({
179
+ "Time": alert.timestamp.strftime("%H:%M:%S"),
180
+ "Severity": _format_severity(alert.severity),
181
+ "Metric": alert.metric,
182
+ "Value": round(alert.value, 2),
183
+ "Threshold": round(alert.threshold, 2),
184
+ "Message": alert.message,
185
+ })
186
+
187
+ active_df = pd.DataFrame(active_rows) if active_rows else pd.DataFrame(
188
+ columns=["Time", "Severity", "Metric", "Value", "Threshold", "Message"]
189
+ )
190
+
191
+ # Get history from database
192
+ history_rows = []
193
+ if db:
194
+ recent = db.get_recent_alerts(limit=20)
195
+ for record in recent:
196
+ history_rows.append({
197
+ "Time": record.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
198
+ "Severity": _format_severity_str(record.severity),
199
+ "Message": record.message,
200
+ "Resolved": "Yes" if record.resolved_at else "No",
201
+ })
202
+
203
+ history_df = pd.DataFrame(history_rows) if history_rows else pd.DataFrame(
204
+ columns=["Time", "Severity", "Message", "Resolved"]
205
+ )
206
+
207
+ return active_df, history_df
208
+
209
+
210
+ def _format_severity(severity: AlertSeverity) -> str:
211
+ """Format severity for display."""
212
+ icons = {
213
+ AlertSeverity.INFO: "INFO",
214
+ AlertSeverity.WARNING: "WARNING",
215
+ AlertSeverity.CRITICAL: "CRITICAL",
216
+ }
217
+ return icons.get(severity, "UNKNOWN")
218
+
219
+
220
+ def _format_severity_str(severity: str) -> str:
221
+ """Format severity string for display."""
222
+ return severity.upper()
223
+
224
+
225
+ def get_alert_badge_html(alerts: List[Alert]) -> str:
226
+ """
227
+ Generate HTML badge for active alerts.
228
+
229
+ Args:
230
+ alerts: List of active alerts
231
+
232
+ Returns:
233
+ HTML string for badge
234
+ """
235
+ if not alerts:
236
+ return '<span style="color: #2e7d32;">No Active Alerts</span>'
237
+
238
+ critical = sum(1 for a in alerts if a.severity == AlertSeverity.CRITICAL)
239
+ warning = sum(1 for a in alerts if a.severity == AlertSeverity.WARNING)
240
+
241
+ badges = []
242
+ if critical > 0:
243
+ badges.append(
244
+ f'<span style="background: #c62828; color: white; padding: 2px 8px; '
245
+ f'border-radius: 12px; margin-right: 5px;">{critical} Critical</span>'
246
+ )
247
+ if warning > 0:
248
+ badges.append(
249
+ f'<span style="background: #ff9800; color: white; padding: 2px 8px; '
250
+ f'border-radius: 12px;">{warning} Warning</span>'
251
+ )
252
+
253
+ return "".join(badges)
components/comparison_panel.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A/B comparison panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import asyncio
6
+ from typing import Dict, Any, Tuple
7
+
8
+ from services.comparator import ABComparator, DeploymentConfig, ComparisonResult
9
+
10
+
11
+ def create_comparison_panel() -> Dict[str, Any]:
12
+ """
13
+ Create the A/B comparison panel.
14
+
15
+ Returns:
16
+ Dictionary of Gradio components
17
+ """
18
+ with gr.Column():
19
+ gr.Markdown("### A/B Deployment Comparison")
20
+
21
+ # Endpoint configuration
22
+ with gr.Row():
23
+ endpoint_a = gr.Textbox(
24
+ label="Deployment A",
25
+ value="http://localhost:8000",
26
+ placeholder="http://host:port",
27
+ )
28
+ name_a = gr.Textbox(
29
+ label="Name A",
30
+ value="Baseline",
31
+ placeholder="e.g., FP16-baseline",
32
+ )
33
+
34
+ with gr.Row():
35
+ endpoint_b = gr.Textbox(
36
+ label="Deployment B",
37
+ value="http://localhost:8001",
38
+ placeholder="http://host:port",
39
+ )
40
+ name_b = gr.Textbox(
41
+ label="Name B",
42
+ value="Candidate",
43
+ placeholder="e.g., AWQ-4bit",
44
+ )
45
+
46
+ with gr.Row():
47
+ compare_btn = gr.Button("Compare Now", variant="primary")
48
+ collect_samples_btn = gr.Button("Collect Samples (30s)", variant="secondary")
49
+
50
+ # Status
51
+ comparison_status = gr.Textbox(
52
+ label="Status",
53
+ interactive=False,
54
+ )
55
+
56
+ # Results side by side
57
+ with gr.Row():
58
+ with gr.Column():
59
+ gr.Markdown("### Deployment A")
60
+ a_connected = gr.Checkbox(label="Connected", interactive=False)
61
+ a_throughput = gr.Number(label="Throughput (tok/s)", precision=1, interactive=False)
62
+ a_ttft = gr.Number(label="TTFT (ms)", precision=1, interactive=False)
63
+ a_latency = gr.Number(label="E2E Latency (ms)", precision=1, interactive=False)
64
+ a_kv_cache = gr.Number(label="KV Cache %", precision=1, interactive=False)
65
+ a_batch = gr.Number(label="Batch Size", precision=0, interactive=False)
66
+
67
+ with gr.Column():
68
+ gr.Markdown("### Deployment B")
69
+ b_connected = gr.Checkbox(label="Connected", interactive=False)
70
+ b_throughput = gr.Number(label="Throughput (tok/s)", precision=1, interactive=False)
71
+ b_ttft = gr.Number(label="TTFT (ms)", precision=1, interactive=False)
72
+ b_latency = gr.Number(label="E2E Latency (ms)", precision=1, interactive=False)
73
+ b_kv_cache = gr.Number(label="KV Cache %", precision=1, interactive=False)
74
+ b_batch = gr.Number(label="Batch Size", precision=0, interactive=False)
75
+
76
+ # Comparison table
77
+ gr.Markdown("### Comparison Summary")
78
+ comparison_table = gr.Dataframe(
79
+ headers=["Metric", "Deployment A", "Deployment B", "Difference"],
80
+ label="Comparison",
81
+ interactive=False,
82
+ )
83
+
84
+ # Recommendation
85
+ recommendation = gr.Markdown("")
86
+
87
+ # Statistical significance
88
+ with gr.Row():
89
+ significance_throughput = gr.Textbox(
90
+ label="Throughput Significance",
91
+ interactive=False,
92
+ )
93
+ significance_latency = gr.Textbox(
94
+ label="Latency Significance",
95
+ interactive=False,
96
+ )
97
+
98
+ # Event handlers
99
+ async def run_comparison(ep_a, name_a_val, ep_b, name_b_val):
100
+ config_a = DeploymentConfig(name=name_a_val, endpoint=ep_a)
101
+ config_b = DeploymentConfig(name=name_b_val, endpoint=ep_b)
102
+
103
+ comparator = ABComparator(config_a, config_b)
104
+ result = await comparator.compare()
105
+
106
+ return format_comparison_results(result, comparator)
107
+
108
+ compare_btn.click(
109
+ fn=run_comparison,
110
+ inputs=[endpoint_a, name_a, endpoint_b, name_b],
111
+ outputs=[
112
+ comparison_status,
113
+ a_connected, a_throughput, a_ttft, a_latency, a_kv_cache, a_batch,
114
+ b_connected, b_throughput, b_ttft, b_latency, b_kv_cache, b_batch,
115
+ comparison_table, recommendation,
116
+ significance_throughput, significance_latency,
117
+ ],
118
+ )
119
+
120
+ async def collect_and_compare(ep_a, name_a_val, ep_b, name_b_val):
121
+ config_a = DeploymentConfig(name=name_a_val, endpoint=ep_a)
122
+ config_b = DeploymentConfig(name=name_b_val, endpoint=ep_b)
123
+
124
+ comparator = ABComparator(config_a, config_b)
125
+
126
+ # Collect samples (this takes ~30 seconds)
127
+ yield (
128
+ "Collecting samples (0/30)...",
129
+ *[None] * 15 # Placeholder outputs
130
+ )
131
+
132
+ await comparator.collect_samples(count=30)
133
+ result = await comparator.compare()
134
+
135
+ yield format_comparison_results(result, comparator)
136
+
137
+ collect_samples_btn.click(
138
+ fn=collect_and_compare,
139
+ inputs=[endpoint_a, name_a, endpoint_b, name_b],
140
+ outputs=[
141
+ comparison_status,
142
+ a_connected, a_throughput, a_ttft, a_latency, a_kv_cache, a_batch,
143
+ b_connected, b_throughput, b_ttft, b_latency, b_kv_cache, b_batch,
144
+ comparison_table, recommendation,
145
+ significance_throughput, significance_latency,
146
+ ],
147
+ )
148
+
149
+ return {
150
+ "endpoint_a": endpoint_a,
151
+ "name_a": name_a,
152
+ "endpoint_b": endpoint_b,
153
+ "name_b": name_b,
154
+ "comparison_status": comparison_status,
155
+ "a_connected": a_connected,
156
+ "a_throughput": a_throughput,
157
+ "a_ttft": a_ttft,
158
+ "a_latency": a_latency,
159
+ "a_kv_cache": a_kv_cache,
160
+ "a_batch": a_batch,
161
+ "b_connected": b_connected,
162
+ "b_throughput": b_throughput,
163
+ "b_ttft": b_ttft,
164
+ "b_latency": b_latency,
165
+ "b_kv_cache": b_kv_cache,
166
+ "b_batch": b_batch,
167
+ "comparison_table": comparison_table,
168
+ "recommendation": recommendation,
169
+ "significance_throughput": significance_throughput,
170
+ "significance_latency": significance_latency,
171
+ }
172
+
173
+
174
+ def format_comparison_results(
175
+ result: ComparisonResult,
176
+ comparator: ABComparator,
177
+ ) -> Tuple:
178
+ """Format comparison results for UI components."""
179
+ a = result.deployment_a
180
+ b = result.deployment_b
181
+
182
+ # Build comparison table
183
+ table_data = comparator.get_comparison_table(result)
184
+ table_df = pd.DataFrame(table_data)
185
+
186
+ # Format recommendation
187
+ recommendation_md = f"**Recommendation:** {result.recommendation}"
188
+
189
+ # Format significance
190
+ sig_throughput = "Not tested"
191
+ sig_latency = "Not tested"
192
+
193
+ if result.p_value_throughput < 1.0:
194
+ sig_status = "Significant" if result.throughput_significant else "Not significant"
195
+ sig_throughput = f"{sig_status} (p={result.p_value_throughput:.4f})"
196
+
197
+ if result.p_value_latency < 1.0:
198
+ sig_status = "Significant" if result.latency_significant else "Not significant"
199
+ sig_latency = f"{sig_status} (p={result.p_value_latency:.4f})"
200
+
201
+ return (
202
+ "Comparison complete",
203
+ a.connected, a.tokens_per_second, a.ttft_ms, a.e2e_latency_ms, a.kv_cache_percent, a.batch_size,
204
+ b.connected, b.tokens_per_second, b.ttft_ms, b.e2e_latency_ms, b.kv_cache_percent, b.batch_size,
205
+ table_df, recommendation_md,
206
+ sig_throughput, sig_latency,
207
+ )
components/gpu_panel.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPU status panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from typing import List, Dict, Any, Tuple
6
+
7
+ from collectors.gpu_collector import GPUCollector, GPUStats
8
+ from utils.history import MetricHistory
9
+
10
+
11
+ def create_gpu_panel(history: MetricHistory) -> Dict[str, Any]:
12
+ """
13
+ Create the GPU status panel.
14
+
15
+ Args:
16
+ history: Metric history for charting
17
+
18
+ Returns:
19
+ Dictionary of Gradio components
20
+ """
21
+ with gr.Column():
22
+ gr.Markdown("### GPU / Rank Status")
23
+
24
+ # GPU stats table
25
+ gpu_table = gr.Dataframe(
26
+ headers=["GPU", "Name", "Memory", "Memory %", "Util %", "Temp", "Power", "TP Rank"],
27
+ datatype=["number", "str", "str", "number", "number", "str", "str", "str"],
28
+ label="GPU Statistics",
29
+ interactive=False,
30
+ )
31
+
32
+ with gr.Row():
33
+ # Memory usage plot
34
+ gpu_memory_plot = gr.LinePlot(
35
+ x="time",
36
+ y="value",
37
+ color="gpu",
38
+ title="GPU Memory Usage (GB)",
39
+ x_title="Time",
40
+ y_title="Memory (GB)",
41
+ height=250,
42
+ )
43
+
44
+ # Utilization plot
45
+ gpu_util_plot = gr.LinePlot(
46
+ x="time",
47
+ y="value",
48
+ color="gpu",
49
+ title="GPU Utilization (%)",
50
+ x_title="Time",
51
+ y_title="Utilization %",
52
+ height=250,
53
+ )
54
+
55
+ # NCCL / Communication status
56
+ nccl_status = gr.HTML(
57
+ value='<div style="padding: 10px; background: #e8f5e9; border-radius: 5px;">'
58
+ '<span style="color: #2e7d32;">NCCL Status: Healthy</span></div>',
59
+ label="Communication Status",
60
+ )
61
+
62
+ return {
63
+ "gpu_table": gpu_table,
64
+ "gpu_memory_plot": gpu_memory_plot,
65
+ "gpu_util_plot": gpu_util_plot,
66
+ "nccl_status": nccl_status,
67
+ }
68
+
69
+
70
+ def update_gpu_panel(
71
+ collector: GPUCollector,
72
+ history: MetricHistory,
73
+ ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, str]:
74
+ """
75
+ Update the GPU panel with current data.
76
+
77
+ Args:
78
+ collector: GPU collector instance
79
+ history: Metric history
80
+
81
+ Returns:
82
+ Tuple of (table_data, memory_plot_data, util_plot_data, nccl_html)
83
+ """
84
+ stats = collector.collect()
85
+
86
+ # Update history
87
+ for stat in stats:
88
+ history.add(
89
+ "gpu_memory_gb",
90
+ stat.memory_used_gb,
91
+ labels={"gpu": str(stat.gpu_id)},
92
+ )
93
+ history.add(
94
+ "gpu_util_percent",
95
+ stat.gpu_util_percent,
96
+ labels={"gpu": str(stat.gpu_id)},
97
+ )
98
+
99
+ # Build table data
100
+ table_data = _build_table(stats)
101
+
102
+ # Build chart data
103
+ memory_df = _build_memory_chart_data(history)
104
+ util_df = _build_util_chart_data(history)
105
+
106
+ # NCCL status (simplified - would need more complex detection)
107
+ nccl_html = _build_nccl_status(stats)
108
+
109
+ return table_data, memory_df, util_df, nccl_html
110
+
111
+
112
+ def _build_table(stats: List[GPUStats]) -> pd.DataFrame:
113
+ """Build GPU stats table."""
114
+ rows = []
115
+ for stat in stats:
116
+ rows.append({
117
+ "GPU": stat.gpu_id,
118
+ "Name": stat.name[:20] if len(stat.name) > 20 else stat.name,
119
+ "Memory": f"{stat.memory_used_gb:.1f}/{stat.memory_total_gb:.1f} GB",
120
+ "Memory %": round(stat.memory_percent, 1),
121
+ "Util %": round(stat.gpu_util_percent, 1),
122
+ "Temp": f"{stat.temperature_c}C",
123
+ "Power": f"{stat.power_watts:.0f}/{stat.power_limit_watts:.0f}W",
124
+ "TP Rank": str(stat.tp_rank) if stat.tp_rank is not None else "-",
125
+ })
126
+
127
+ return pd.DataFrame(rows)
128
+
129
+
130
+ def _build_memory_chart_data(history: MetricHistory) -> pd.DataFrame:
131
+ """Build memory usage chart data."""
132
+ all_series = history.get_all_series("gpu_memory_gb")
133
+
134
+ rows = []
135
+ for key, points in all_series.items():
136
+ gpu_id = key.split("=")[-1] if "=" in key else "0"
137
+ for point in points[-60:]: # Last 60 points
138
+ rows.append({
139
+ "time": point.timestamp,
140
+ "value": point.value,
141
+ "gpu": f"GPU {gpu_id}",
142
+ })
143
+
144
+ if not rows:
145
+ return pd.DataFrame({"time": [], "value": [], "gpu": []})
146
+
147
+ return pd.DataFrame(rows)
148
+
149
+
150
+ def _build_util_chart_data(history: MetricHistory) -> pd.DataFrame:
151
+ """Build utilization chart data."""
152
+ all_series = history.get_all_series("gpu_util_percent")
153
+
154
+ rows = []
155
+ for key, points in all_series.items():
156
+ gpu_id = key.split("=")[-1] if "=" in key else "0"
157
+ for point in points[-60:]:
158
+ rows.append({
159
+ "time": point.timestamp,
160
+ "value": point.value,
161
+ "gpu": f"GPU {gpu_id}",
162
+ })
163
+
164
+ if not rows:
165
+ return pd.DataFrame({"time": [], "value": [], "gpu": []})
166
+
167
+ return pd.DataFrame(rows)
168
+
169
+
170
+ def _build_nccl_status(stats: List[GPUStats]) -> str:
171
+ """Build NCCL status HTML."""
172
+ if not stats:
173
+ return (
174
+ '<div style="padding: 10px; background: #fff3e0; border-radius: 5px;">'
175
+ '<span style="color: #e65100;">NCCL Status: No GPUs detected</span></div>'
176
+ )
177
+
178
+ # Check for GPU communication health indicators
179
+ # In a real implementation, this would check vLLM metrics for NCCL errors
180
+ all_healthy = all(stat.gpu_util_percent > 0 or stat.memory_percent > 0 for stat in stats)
181
+
182
+ if all_healthy:
183
+ return (
184
+ '<div style="padding: 10px; background: #e8f5e9; border-radius: 5px;">'
185
+ f'<span style="color: #2e7d32;">NCCL Status: Healthy ({len(stats)} GPUs)</span></div>'
186
+ )
187
+ else:
188
+ return (
189
+ '<div style="padding: 10px; background: #ffebee; border-radius: 5px;">'
190
+ '<span style="color: #c62828;">NCCL Status: Communication issue detected</span></div>'
191
+ )
components/inference_panel.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference metrics panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from typing import Dict, Any, Tuple
6
+
7
+ from collectors.vllm_collector import VLLMCollector, InferenceMetrics
8
+ from utils.history import MetricHistory
9
+
10
+
11
+ def create_inference_panel(history: MetricHistory) -> Dict[str, Any]:
12
+ """
13
+ Create the inference metrics panel.
14
+
15
+ Args:
16
+ history: Metric history for charting
17
+
18
+ Returns:
19
+ Dictionary of Gradio components
20
+ """
21
+ with gr.Column():
22
+ gr.Markdown("### Inference Metrics")
23
+
24
+ # Key metrics row
25
+ with gr.Row():
26
+ throughput = gr.Number(
27
+ label="Tokens/sec",
28
+ precision=1,
29
+ interactive=False,
30
+ )
31
+ ttft = gr.Number(
32
+ label="TTFT (ms)",
33
+ precision=1,
34
+ interactive=False,
35
+ )
36
+ batch_size = gr.Number(
37
+ label="Batch Size",
38
+ precision=0,
39
+ interactive=False,
40
+ )
41
+ kv_cache = gr.Number(
42
+ label="KV Cache %",
43
+ precision=1,
44
+ interactive=False,
45
+ )
46
+
47
+ # Throughput plot
48
+ throughput_plot = gr.LinePlot(
49
+ x="time",
50
+ y="value",
51
+ title="Throughput Over Time",
52
+ x_title="Time",
53
+ y_title="Tokens/sec",
54
+ height=250,
55
+ )
56
+
57
+ # Secondary metrics row
58
+ with gr.Row():
59
+ prefill_pct = gr.Number(
60
+ label="Prefill %",
61
+ precision=1,
62
+ interactive=False,
63
+ )
64
+ decode_pct = gr.Number(
65
+ label="Decode %",
66
+ precision=1,
67
+ interactive=False,
68
+ )
69
+ queue_depth = gr.Number(
70
+ label="Queue Depth",
71
+ precision=0,
72
+ interactive=False,
73
+ )
74
+ e2e_latency = gr.Number(
75
+ label="E2E Latency (ms)",
76
+ precision=1,
77
+ interactive=False,
78
+ )
79
+
80
+ # Latency plot
81
+ latency_plot = gr.LinePlot(
82
+ x="time",
83
+ y="value",
84
+ color="metric",
85
+ title="Latency Over Time",
86
+ x_title="Time",
87
+ y_title="Latency (ms)",
88
+ height=250,
89
+ )
90
+
91
+ return {
92
+ "throughput": throughput,
93
+ "ttft": ttft,
94
+ "batch_size": batch_size,
95
+ "kv_cache": kv_cache,
96
+ "throughput_plot": throughput_plot,
97
+ "prefill_pct": prefill_pct,
98
+ "decode_pct": decode_pct,
99
+ "queue_depth": queue_depth,
100
+ "e2e_latency": e2e_latency,
101
+ "latency_plot": latency_plot,
102
+ }
103
+
104
+
105
+ def update_inference_panel(
106
+ collector: VLLMCollector,
107
+ history: MetricHistory,
108
+ ) -> Tuple[float, float, int, float, pd.DataFrame, float, float, int, float, pd.DataFrame]:
109
+ """
110
+ Update the inference panel with current data.
111
+
112
+ Args:
113
+ collector: vLLM collector instance
114
+ history: Metric history
115
+
116
+ Returns:
117
+ Tuple of all metric values and chart data
118
+ """
119
+ metrics = collector.collect()
120
+
121
+ # Update history
122
+ history.add("tokens_per_second", metrics.tokens_per_second)
123
+ history.add("ttft_ms", metrics.ttft_ms)
124
+ history.add("e2e_latency_ms", metrics.e2e_latency_ms)
125
+ history.add("kv_cache_percent", metrics.kv_cache_usage_percent)
126
+
127
+ # Build throughput chart
128
+ throughput_df = _build_throughput_chart(history)
129
+
130
+ # Build latency chart
131
+ latency_df = _build_latency_chart(history)
132
+
133
+ # Calculate prefill/decode percentages
134
+ prefill_pct = metrics.prefill_ratio * 100
135
+ decode_pct = 100 - prefill_pct
136
+
137
+ return (
138
+ metrics.tokens_per_second,
139
+ metrics.ttft_ms,
140
+ metrics.batch_size,
141
+ metrics.kv_cache_usage_percent,
142
+ throughput_df,
143
+ prefill_pct,
144
+ decode_pct,
145
+ metrics.num_requests_waiting,
146
+ metrics.e2e_latency_ms,
147
+ latency_df,
148
+ )
149
+
150
+
151
+ def _build_throughput_chart(history: MetricHistory) -> pd.DataFrame:
152
+ """Build throughput chart data."""
153
+ points = history.get("tokens_per_second", limit=60)
154
+
155
+ if not points:
156
+ return pd.DataFrame({"time": [], "value": []})
157
+
158
+ return pd.DataFrame([
159
+ {"time": p.timestamp, "value": p.value}
160
+ for p in points
161
+ ])
162
+
163
+
164
+ def _build_latency_chart(history: MetricHistory) -> pd.DataFrame:
165
+ """Build latency chart data with multiple series."""
166
+ ttft_points = history.get("ttft_ms", limit=60)
167
+ e2e_points = history.get("e2e_latency_ms", limit=60)
168
+
169
+ rows = []
170
+
171
+ for p in ttft_points:
172
+ rows.append({
173
+ "time": p.timestamp,
174
+ "value": p.value,
175
+ "metric": "TTFT",
176
+ })
177
+
178
+ for p in e2e_points:
179
+ rows.append({
180
+ "time": p.timestamp,
181
+ "value": p.value,
182
+ "metric": "E2E",
183
+ })
184
+
185
+ if not rows:
186
+ return pd.DataFrame({"time": [], "value": [], "metric": []})
187
+
188
+ return pd.DataFrame(rows)
189
+
190
+
191
+ def get_metrics_dict(metrics: InferenceMetrics) -> Dict[str, float]:
192
+ """
193
+ Convert metrics to dictionary for alerting.
194
+
195
+ Args:
196
+ metrics: InferenceMetrics instance
197
+
198
+ Returns:
199
+ Dictionary of metric name to value
200
+ """
201
+ return {
202
+ "tokens_per_second": metrics.tokens_per_second,
203
+ "ttft_ms": metrics.ttft_ms,
204
+ "e2e_latency_ms": metrics.e2e_latency_ms,
205
+ "kv_cache_percent": metrics.kv_cache_usage_percent,
206
+ "batch_size": metrics.batch_size,
207
+ "queue_depth": metrics.num_requests_waiting,
208
+ "gpu_cache_percent": metrics.gpu_cache_usage_percent,
209
+ }
components/loading_panel.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading progress panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from typing import Dict, Any, Tuple
6
+
7
+ from collectors.loading_tracker import LoadingTracker, LoadingStatus
8
+
9
+
10
+ def create_loading_panel() -> Dict[str, Any]:
11
+ """
12
+ Create the loading progress panel.
13
+
14
+ Returns:
15
+ Dictionary of Gradio components
16
+ """
17
+ with gr.Column():
18
+ gr.Markdown("### Model Loading Progress")
19
+
20
+ # Status indicator
21
+ loading_status = gr.HTML(
22
+ value=_build_status_html(LoadingStatus.NOT_STARTED),
23
+ )
24
+
25
+ # Progress bar
26
+ loading_progress = gr.Slider(
27
+ label="Loading Progress",
28
+ minimum=0,
29
+ maximum=100,
30
+ value=0,
31
+ interactive=False,
32
+ )
33
+
34
+ with gr.Row():
35
+ shards_loaded = gr.Textbox(
36
+ label="Shards Loaded",
37
+ value="0 / 0",
38
+ interactive=False,
39
+ )
40
+ layers_loaded = gr.Textbox(
41
+ label="Layers Loaded",
42
+ value="0 / 0",
43
+ interactive=False,
44
+ )
45
+ eta = gr.Textbox(
46
+ label="ETA",
47
+ value="-",
48
+ interactive=False,
49
+ )
50
+
51
+ # Shard details table
52
+ gr.Markdown("#### Shard Details")
53
+ shard_table = gr.Dataframe(
54
+ headers=["Shard", "Size (MB)", "Status", "Layers"],
55
+ datatype=["str", "number", "str", "str"],
56
+ label="Shards",
57
+ interactive=False,
58
+ )
59
+
60
+ return {
61
+ "loading_status": loading_status,
62
+ "loading_progress": loading_progress,
63
+ "shards_loaded": shards_loaded,
64
+ "layers_loaded": layers_loaded,
65
+ "eta": eta,
66
+ "shard_table": shard_table,
67
+ }
68
+
69
+
70
+ def update_loading_panel(
71
+ tracker: LoadingTracker,
72
+ ) -> Tuple[str, float, str, str, str, pd.DataFrame]:
73
+ """
74
+ Update the loading panel with current data.
75
+
76
+ Args:
77
+ tracker: Loading tracker instance
78
+
79
+ Returns:
80
+ Tuple of (status_html, progress, shards_text, layers_text, eta_text, shard_table)
81
+ """
82
+ progress = tracker.get_progress()
83
+ shards = tracker.get_shards()
84
+
85
+ # Build status HTML
86
+ status_html = _build_status_html(progress.status)
87
+
88
+ # Build shard table
89
+ shard_rows = []
90
+ for shard in shards[:20]:
91
+ shard_rows.append({
92
+ "Shard": shard.filename,
93
+ "Size (MB)": round(shard.size_mb, 1),
94
+ "Status": _format_shard_status(shard.status),
95
+ "Layers": str(len(shard.layers)),
96
+ })
97
+
98
+ shard_df = pd.DataFrame(shard_rows) if shard_rows else pd.DataFrame(
99
+ columns=["Shard", "Size (MB)", "Status", "Layers"]
100
+ )
101
+
102
+ # Format text values
103
+ shards_text = f"{progress.loaded_shards} / {progress.total_shards}"
104
+ layers_text = f"{progress.layers_loaded} / {progress.total_layers}"
105
+
106
+ if progress.estimated_remaining_seconds:
107
+ minutes = int(progress.estimated_remaining_seconds // 60)
108
+ seconds = int(progress.estimated_remaining_seconds % 60)
109
+ eta_text = f"{minutes}m {seconds}s"
110
+ else:
111
+ eta_text = "-"
112
+
113
+ return (
114
+ status_html,
115
+ progress.progress_percent,
116
+ shards_text,
117
+ layers_text,
118
+ eta_text,
119
+ shard_df,
120
+ )
121
+
122
+
123
+ def _build_status_html(status: LoadingStatus) -> str:
124
+ """Build HTML for loading status."""
125
+ status_configs = {
126
+ LoadingStatus.NOT_STARTED: ("Not Started", "#9e9e9e", "#fafafa"),
127
+ LoadingStatus.DOWNLOADING: ("Downloading", "#1976d2", "#e3f2fd"),
128
+ LoadingStatus.LOADING: ("Loading", "#ff9800", "#fff3e0"),
129
+ LoadingStatus.READY: ("Ready", "#2e7d32", "#e8f5e9"),
130
+ LoadingStatus.ERROR: ("Error", "#c62828", "#ffebee"),
131
+ }
132
+
133
+ text, color, bg_color = status_configs.get(
134
+ status, ("Unknown", "#9e9e9e", "#fafafa")
135
+ )
136
+
137
+ return (
138
+ f'<div style="padding: 10px; background: {bg_color}; '
139
+ f'border-radius: 5px; text-align: center;">'
140
+ f'<span style="color: {color}; font-weight: bold; font-size: 1.2em;">'
141
+ f'{text}</span></div>'
142
+ )
143
+
144
+
145
+ def _format_shard_status(status: str) -> str:
146
+ """Format shard status with indicator."""
147
+ if status == "loaded":
148
+ return "Loaded"
149
+ if status == "loading":
150
+ return "Loading..."
151
+ return "Pending"
components/loadtest_panel.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load testing panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import asyncio
6
+ from typing import Dict, Any, Optional
7
+
8
+ from services.load_tester import LoadTester, LoadTestConfig
9
+ from storage.models import LoadTestResult
10
+
11
+
12
+ # Global load tester instance (managed by the panel)
13
+ _active_load_tester: Optional[LoadTester] = None
14
+
15
+
16
+ def create_loadtest_panel() -> Dict[str, Any]:
17
+ """
18
+ Create the load testing panel.
19
+
20
+ Returns:
21
+ Dictionary of Gradio components
22
+ """
23
+ global _active_load_tester
24
+
25
+ with gr.Column():
26
+ gr.Markdown("### Load Testing")
27
+
28
+ # Configuration
29
+ with gr.Row():
30
+ target_endpoint = gr.Textbox(
31
+ label="Target Endpoint",
32
+ value="http://localhost:8000",
33
+ placeholder="http://host:port",
34
+ )
35
+
36
+ with gr.Row():
37
+ concurrent_users = gr.Slider(
38
+ label="Concurrent Users",
39
+ minimum=1,
40
+ maximum=100,
41
+ value=10,
42
+ step=1,
43
+ )
44
+ requests_per_second = gr.Slider(
45
+ label="Requests/Second",
46
+ minimum=0.1,
47
+ maximum=50,
48
+ value=5,
49
+ step=0.5,
50
+ )
51
+ duration = gr.Slider(
52
+ label="Duration (seconds)",
53
+ minimum=10,
54
+ maximum=300,
55
+ value=60,
56
+ step=10,
57
+ )
58
+
59
+ with gr.Row():
60
+ prompt_distribution = gr.Dropdown(
61
+ choices=["fixed", "realistic", "random"],
62
+ value="fixed",
63
+ label="Prompt Distribution",
64
+ )
65
+ max_tokens = gr.Slider(
66
+ label="Max Tokens",
67
+ minimum=10,
68
+ maximum=500,
69
+ value=100,
70
+ step=10,
71
+ )
72
+
73
+ # Control buttons
74
+ with gr.Row():
75
+ start_btn = gr.Button("Start Load Test", variant="primary")
76
+ stop_btn = gr.Button("Stop", variant="stop")
77
+
78
+ # Status
79
+ test_status = gr.Textbox(
80
+ label="Status",
81
+ interactive=False,
82
+ )
83
+
84
+ # Progress
85
+ with gr.Row():
86
+ progress_elapsed = gr.Number(label="Elapsed (s)", precision=0, interactive=False)
87
+ progress_requests = gr.Number(label="Total Requests", precision=0, interactive=False)
88
+ progress_success = gr.Number(label="Successful", precision=0, interactive=False)
89
+ progress_failed = gr.Number(label="Failed", precision=0, interactive=False)
90
+
91
+ # Results
92
+ gr.Markdown("### Results")
93
+ with gr.Row():
94
+ result_avg = gr.Number(label="Avg Latency (ms)", precision=1, interactive=False)
95
+ result_p50 = gr.Number(label="P50 (ms)", precision=1, interactive=False)
96
+ result_p95 = gr.Number(label="P95 (ms)", precision=1, interactive=False)
97
+ result_p99 = gr.Number(label="P99 (ms)", precision=1, interactive=False)
98
+
99
+ with gr.Row():
100
+ result_throughput = gr.Number(label="Throughput (req/s)", precision=2, interactive=False)
101
+ result_saturation = gr.Number(label="Saturation Point", precision=1, interactive=False)
102
+
103
+ # Latency over time chart
104
+ latency_chart = gr.LinePlot(
105
+ x="time",
106
+ y="latency_ms",
107
+ title="Latency Over Time",
108
+ x_title="Time",
109
+ y_title="Latency (ms)",
110
+ height=250,
111
+ )
112
+
113
+ # Event handlers
114
+ async def start_load_test(endpoint, users, rps, dur, dist, max_tok):
115
+ global _active_load_tester
116
+
117
+ config = LoadTestConfig(
118
+ target_endpoint=endpoint,
119
+ concurrent_users=int(users),
120
+ requests_per_second=rps,
121
+ duration_seconds=int(dur),
122
+ prompt_length_distribution=dist,
123
+ max_tokens=int(max_tok),
124
+ )
125
+
126
+ _active_load_tester = LoadTester(config)
127
+
128
+ # Initial status
129
+ yield (
130
+ "Starting load test...",
131
+ 0, 0, 0, 0,
132
+ 0, 0, 0, 0, 0, None,
133
+ pd.DataFrame({"time": [], "latency_ms": []}),
134
+ )
135
+
136
+ try:
137
+ result = await _active_load_tester.run()
138
+ yield format_load_test_results(result, _active_load_tester)
139
+ except Exception as e:
140
+ yield (
141
+ f"Error: {str(e)}",
142
+ 0, 0, 0, 0,
143
+ 0, 0, 0, 0, 0, None,
144
+ pd.DataFrame({"time": [], "latency_ms": []}),
145
+ )
146
+
147
+ def stop_load_test():
148
+ global _active_load_tester
149
+ if _active_load_tester:
150
+ _active_load_tester.stop()
151
+ return "Load test stopped"
152
+ return "No active load test"
153
+
154
+ start_btn.click(
155
+ fn=start_load_test,
156
+ inputs=[
157
+ target_endpoint, concurrent_users, requests_per_second,
158
+ duration, prompt_distribution, max_tokens
159
+ ],
160
+ outputs=[
161
+ test_status,
162
+ progress_elapsed, progress_requests, progress_success, progress_failed,
163
+ result_avg, result_p50, result_p95, result_p99, result_throughput, result_saturation,
164
+ latency_chart,
165
+ ],
166
+ )
167
+
168
+ stop_btn.click(
169
+ fn=stop_load_test,
170
+ outputs=test_status,
171
+ )
172
+
173
+ return {
174
+ "target_endpoint": target_endpoint,
175
+ "concurrent_users": concurrent_users,
176
+ "requests_per_second": requests_per_second,
177
+ "duration": duration,
178
+ "prompt_distribution": prompt_distribution,
179
+ "max_tokens": max_tokens,
180
+ "test_status": test_status,
181
+ "progress_elapsed": progress_elapsed,
182
+ "progress_requests": progress_requests,
183
+ "progress_success": progress_success,
184
+ "progress_failed": progress_failed,
185
+ "result_avg": result_avg,
186
+ "result_p50": result_p50,
187
+ "result_p95": result_p95,
188
+ "result_p99": result_p99,
189
+ "result_throughput": result_throughput,
190
+ "result_saturation": result_saturation,
191
+ "latency_chart": latency_chart,
192
+ }
193
+
194
+
195
+ def format_load_test_results(
196
+ result: LoadTestResult,
197
+ tester: LoadTester,
198
+ ) -> tuple:
199
+ """Format load test results for UI components."""
200
+ # Build latency timeseries
201
+ timeseries = tester.get_latency_timeseries()
202
+ if timeseries:
203
+ latency_df = pd.DataFrame(timeseries)
204
+ else:
205
+ latency_df = pd.DataFrame({"time": [], "latency_ms": []})
206
+
207
+ return (
208
+ f"Load test complete: {result.total_requests} requests",
209
+ result.duration_seconds,
210
+ result.total_requests,
211
+ result.successful_requests,
212
+ result.failed_requests,
213
+ result.avg_latency_ms,
214
+ result.p50_latency_ms,
215
+ result.p95_latency_ms,
216
+ result.p99_latency_ms,
217
+ result.throughput_rps,
218
+ result.saturation_point,
219
+ latency_df,
220
+ )
components/quant_panel.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quantization details panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from typing import Dict, Any, Tuple, Optional
6
+
7
+ from collectors.quant_collector import QuantizationCollector, QuantizationInfo
8
+
9
+
10
+ def create_quant_panel() -> Dict[str, Any]:
11
+ """
12
+ Create the quantization details panel.
13
+
14
+ Returns:
15
+ Dictionary of Gradio components
16
+ """
17
+ with gr.Column():
18
+ gr.Markdown("### Quantization Details")
19
+
20
+ with gr.Row():
21
+ quant_type = gr.Textbox(
22
+ label="Quantization Method",
23
+ interactive=False,
24
+ )
25
+ bits = gr.Number(
26
+ label="Bits",
27
+ precision=0,
28
+ interactive=False,
29
+ )
30
+ group_size = gr.Number(
31
+ label="Group Size",
32
+ precision=0,
33
+ interactive=False,
34
+ )
35
+
36
+ # Full configuration JSON
37
+ quant_details = gr.JSON(
38
+ label="Full Configuration",
39
+ )
40
+
41
+ # Layer precision table
42
+ gr.Markdown("#### Per-Layer Precision")
43
+ layer_table = gr.Dataframe(
44
+ headers=["Layer", "Bits", "Group Size", "Dtype"],
45
+ datatype=["str", "number", "str", "str"],
46
+ label="Layer Precisions",
47
+ interactive=False,
48
+ )
49
+
50
+ return {
51
+ "quant_type": quant_type,
52
+ "bits": bits,
53
+ "group_size": group_size,
54
+ "quant_details": quant_details,
55
+ "layer_table": layer_table,
56
+ }
57
+
58
+
59
+ def update_quant_panel(
60
+ collector: QuantizationCollector,
61
+ ) -> Tuple[str, int, Optional[int], Dict, pd.DataFrame]:
62
+ """
63
+ Update the quantization panel with current data.
64
+
65
+ Args:
66
+ collector: Quantization collector instance
67
+
68
+ Returns:
69
+ Tuple of (method, bits, group_size, details_json, layer_table)
70
+ """
71
+ info = collector.detect()
72
+ layers = collector.get_layer_precisions()
73
+
74
+ # Build layer table
75
+ layer_rows = []
76
+ for layer in layers[:20]: # Limit to 20 rows
77
+ layer_rows.append({
78
+ "Layer": layer.layer_name,
79
+ "Bits": layer.bits,
80
+ "Group Size": str(layer.group_size) if layer.group_size else "-",
81
+ "Dtype": layer.dtype,
82
+ })
83
+
84
+ layer_df = pd.DataFrame(layer_rows) if layer_rows else pd.DataFrame(
85
+ columns=["Layer", "Bits", "Group Size", "Dtype"]
86
+ )
87
+
88
+ return (
89
+ info.method,
90
+ info.bits,
91
+ info.group_size,
92
+ info.to_dict(),
93
+ layer_df,
94
+ )
95
+
96
+
97
+ def get_quant_summary(info: QuantizationInfo) -> str:
98
+ """
99
+ Get a summary string for the quantization.
100
+
101
+ Args:
102
+ info: QuantizationInfo instance
103
+
104
+ Returns:
105
+ Human-readable summary string
106
+ """
107
+ if info.method == "None (FP16/BF16)":
108
+ return f"Full precision ({info.compute_dtype or 'float16'})"
109
+
110
+ summary = f"{info.method} {info.bits}-bit"
111
+
112
+ if info.group_size:
113
+ summary += f", group size {info.group_size}"
114
+
115
+ if info.quant_type:
116
+ summary += f" ({info.quant_type})"
117
+
118
+ return summary
components/tracing_panel.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Request tracing panel component."""
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from typing import Dict, Any, Tuple
6
+
7
+ from services.request_tracer import RequestTracer
8
+
9
+
10
+ def create_tracing_panel(tracer: RequestTracer) -> Dict[str, Any]:
11
+ """
12
+ Create the request tracing panel.
13
+
14
+ Args:
15
+ tracer: Request tracer instance
16
+
17
+ Returns:
18
+ Dictionary of Gradio components
19
+ """
20
+ with gr.Column():
21
+ gr.Markdown("### Request Tracing")
22
+
23
+ # Filter controls
24
+ with gr.Row():
25
+ trace_filter = gr.Dropdown(
26
+ choices=["All Requests", "Slow Only"],
27
+ value="All Requests",
28
+ label="Filter",
29
+ )
30
+ trace_limit = gr.Slider(
31
+ minimum=10,
32
+ maximum=500,
33
+ value=100,
34
+ step=10,
35
+ label="Show Last N Requests",
36
+ )
37
+ refresh_btn = gr.Button("Refresh", size="sm")
38
+
39
+ # Summary stats
40
+ with gr.Row():
41
+ total_requests = gr.Number(
42
+ label="Total Requests",
43
+ precision=0,
44
+ interactive=False,
45
+ )
46
+ slow_requests = gr.Number(
47
+ label="Slow Requests",
48
+ precision=0,
49
+ interactive=False,
50
+ )
51
+ slow_rate = gr.Number(
52
+ label="Slow Rate %",
53
+ precision=1,
54
+ interactive=False,
55
+ )
56
+ baseline_p95 = gr.Number(
57
+ label="Baseline P95 (ms)",
58
+ precision=1,
59
+ interactive=False,
60
+ )
61
+
62
+ # Traces table
63
+ traces_table = gr.Dataframe(
64
+ headers=[
65
+ "ID", "Prompt Toks", "Output Toks",
66
+ "Queue (ms)", "Prefill (ms)", "Decode (ms)",
67
+ "Total (ms)", "Tok/s", "Slow?"
68
+ ],
69
+ datatype=[
70
+ "str", "number", "number",
71
+ "number", "number", "number",
72
+ "number", "number", "str"
73
+ ],
74
+ label="Request Traces",
75
+ interactive=False,
76
+ )
77
+
78
+ # Latency breakdown chart
79
+ gr.Markdown("#### Average Latency Breakdown")
80
+ latency_breakdown = gr.BarPlot(
81
+ x="phase",
82
+ y="ms",
83
+ title="Latency by Phase",
84
+ x_title="Phase",
85
+ y_title="Time (ms)",
86
+ height=200,
87
+ )
88
+
89
+ # Percentiles
90
+ gr.Markdown("#### Latency Percentiles")
91
+ with gr.Row():
92
+ p50 = gr.Number(label="P50 (ms)", precision=1, interactive=False)
93
+ p95 = gr.Number(label="P95 (ms)", precision=1, interactive=False)
94
+ p99 = gr.Number(label="P99 (ms)", precision=1, interactive=False)
95
+
96
+ # Event handlers
97
+ def refresh_traces(filter_val, limit):
98
+ slow_only = filter_val == "Slow Only"
99
+ return update_tracing_panel(tracer, slow_only, int(limit))
100
+
101
+ refresh_btn.click(
102
+ fn=refresh_traces,
103
+ inputs=[trace_filter, trace_limit],
104
+ outputs=[
105
+ total_requests, slow_requests, slow_rate, baseline_p95,
106
+ traces_table, latency_breakdown, p50, p95, p99
107
+ ],
108
+ )
109
+
110
+ return {
111
+ "trace_filter": trace_filter,
112
+ "trace_limit": trace_limit,
113
+ "total_requests": total_requests,
114
+ "slow_requests": slow_requests,
115
+ "slow_rate": slow_rate,
116
+ "baseline_p95": baseline_p95,
117
+ "traces_table": traces_table,
118
+ "latency_breakdown": latency_breakdown,
119
+ "p50": p50,
120
+ "p95": p95,
121
+ "p99": p99,
122
+ }
123
+
124
+
125
+ def update_tracing_panel(
126
+ tracer: RequestTracer,
127
+ slow_only: bool = False,
128
+ limit: int = 100,
129
+ ) -> Tuple[int, int, float, float, pd.DataFrame, pd.DataFrame, float, float, float]:
130
+ """
131
+ Update the tracing panel with current data.
132
+
133
+ Args:
134
+ tracer: Request tracer instance
135
+ slow_only: Only show slow requests
136
+ limit: Maximum number of traces to show
137
+
138
+ Returns:
139
+ Tuple of all component values
140
+ """
141
+ stats = tracer.get_stats()
142
+ traces = tracer.get_recent_traces(limit=limit, slow_only=slow_only)
143
+ breakdown = tracer.get_latency_breakdown()
144
+ percentiles = tracer.get_percentiles()
145
+
146
+ # Build traces table
147
+ trace_rows = []
148
+ for trace in reversed(traces): # Most recent first
149
+ trace_rows.append({
150
+ "ID": trace.request_id,
151
+ "Prompt Toks": trace.prompt_tokens,
152
+ "Output Toks": trace.output_tokens,
153
+ "Queue (ms)": round(trace.queue_time_ms, 1),
154
+ "Prefill (ms)": round(trace.prefill_time_ms, 1),
155
+ "Decode (ms)": round(trace.decode_time_ms, 1),
156
+ "Total (ms)": round(trace.total_time_ms, 1),
157
+ "Tok/s": round(trace.tokens_per_second, 1),
158
+ "Slow?": "Yes" if trace.is_slow else "",
159
+ })
160
+
161
+ traces_df = pd.DataFrame(trace_rows) if trace_rows else pd.DataFrame(
162
+ columns=[
163
+ "ID", "Prompt Toks", "Output Toks",
164
+ "Queue (ms)", "Prefill (ms)", "Decode (ms)",
165
+ "Total (ms)", "Tok/s", "Slow?"
166
+ ]
167
+ )
168
+
169
+ # Build breakdown chart
170
+ breakdown_df = pd.DataFrame([
171
+ {"phase": "Queue", "ms": breakdown.queue_ms},
172
+ {"phase": "Prefill", "ms": breakdown.prefill_ms},
173
+ {"phase": "Decode", "ms": breakdown.decode_ms},
174
+ ])
175
+
176
+ return (
177
+ stats["total_requests"],
178
+ stats["slow_requests"],
179
+ stats.get("slow_rate_percent", 0),
180
+ stats.get("baseline_p95", 0) or 0,
181
+ traces_df,
182
+ breakdown_df,
183
+ percentiles["p50"],
184
+ percentiles["p95"],
185
+ percentiles["p99"],
186
+ )
config.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration settings for LLM Inference Dashboard."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ import os
6
+
7
+
8
+ @dataclass
9
+ class Config:
10
+ """Dashboard configuration with sensible defaults."""
11
+
12
+ # vLLM Connection
13
+ vllm_host: str = "localhost"
14
+ vllm_port: int = 8000
15
+ model_path: Optional[str] = None
16
+
17
+ # Dashboard
18
+ refresh_interval: float = 1.0
19
+ history_length: int = 300 # 5 minutes at 1s intervals
20
+
21
+ # Database
22
+ db_path: str = "data/metrics.db"
23
+
24
+ # Alert Thresholds
25
+ alert_kv_cache_threshold: float = 90.0
26
+ alert_gpu_memory_threshold: float = 95.0
27
+ alert_ttft_multiplier: float = 2.0
28
+ alert_throughput_drop_pct: float = 50.0
29
+
30
+ # Webhooks
31
+ slack_webhook: Optional[str] = None
32
+ pagerduty_routing_key: Optional[str] = None
33
+ generic_webhooks: list = field(default_factory=list)
34
+
35
+ # Load Testing Defaults
36
+ loadtest_concurrent_users: int = 10
37
+ loadtest_rps: float = 5.0
38
+ loadtest_duration: int = 60
39
+
40
+ @property
41
+ def metrics_endpoint(self) -> str:
42
+ return f"http://{self.vllm_host}:{self.vllm_port}/metrics"
43
+
44
+ @property
45
+ def openai_endpoint(self) -> str:
46
+ return f"http://{self.vllm_host}:{self.vllm_port}/v1"
47
+
48
+ @property
49
+ def health_endpoint(self) -> str:
50
+ return f"http://{self.vllm_host}:{self.vllm_port}/health"
51
+
52
+ @classmethod
53
+ def from_env(cls) -> "Config":
54
+ """Create config from environment variables."""
55
+ return cls(
56
+ vllm_host=os.getenv("VLLM_HOST", "localhost"),
57
+ vllm_port=int(os.getenv("VLLM_PORT", "8000")),
58
+ model_path=os.getenv("MODEL_PATH"),
59
+ refresh_interval=float(os.getenv("REFRESH_INTERVAL", "1.0")),
60
+ db_path=os.getenv("DB_PATH", "data/metrics.db"),
61
+ slack_webhook=os.getenv("SLACK_WEBHOOK"),
62
+ pagerduty_routing_key=os.getenv("PAGERDUTY_KEY"),
63
+ )
64
+
65
+
66
+ # Global config instance
67
+ config = Config.from_env()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ gradio>=5.0.0
3
+ requests>=2.28.0
4
+ aiohttp>=3.9.0
5
+
6
+ # Data processing
7
+ pandas>=2.0.0
8
+ numpy>=1.24.0
9
+ scipy>=1.11.0
10
+
11
+ # Model utilities
12
+ safetensors>=0.4.0
13
+ huggingface-hub>=0.20.0
14
+
15
+ # GPU monitoring (optional - will use mock data if unavailable)
16
+ nvidia-ml-py3>=7.352.0
services/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Services for alerting, tracing, comparison, and load testing."""
2
+
3
+ from .alerting import AlertEngine, AlertDispatcher, Alert
4
+ from .request_tracer import RequestTracer
5
+ from .comparator import ABComparator, DeploymentConfig, ComparisonResult
6
+ from .load_tester import LoadTester, LoadTestConfig
7
+
8
+ __all__ = [
9
+ "AlertEngine",
10
+ "AlertDispatcher",
11
+ "Alert",
12
+ "RequestTracer",
13
+ "ABComparator",
14
+ "DeploymentConfig",
15
+ "ComparisonResult",
16
+ "LoadTester",
17
+ "LoadTestConfig",
18
+ ]
services/alerting.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Alert engine and webhook dispatch for monitoring thresholds."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from typing import Dict, List, Optional, Any, Callable
8
+ from enum import Enum
9
+
10
+ import aiohttp
11
+
12
+ from storage.database import MetricsDB
13
+ from storage.models import AlertRecord
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class AlertSeverity(Enum):
19
+ INFO = "info"
20
+ WARNING = "warning"
21
+ CRITICAL = "critical"
22
+
23
+
24
+ @dataclass
25
+ class AlertRule:
26
+ """Configuration for an alert rule."""
27
+ name: str
28
+ metric: str
29
+ condition: str # >, <, >=, <=, ==
30
+ threshold: float
31
+ severity: AlertSeverity
32
+ message: str
33
+ # For dynamic thresholds
34
+ threshold_type: str = "static" # static, baseline_multiplier, baseline_percent
35
+ multiplier: float = 1.0
36
+ percent: float = 100.0
37
+ cooldown_seconds: int = 60
38
+
39
+
40
+ @dataclass
41
+ class Alert:
42
+ """A triggered alert instance."""
43
+ rule_name: str
44
+ metric: str
45
+ value: float
46
+ threshold: float
47
+ severity: AlertSeverity
48
+ message: str
49
+ timestamp: datetime = field(default_factory=datetime.now)
50
+ resolved: bool = False
51
+
52
+ def to_dict(self) -> Dict[str, Any]:
53
+ return {
54
+ "rule_name": self.rule_name,
55
+ "metric": self.metric,
56
+ "value": self.value,
57
+ "threshold": self.threshold,
58
+ "severity": self.severity.value,
59
+ "message": self.message,
60
+ "timestamp": self.timestamp.isoformat(),
61
+ "resolved": self.resolved,
62
+ }
63
+
64
+
65
+ # Default alert rules
66
+ DEFAULT_RULES = {
67
+ "kv_cache_high": AlertRule(
68
+ name="kv_cache_high",
69
+ metric="kv_cache_percent",
70
+ condition=">",
71
+ threshold=90.0,
72
+ severity=AlertSeverity.WARNING,
73
+ message="KV cache utilization above 90%",
74
+ ),
75
+ "gpu_memory_critical": AlertRule(
76
+ name="gpu_memory_critical",
77
+ metric="gpu_memory_percent",
78
+ condition=">",
79
+ threshold=95.0,
80
+ severity=AlertSeverity.CRITICAL,
81
+ message="GPU memory critically high (>95%)",
82
+ ),
83
+ "ttft_spike": AlertRule(
84
+ name="ttft_spike",
85
+ metric="ttft_ms",
86
+ condition=">",
87
+ threshold=0, # Dynamic
88
+ threshold_type="baseline_multiplier",
89
+ multiplier=2.0,
90
+ severity=AlertSeverity.WARNING,
91
+ message="Time to first token spiked to 2x baseline",
92
+ ),
93
+ "throughput_drop": AlertRule(
94
+ name="throughput_drop",
95
+ metric="tokens_per_second",
96
+ condition="<",
97
+ threshold=0, # Dynamic
98
+ threshold_type="baseline_percent",
99
+ percent=50.0,
100
+ severity=AlertSeverity.WARNING,
101
+ message="Throughput dropped below 50% of baseline",
102
+ ),
103
+ "queue_buildup": AlertRule(
104
+ name="queue_buildup",
105
+ metric="queue_depth",
106
+ condition=">",
107
+ threshold=50.0,
108
+ severity=AlertSeverity.WARNING,
109
+ message="Request queue depth exceeds 50",
110
+ ),
111
+ }
112
+
113
+
114
+ class AlertEngine:
115
+ """Evaluates metrics against alert rules."""
116
+
117
+ def __init__(self, db: Optional[MetricsDB] = None):
118
+ """
119
+ Initialize alert engine.
120
+
121
+ Args:
122
+ db: Optional database for persisting alerts
123
+ """
124
+ self.db = db
125
+ self.rules: Dict[str, AlertRule] = dict(DEFAULT_RULES)
126
+ self.active_alerts: Dict[str, Alert] = {}
127
+ self.baselines: Dict[str, float] = {}
128
+ self._last_trigger_times: Dict[str, datetime] = {}
129
+ self._callbacks: List[Callable[[Alert], None]] = []
130
+
131
+ def add_rule(self, rule: AlertRule) -> None:
132
+ """Add or update an alert rule."""
133
+ self.rules[rule.name] = rule
134
+
135
+ def remove_rule(self, name: str) -> None:
136
+ """Remove an alert rule."""
137
+ self.rules.pop(name, None)
138
+
139
+ def set_baseline(self, metric: str, value: float) -> None:
140
+ """Set baseline value for a metric."""
141
+ self.baselines[metric] = value
142
+
143
+ def update_baselines(self, metrics: Dict[str, float]) -> None:
144
+ """Update baseline values from current metrics."""
145
+ for metric, value in metrics.items():
146
+ if metric not in self.baselines and value > 0:
147
+ self.baselines[metric] = value
148
+
149
+ def on_alert(self, callback: Callable[[Alert], None]) -> None:
150
+ """Register callback for new alerts."""
151
+ self._callbacks.append(callback)
152
+
153
+ def evaluate(self, metrics: Dict[str, float]) -> List[Alert]:
154
+ """
155
+ Evaluate metrics against all rules.
156
+
157
+ Args:
158
+ metrics: Current metric values
159
+
160
+ Returns:
161
+ List of newly triggered alerts
162
+ """
163
+ new_alerts = []
164
+
165
+ for rule_name, rule in self.rules.items():
166
+ if rule.metric not in metrics:
167
+ continue
168
+
169
+ value = metrics[rule.metric]
170
+ threshold = self._get_threshold(rule)
171
+
172
+ if threshold is None:
173
+ continue
174
+
175
+ triggered = self._check_condition(value, rule.condition, threshold)
176
+
177
+ if triggered:
178
+ # Check cooldown
179
+ if rule_name in self._last_trigger_times:
180
+ elapsed = (
181
+ datetime.now() - self._last_trigger_times[rule_name]
182
+ ).total_seconds()
183
+ if elapsed < rule.cooldown_seconds:
184
+ continue
185
+
186
+ # Create alert
187
+ alert = Alert(
188
+ rule_name=rule_name,
189
+ metric=rule.metric,
190
+ value=value,
191
+ threshold=threshold,
192
+ severity=rule.severity,
193
+ message=rule.message,
194
+ )
195
+
196
+ self.active_alerts[rule_name] = alert
197
+ self._last_trigger_times[rule_name] = datetime.now()
198
+ new_alerts.append(alert)
199
+
200
+ # Persist to database
201
+ if self.db:
202
+ record = AlertRecord(
203
+ rule_name=rule_name,
204
+ severity=rule.severity.value,
205
+ metric_name=rule.metric,
206
+ value=value,
207
+ threshold=threshold,
208
+ message=rule.message,
209
+ )
210
+ self.db.insert_alert(record)
211
+
212
+ # Notify callbacks
213
+ for callback in self._callbacks:
214
+ try:
215
+ callback(alert)
216
+ except Exception as e:
217
+ logger.error(f"Alert callback error: {e}")
218
+
219
+ elif rule_name in self.active_alerts:
220
+ # Resolve alert
221
+ self.active_alerts[rule_name].resolved = True
222
+ del self.active_alerts[rule_name]
223
+
224
+ return new_alerts
225
+
226
+ def _get_threshold(self, rule: AlertRule) -> Optional[float]:
227
+ """Calculate threshold for a rule."""
228
+ if rule.threshold_type == "static":
229
+ return rule.threshold
230
+
231
+ baseline = self.baselines.get(rule.metric)
232
+ if baseline is None:
233
+ return None
234
+
235
+ if rule.threshold_type == "baseline_multiplier":
236
+ return baseline * rule.multiplier
237
+
238
+ if rule.threshold_type == "baseline_percent":
239
+ return baseline * (rule.percent / 100.0)
240
+
241
+ return rule.threshold
242
+
243
+ def _check_condition(
244
+ self, value: float, condition: str, threshold: float
245
+ ) -> bool:
246
+ """Check if condition is met."""
247
+ if condition == ">":
248
+ return value > threshold
249
+ if condition == ">=":
250
+ return value >= threshold
251
+ if condition == "<":
252
+ return value < threshold
253
+ if condition == "<=":
254
+ return value <= threshold
255
+ if condition == "==":
256
+ return abs(value - threshold) < 0.001
257
+ return False
258
+
259
+ def get_active_alerts(self) -> List[Alert]:
260
+ """Get all active (unresolved) alerts."""
261
+ return list(self.active_alerts.values())
262
+
263
+ def clear_alert(self, rule_name: str) -> None:
264
+ """Manually clear an alert."""
265
+ if rule_name in self.active_alerts:
266
+ del self.active_alerts[rule_name]
267
+
268
+
269
+ class AlertDispatcher:
270
+ """Dispatches alerts to external services."""
271
+
272
+ def __init__(
273
+ self,
274
+ slack_webhook: Optional[str] = None,
275
+ pagerduty_key: Optional[str] = None,
276
+ generic_webhooks: Optional[List[str]] = None,
277
+ ):
278
+ """
279
+ Initialize alert dispatcher.
280
+
281
+ Args:
282
+ slack_webhook: Slack incoming webhook URL
283
+ pagerduty_key: PagerDuty routing key
284
+ generic_webhooks: List of generic webhook URLs
285
+ """
286
+ self.slack_webhook = slack_webhook
287
+ self.pagerduty_key = pagerduty_key
288
+ self.generic_webhooks = generic_webhooks or []
289
+
290
+ async def dispatch(self, alert: Alert) -> None:
291
+ """
292
+ Dispatch alert to all configured services.
293
+
294
+ Args:
295
+ alert: Alert to dispatch
296
+ """
297
+ tasks = []
298
+
299
+ if self.slack_webhook:
300
+ tasks.append(self._send_slack(alert))
301
+
302
+ if self.pagerduty_key and alert.severity == AlertSeverity.CRITICAL:
303
+ tasks.append(self._send_pagerduty(alert))
304
+
305
+ for webhook in self.generic_webhooks:
306
+ tasks.append(self._send_generic(webhook, alert))
307
+
308
+ if tasks:
309
+ await asyncio.gather(*tasks, return_exceptions=True)
310
+
311
+ async def _send_slack(self, alert: Alert) -> None:
312
+ """Send alert to Slack."""
313
+ color = "danger" if alert.severity == AlertSeverity.CRITICAL else "warning"
314
+ emoji = "🚨" if alert.severity == AlertSeverity.CRITICAL else "⚠️"
315
+
316
+ payload = {
317
+ "text": f"{emoji} *{alert.severity.value.upper()}*: {alert.message}",
318
+ "attachments": [
319
+ {
320
+ "color": color,
321
+ "fields": [
322
+ {
323
+ "title": "Metric",
324
+ "value": alert.metric,
325
+ "short": True,
326
+ },
327
+ {
328
+ "title": "Value",
329
+ "value": f"{alert.value:.2f}",
330
+ "short": True,
331
+ },
332
+ {
333
+ "title": "Threshold",
334
+ "value": f"{alert.threshold:.2f}",
335
+ "short": True,
336
+ },
337
+ {
338
+ "title": "Time",
339
+ "value": alert.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
340
+ "short": True,
341
+ },
342
+ ],
343
+ }
344
+ ],
345
+ }
346
+
347
+ try:
348
+ async with aiohttp.ClientSession() as session:
349
+ async with session.post(
350
+ self.slack_webhook,
351
+ json=payload,
352
+ timeout=aiohttp.ClientTimeout(total=10),
353
+ ) as response:
354
+ if response.status != 200:
355
+ logger.error(f"Slack webhook failed: {response.status}")
356
+ except Exception as e:
357
+ logger.error(f"Error sending Slack alert: {e}")
358
+
359
+ async def _send_pagerduty(self, alert: Alert) -> None:
360
+ """Send alert to PagerDuty."""
361
+ payload = {
362
+ "routing_key": self.pagerduty_key,
363
+ "event_action": "trigger",
364
+ "dedup_key": f"llm-dashboard-{alert.rule_name}",
365
+ "payload": {
366
+ "summary": alert.message,
367
+ "severity": "critical",
368
+ "source": "llm-inference-dashboard",
369
+ "custom_details": {
370
+ "metric": alert.metric,
371
+ "value": alert.value,
372
+ "threshold": alert.threshold,
373
+ },
374
+ },
375
+ }
376
+
377
+ try:
378
+ async with aiohttp.ClientSession() as session:
379
+ async with session.post(
380
+ "https://events.pagerduty.com/v2/enqueue",
381
+ json=payload,
382
+ timeout=aiohttp.ClientTimeout(total=10),
383
+ ) as response:
384
+ if response.status != 202:
385
+ logger.error(f"PagerDuty failed: {response.status}")
386
+ except Exception as e:
387
+ logger.error(f"Error sending PagerDuty alert: {e}")
388
+
389
+ async def _send_generic(self, webhook_url: str, alert: Alert) -> None:
390
+ """Send alert to a generic webhook."""
391
+ payload = alert.to_dict()
392
+
393
+ try:
394
+ async with aiohttp.ClientSession() as session:
395
+ async with session.post(
396
+ webhook_url,
397
+ json=payload,
398
+ timeout=aiohttp.ClientTimeout(total=10),
399
+ ) as response:
400
+ if response.status >= 400:
401
+ logger.error(f"Webhook {webhook_url} failed: {response.status}")
402
+ except Exception as e:
403
+ logger.error(f"Error sending to webhook {webhook_url}: {e}")
404
+
405
+ async def send_test_alert(self) -> bool:
406
+ """Send a test alert to verify configuration."""
407
+ test_alert = Alert(
408
+ rule_name="test_alert",
409
+ metric="test_metric",
410
+ value=100.0,
411
+ threshold=50.0,
412
+ severity=AlertSeverity.INFO,
413
+ message="This is a test alert from LLM Inference Dashboard",
414
+ )
415
+
416
+ try:
417
+ await self.dispatch(test_alert)
418
+ return True
419
+ except Exception as e:
420
+ logger.error(f"Test alert failed: {e}")
421
+ return False
services/comparator.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A/B comparison of vLLM deployments."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional, Dict, List, Any
7
+ from datetime import datetime
8
+
9
+ import aiohttp
10
+ from scipy import stats
11
+
12
+ from utils.prometheus_parser import (
13
+ parse_prometheus_metrics,
14
+ get_metric_value,
15
+ get_histogram_quantile,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class DeploymentConfig:
23
+ """Configuration for a vLLM deployment."""
24
+ name: str
25
+ endpoint: str # Base URL (e.g., http://localhost:8000)
26
+ model_name: str = ""
27
+ quantization: str = ""
28
+
29
+ @property
30
+ def metrics_url(self) -> str:
31
+ return f"{self.endpoint}/metrics"
32
+
33
+
34
+ @dataclass
35
+ class DeploymentMetrics:
36
+ """Metrics collected from a deployment."""
37
+ endpoint: str
38
+ timestamp: datetime = field(default_factory=datetime.now)
39
+ connected: bool = False
40
+
41
+ # Throughput
42
+ tokens_per_second: float = 0.0
43
+ throughput_samples: List[float] = field(default_factory=list)
44
+
45
+ # Latency
46
+ ttft_ms: float = 0.0
47
+ tpot_ms: float = 0.0
48
+ e2e_latency_ms: float = 0.0
49
+ latency_samples: List[float] = field(default_factory=list)
50
+
51
+ # Resources
52
+ gpu_memory_gb: float = 0.0
53
+ kv_cache_percent: float = 0.0
54
+ batch_size: int = 0
55
+
56
+ # Model info
57
+ model_name: str = ""
58
+
59
+
60
+ @dataclass
61
+ class ComparisonResult:
62
+ """Result of comparing two deployments."""
63
+ deployment_a: DeploymentMetrics
64
+ deployment_b: DeploymentMetrics
65
+ timestamp: datetime = field(default_factory=datetime.now)
66
+
67
+ # Differences
68
+ throughput_diff_pct: float = 0.0
69
+ ttft_diff_pct: float = 0.0
70
+ latency_diff_pct: float = 0.0
71
+ memory_diff_gb: float = 0.0
72
+
73
+ # Statistical significance
74
+ throughput_significant: bool = False
75
+ latency_significant: bool = False
76
+ p_value_throughput: float = 1.0
77
+ p_value_latency: float = 1.0
78
+
79
+ # Recommendation
80
+ recommendation: str = ""
81
+
82
+
83
+ class ABComparator:
84
+ """Compares metrics between two vLLM deployments."""
85
+
86
+ def __init__(
87
+ self,
88
+ deployment_a: DeploymentConfig,
89
+ deployment_b: DeploymentConfig,
90
+ sample_count: int = 30,
91
+ ):
92
+ """
93
+ Initialize comparator.
94
+
95
+ Args:
96
+ deployment_a: First deployment configuration
97
+ deployment_b: Second deployment configuration
98
+ sample_count: Number of samples to collect for statistical tests
99
+ """
100
+ self.deployment_a = deployment_a
101
+ self.deployment_b = deployment_b
102
+ self.sample_count = sample_count
103
+ self._samples_a: List[DeploymentMetrics] = []
104
+ self._samples_b: List[DeploymentMetrics] = []
105
+
106
+ async def collect_metrics(self, config: DeploymentConfig) -> DeploymentMetrics:
107
+ """
108
+ Collect current metrics from a deployment.
109
+
110
+ Args:
111
+ config: Deployment configuration
112
+
113
+ Returns:
114
+ DeploymentMetrics with current values
115
+ """
116
+ metrics = DeploymentMetrics(endpoint=config.endpoint)
117
+
118
+ try:
119
+ async with aiohttp.ClientSession() as session:
120
+ async with session.get(
121
+ config.metrics_url,
122
+ timeout=aiohttp.ClientTimeout(total=5),
123
+ ) as response:
124
+ if response.status != 200:
125
+ return metrics
126
+
127
+ text = await response.text()
128
+ raw = parse_prometheus_metrics(text)
129
+ metrics.connected = True
130
+
131
+ # Parse metrics
132
+ metrics.tokens_per_second = self._calculate_tps(raw)
133
+ metrics.ttft_ms = (
134
+ get_histogram_quantile(
135
+ raw, "vllm:time_to_first_token_seconds", 0.5
136
+ )
137
+ or 0
138
+ ) * 1000
139
+ metrics.tpot_ms = (
140
+ get_histogram_quantile(
141
+ raw, "vllm:time_per_output_token_seconds", 0.5
142
+ )
143
+ or 0
144
+ ) * 1000
145
+ metrics.e2e_latency_ms = (
146
+ get_histogram_quantile(
147
+ raw, "vllm:e2e_request_latency_seconds", 0.5
148
+ )
149
+ or 0
150
+ ) * 1000
151
+ metrics.kv_cache_percent = (
152
+ get_metric_value(raw, "vllm:gpu_cache_usage_perc") or 0
153
+ ) * 100
154
+ metrics.batch_size = int(
155
+ get_metric_value(raw, "vllm:num_requests_running") or 0
156
+ )
157
+
158
+ # Model name from labels
159
+ for samples in raw.values():
160
+ for sample in samples:
161
+ if "model_name" in sample.labels:
162
+ metrics.model_name = sample.labels["model_name"]
163
+ break
164
+
165
+ except Exception as e:
166
+ logger.error(f"Error collecting metrics from {config.endpoint}: {e}")
167
+
168
+ return metrics
169
+
170
+ def _calculate_tps(self, raw: Dict) -> float:
171
+ """Calculate tokens per second from counter metrics."""
172
+ # This is a simplified calculation
173
+ # In practice, you'd track delta over time
174
+ generation_total = get_metric_value(raw, "vllm:generation_tokens_total") or 0
175
+ if generation_total > 0:
176
+ # Estimate based on running requests
177
+ running = get_metric_value(raw, "vllm:num_requests_running") or 1
178
+ tpot = (
179
+ get_histogram_quantile(
180
+ raw, "vllm:time_per_output_token_seconds", 0.5
181
+ )
182
+ or 0.05
183
+ )
184
+ if tpot > 0:
185
+ return running / tpot
186
+ return 0
187
+
188
+ async def collect_samples(self, count: Optional[int] = None) -> None:
189
+ """
190
+ Collect multiple samples for statistical comparison.
191
+
192
+ Args:
193
+ count: Number of samples to collect
194
+ """
195
+ if count is None:
196
+ count = self.sample_count
197
+
198
+ self._samples_a.clear()
199
+ self._samples_b.clear()
200
+
201
+ for i in range(count):
202
+ metrics_a, metrics_b = await asyncio.gather(
203
+ self.collect_metrics(self.deployment_a),
204
+ self.collect_metrics(self.deployment_b),
205
+ )
206
+
207
+ if metrics_a.connected:
208
+ metrics_a.throughput_samples = [metrics_a.tokens_per_second]
209
+ metrics_a.latency_samples = [metrics_a.e2e_latency_ms]
210
+ self._samples_a.append(metrics_a)
211
+
212
+ if metrics_b.connected:
213
+ metrics_b.throughput_samples = [metrics_b.tokens_per_second]
214
+ metrics_b.latency_samples = [metrics_b.e2e_latency_ms]
215
+ self._samples_b.append(metrics_b)
216
+
217
+ # Wait between samples
218
+ if i < count - 1:
219
+ await asyncio.sleep(1)
220
+
221
+ async def compare(self) -> ComparisonResult:
222
+ """
223
+ Perform comparison between deployments.
224
+
225
+ Returns:
226
+ ComparisonResult with comparison data
227
+ """
228
+ # Collect current metrics
229
+ metrics_a, metrics_b = await asyncio.gather(
230
+ self.collect_metrics(self.deployment_a),
231
+ self.collect_metrics(self.deployment_b),
232
+ )
233
+
234
+ result = ComparisonResult(
235
+ deployment_a=metrics_a,
236
+ deployment_b=metrics_b,
237
+ )
238
+
239
+ # Calculate differences
240
+ if metrics_a.tokens_per_second > 0:
241
+ result.throughput_diff_pct = (
242
+ (metrics_b.tokens_per_second - metrics_a.tokens_per_second)
243
+ / metrics_a.tokens_per_second
244
+ ) * 100
245
+
246
+ if metrics_a.ttft_ms > 0:
247
+ result.ttft_diff_pct = (
248
+ (metrics_b.ttft_ms - metrics_a.ttft_ms) / metrics_a.ttft_ms
249
+ ) * 100
250
+
251
+ if metrics_a.e2e_latency_ms > 0:
252
+ result.latency_diff_pct = (
253
+ (metrics_b.e2e_latency_ms - metrics_a.e2e_latency_ms)
254
+ / metrics_a.e2e_latency_ms
255
+ ) * 100
256
+
257
+ result.memory_diff_gb = metrics_b.gpu_memory_gb - metrics_a.gpu_memory_gb
258
+
259
+ # Statistical significance (if we have samples)
260
+ if self._samples_a and self._samples_b:
261
+ result = self._add_significance(result)
262
+
263
+ # Generate recommendation
264
+ result.recommendation = self._generate_recommendation(result)
265
+
266
+ return result
267
+
268
+ def _add_significance(self, result: ComparisonResult) -> ComparisonResult:
269
+ """Add statistical significance tests to result."""
270
+ tps_a = [s.tokens_per_second for s in self._samples_a]
271
+ tps_b = [s.tokens_per_second for s in self._samples_b]
272
+
273
+ lat_a = [s.e2e_latency_ms for s in self._samples_a]
274
+ lat_b = [s.e2e_latency_ms for s in self._samples_b]
275
+
276
+ if len(tps_a) >= 2 and len(tps_b) >= 2:
277
+ try:
278
+ _, p_tps = stats.ttest_ind(tps_a, tps_b)
279
+ result.p_value_throughput = p_tps
280
+ result.throughput_significant = p_tps < 0.05
281
+ except Exception:
282
+ pass
283
+
284
+ if len(lat_a) >= 2 and len(lat_b) >= 2:
285
+ try:
286
+ _, p_lat = stats.ttest_ind(lat_a, lat_b)
287
+ result.p_value_latency = p_lat
288
+ result.latency_significant = p_lat < 0.05
289
+ except Exception:
290
+ pass
291
+
292
+ return result
293
+
294
+ def _generate_recommendation(self, result: ComparisonResult) -> str:
295
+ """Generate a human-readable recommendation."""
296
+ parts = []
297
+ a_name = self.deployment_a.name
298
+ b_name = self.deployment_b.name
299
+
300
+ # Throughput comparison
301
+ if abs(result.throughput_diff_pct) > 5:
302
+ faster = b_name if result.throughput_diff_pct > 0 else a_name
303
+ diff = abs(result.throughput_diff_pct)
304
+ sig = " (statistically significant)" if result.throughput_significant else ""
305
+ parts.append(f"{faster} has {diff:.1f}% higher throughput{sig}")
306
+
307
+ # Latency comparison
308
+ if abs(result.latency_diff_pct) > 5:
309
+ faster = a_name if result.latency_diff_pct > 0 else b_name
310
+ diff = abs(result.latency_diff_pct)
311
+ sig = " (statistically significant)" if result.latency_significant else ""
312
+ parts.append(f"{faster} has {diff:.1f}% lower latency{sig}")
313
+
314
+ # Memory comparison
315
+ if abs(result.memory_diff_gb) > 1:
316
+ lower = a_name if result.memory_diff_gb > 0 else b_name
317
+ diff = abs(result.memory_diff_gb)
318
+ parts.append(f"{lower} uses {diff:.1f}GB less GPU memory")
319
+
320
+ if not parts:
321
+ return "Both deployments show similar performance"
322
+
323
+ return ". ".join(parts) + "."
324
+
325
+ def get_comparison_table(self, result: ComparisonResult) -> List[Dict[str, Any]]:
326
+ """
327
+ Generate comparison table data.
328
+
329
+ Args:
330
+ result: Comparison result
331
+
332
+ Returns:
333
+ List of rows for comparison table
334
+ """
335
+ return [
336
+ {
337
+ "Metric": "Throughput (tok/s)",
338
+ self.deployment_a.name: f"{result.deployment_a.tokens_per_second:.1f}",
339
+ self.deployment_b.name: f"{result.deployment_b.tokens_per_second:.1f}",
340
+ "Diff": f"{result.throughput_diff_pct:+.1f}%",
341
+ },
342
+ {
343
+ "Metric": "TTFT (ms)",
344
+ self.deployment_a.name: f"{result.deployment_a.ttft_ms:.1f}",
345
+ self.deployment_b.name: f"{result.deployment_b.ttft_ms:.1f}",
346
+ "Diff": f"{result.ttft_diff_pct:+.1f}%",
347
+ },
348
+ {
349
+ "Metric": "E2E Latency (ms)",
350
+ self.deployment_a.name: f"{result.deployment_a.e2e_latency_ms:.1f}",
351
+ self.deployment_b.name: f"{result.deployment_b.e2e_latency_ms:.1f}",
352
+ "Diff": f"{result.latency_diff_pct:+.1f}%",
353
+ },
354
+ {
355
+ "Metric": "KV Cache %",
356
+ self.deployment_a.name: f"{result.deployment_a.kv_cache_percent:.1f}",
357
+ self.deployment_b.name: f"{result.deployment_b.kv_cache_percent:.1f}",
358
+ "Diff": "-",
359
+ },
360
+ {
361
+ "Metric": "Batch Size",
362
+ self.deployment_a.name: str(result.deployment_a.batch_size),
363
+ self.deployment_b.name: str(result.deployment_b.batch_size),
364
+ "Diff": "-",
365
+ },
366
+ ]
services/load_tester.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load testing engine for vLLM endpoints."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import statistics
6
+ import time
7
+ import uuid
8
+ from dataclasses import dataclass, field
9
+ from datetime import datetime
10
+ from typing import List, Optional, Dict, Any, Callable
11
+ from collections import deque
12
+
13
+ import aiohttp
14
+ import numpy as np
15
+
16
+ from storage.database import MetricsDB
17
+ from storage.models import LoadTestResult
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class LoadTestConfig:
24
+ """Configuration for a load test."""
25
+ target_endpoint: str
26
+ concurrent_users: int = 10
27
+ requests_per_second: float = 5.0
28
+ duration_seconds: int = 60
29
+ prompt: str = "Hello, please write a short story about a robot."
30
+ max_tokens: int = 100
31
+ prompt_length_distribution: str = "fixed" # fixed, realistic, random
32
+
33
+
34
+ @dataclass
35
+ class RequestResult:
36
+ """Result of a single request."""
37
+ success: bool
38
+ latency_ms: float
39
+ tokens: int
40
+ error: Optional[str] = None
41
+ timestamp: datetime = field(default_factory=datetime.now)
42
+
43
+
44
+ class LoadTester:
45
+ """Load testing engine for vLLM inference endpoints."""
46
+
47
+ def __init__(self, config: LoadTestConfig, db: Optional[MetricsDB] = None):
48
+ """
49
+ Initialize load tester.
50
+
51
+ Args:
52
+ config: Load test configuration
53
+ db: Optional database for storing results
54
+ """
55
+ self.config = config
56
+ self.db = db
57
+ self.running = False
58
+ self._results: List[RequestResult] = []
59
+ self._latency_over_time: deque = deque(maxlen=10000)
60
+ self._progress_callback: Optional[Callable[[Dict], None]] = None
61
+ self._start_time: Optional[float] = None
62
+
63
+ def set_config(self, config: LoadTestConfig) -> None:
64
+ """Update configuration."""
65
+ self.config = config
66
+
67
+ def on_progress(self, callback: Callable[[Dict], None]) -> None:
68
+ """Register progress callback."""
69
+ self._progress_callback = callback
70
+
71
+ async def run(self) -> LoadTestResult:
72
+ """
73
+ Run the load test.
74
+
75
+ Returns:
76
+ LoadTestResult with test results
77
+ """
78
+ self.running = True
79
+ self._results = []
80
+ self._latency_over_time.clear()
81
+ self._start_time = time.time()
82
+
83
+ test_id = str(uuid.uuid4())[:8]
84
+
85
+ logger.info(
86
+ f"Starting load test {test_id}: "
87
+ f"{self.config.concurrent_users} users, "
88
+ f"{self.config.requests_per_second} RPS, "
89
+ f"{self.config.duration_seconds}s"
90
+ )
91
+
92
+ # Create semaphore for concurrency control
93
+ semaphore = asyncio.Semaphore(self.config.concurrent_users)
94
+
95
+ # Calculate request interval
96
+ interval = 1.0 / self.config.requests_per_second
97
+
98
+ # Generate load
99
+ tasks = []
100
+ end_time = time.time() + self.config.duration_seconds
101
+
102
+ try:
103
+ while time.time() < end_time and self.running:
104
+ async with semaphore:
105
+ task = asyncio.create_task(self._make_request())
106
+ tasks.append(task)
107
+
108
+ # Report progress
109
+ if self._progress_callback:
110
+ self._progress_callback(self._get_progress())
111
+
112
+ await asyncio.sleep(interval)
113
+
114
+ # Wait for remaining tasks
115
+ if tasks:
116
+ await asyncio.gather(*tasks, return_exceptions=True)
117
+
118
+ except asyncio.CancelledError:
119
+ logger.info("Load test cancelled")
120
+ except Exception as e:
121
+ logger.error(f"Load test error: {e}")
122
+ finally:
123
+ self.running = False
124
+
125
+ # Analyze results
126
+ result = self._analyze_results(test_id)
127
+
128
+ # Persist to database
129
+ if self.db:
130
+ try:
131
+ self.db.insert_load_test(result)
132
+ except Exception as e:
133
+ logger.error(f"Error persisting load test: {e}")
134
+
135
+ return result
136
+
137
+ def stop(self) -> None:
138
+ """Stop the running load test."""
139
+ self.running = False
140
+
141
+ async def _make_request(self) -> None:
142
+ """Make a single request to the target endpoint."""
143
+ prompt = self._generate_prompt()
144
+ start = time.perf_counter()
145
+ tokens = 0
146
+ error = None
147
+ success = False
148
+
149
+ try:
150
+ async with aiohttp.ClientSession() as session:
151
+ payload = {
152
+ "model": "default",
153
+ "messages": [{"role": "user", "content": prompt}],
154
+ "max_tokens": self.config.max_tokens,
155
+ "stream": False,
156
+ }
157
+
158
+ async with session.post(
159
+ f"{self.config.target_endpoint}/v1/chat/completions",
160
+ json=payload,
161
+ timeout=aiohttp.ClientTimeout(total=60),
162
+ ) as response:
163
+ if response.status == 200:
164
+ data = await response.json()
165
+ tokens = data.get("usage", {}).get("completion_tokens", 0)
166
+ success = True
167
+ else:
168
+ error = f"HTTP {response.status}"
169
+
170
+ except asyncio.TimeoutError:
171
+ error = "Timeout"
172
+ except Exception as e:
173
+ error = str(e)
174
+
175
+ latency = (time.perf_counter() - start) * 1000
176
+
177
+ result = RequestResult(
178
+ success=success,
179
+ latency_ms=latency,
180
+ tokens=tokens,
181
+ error=error,
182
+ )
183
+
184
+ self._results.append(result)
185
+ self._latency_over_time.append({
186
+ "time": datetime.now(),
187
+ "latency_ms": latency,
188
+ "success": success,
189
+ })
190
+
191
+ def _generate_prompt(self) -> str:
192
+ """Generate a prompt based on configuration."""
193
+ if self.config.prompt_length_distribution == "fixed":
194
+ return self.config.prompt
195
+
196
+ if self.config.prompt_length_distribution == "realistic":
197
+ # Simulate realistic prompt length distribution
198
+ prompts = [
199
+ "Hello!",
200
+ "Write a haiku about programming.",
201
+ "Explain quantum computing in simple terms.",
202
+ "Write a detailed technical analysis of transformer architectures and their impact on modern NLP systems.",
203
+ "Compare and contrast the approaches of different programming paradigms including object-oriented, functional, and procedural programming. Provide examples in Python for each.",
204
+ ]
205
+ import random
206
+ return random.choice(prompts)
207
+
208
+ if self.config.prompt_length_distribution == "random":
209
+ import random
210
+ words = ["the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog"]
211
+ length = random.randint(5, 100)
212
+ return " ".join(random.choices(words, k=length))
213
+
214
+ return self.config.prompt
215
+
216
+ def _get_progress(self) -> Dict[str, Any]:
217
+ """Get current progress."""
218
+ if not self._start_time:
219
+ return {}
220
+
221
+ elapsed = time.time() - self._start_time
222
+ successful = sum(1 for r in self._results if r.success)
223
+ latencies = [r.latency_ms for r in self._results if r.success]
224
+
225
+ return {
226
+ "elapsed_seconds": elapsed,
227
+ "total_requests": len(self._results),
228
+ "successful_requests": successful,
229
+ "failed_requests": len(self._results) - successful,
230
+ "avg_latency_ms": statistics.mean(latencies) if latencies else 0,
231
+ "current_rps": len(self._results) / elapsed if elapsed > 0 else 0,
232
+ }
233
+
234
+ def _analyze_results(self, test_id: str) -> LoadTestResult:
235
+ """Analyze test results."""
236
+ successful = [r for r in self._results if r.success]
237
+ failed = [r for r in self._results if not r.success]
238
+
239
+ latencies = [r.latency_ms for r in successful]
240
+
241
+ if not latencies:
242
+ return LoadTestResult(
243
+ test_id=test_id,
244
+ target_endpoint=self.config.target_endpoint,
245
+ concurrent_users=self.config.concurrent_users,
246
+ requests_per_second=self.config.requests_per_second,
247
+ duration_seconds=self.config.duration_seconds,
248
+ total_requests=len(self._results),
249
+ successful_requests=0,
250
+ failed_requests=len(failed),
251
+ avg_latency_ms=0,
252
+ p50_latency_ms=0,
253
+ p95_latency_ms=0,
254
+ p99_latency_ms=0,
255
+ throughput_rps=0,
256
+ )
257
+
258
+ # Calculate percentiles
259
+ sorted_latencies = sorted(latencies)
260
+ n = len(sorted_latencies)
261
+
262
+ p50 = sorted_latencies[int(n * 0.50)]
263
+ p95 = sorted_latencies[int(n * 0.95)]
264
+ p99 = sorted_latencies[min(int(n * 0.99), n - 1)]
265
+
266
+ # Calculate throughput
267
+ duration = self.config.duration_seconds
268
+ if self._start_time:
269
+ duration = time.time() - self._start_time
270
+
271
+ throughput = len(successful) / duration if duration > 0 else 0
272
+
273
+ # Detect saturation point
274
+ saturation = self._find_saturation_point()
275
+
276
+ return LoadTestResult(
277
+ test_id=test_id,
278
+ target_endpoint=self.config.target_endpoint,
279
+ concurrent_users=self.config.concurrent_users,
280
+ requests_per_second=self.config.requests_per_second,
281
+ duration_seconds=self.config.duration_seconds,
282
+ total_requests=len(self._results),
283
+ successful_requests=len(successful),
284
+ failed_requests=len(failed),
285
+ avg_latency_ms=statistics.mean(latencies),
286
+ p50_latency_ms=p50,
287
+ p95_latency_ms=p95,
288
+ p99_latency_ms=p99,
289
+ throughput_rps=throughput,
290
+ saturation_point=saturation,
291
+ )
292
+
293
+ def _find_saturation_point(self) -> Optional[float]:
294
+ """
295
+ Find the point where latency starts increasing dramatically.
296
+
297
+ Returns:
298
+ Request rate at saturation point, or None if not found
299
+ """
300
+ if len(self._latency_over_time) < 20:
301
+ return None
302
+
303
+ # Group latencies by time buckets
304
+ latencies = list(self._latency_over_time)
305
+ bucket_size = len(latencies) // 10
306
+
307
+ bucket_avgs = []
308
+ for i in range(0, len(latencies), bucket_size):
309
+ bucket = latencies[i : i + bucket_size]
310
+ if bucket:
311
+ avg = statistics.mean(r["latency_ms"] for r in bucket)
312
+ bucket_avgs.append(avg)
313
+
314
+ if len(bucket_avgs) < 3:
315
+ return None
316
+
317
+ # Look for significant increase (2x)
318
+ baseline = bucket_avgs[0]
319
+ for i, avg in enumerate(bucket_avgs):
320
+ if avg > baseline * 2:
321
+ # Estimate RPS at this point
322
+ elapsed = self.config.duration_seconds * (i / len(bucket_avgs))
323
+ return len(self._results) / elapsed if elapsed > 0 else None
324
+
325
+ return None
326
+
327
+ def get_latency_timeseries(self) -> List[Dict[str, Any]]:
328
+ """
329
+ Get latency over time for charting.
330
+
331
+ Returns:
332
+ List of {time, latency_ms} points
333
+ """
334
+ return [
335
+ {"time": p["time"], "latency_ms": p["latency_ms"]}
336
+ for p in self._latency_over_time
337
+ ]
338
+
339
+ def get_latency_histogram(self, bins: int = 20) -> Dict[str, Any]:
340
+ """
341
+ Get latency histogram data.
342
+
343
+ Args:
344
+ bins: Number of histogram bins
345
+
346
+ Returns:
347
+ Dictionary with bin edges and counts
348
+ """
349
+ latencies = [r.latency_ms for r in self._results if r.success]
350
+
351
+ if not latencies:
352
+ return {"bins": [], "counts": []}
353
+
354
+ counts, edges = np.histogram(latencies, bins=bins)
355
+
356
+ return {
357
+ "bins": [(edges[i] + edges[i + 1]) / 2 for i in range(len(counts))],
358
+ "counts": counts.tolist(),
359
+ }
services/request_tracer.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Request tracing and latency analysis."""
2
+
3
+ import logging
4
+ import uuid
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from typing import Dict, List, Optional, Any
8
+ from collections import deque
9
+ import statistics
10
+
11
+ from storage.database import MetricsDB
12
+ from storage.models import RequestTrace
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class LatencyBreakdown:
19
+ """Breakdown of request latency by phase."""
20
+ queue_ms: float
21
+ prefill_ms: float
22
+ decode_ms: float
23
+ total_ms: float
24
+
25
+ @property
26
+ def as_dict(self) -> Dict[str, float]:
27
+ return {
28
+ "queue": self.queue_ms,
29
+ "prefill": self.prefill_ms,
30
+ "decode": self.decode_ms,
31
+ "total": self.total_ms,
32
+ }
33
+
34
+
35
+ @dataclass
36
+ class TraceCorrelation:
37
+ """Correlation analysis for a trace."""
38
+ memory_pressure: bool
39
+ likely_cause: str
40
+ memory_delta_gb: float
41
+
42
+
43
+ class RequestTracer:
44
+ """Tracks and analyzes request latency."""
45
+
46
+ def __init__(self, db: Optional[MetricsDB] = None, p95_window: int = 100):
47
+ """
48
+ Initialize request tracer.
49
+
50
+ Args:
51
+ db: Optional database for persisting traces
52
+ p95_window: Number of recent requests for P95 calculation
53
+ """
54
+ self.db = db
55
+ self._traces: deque = deque(maxlen=1000)
56
+ self._latency_window: deque = deque(maxlen=p95_window)
57
+ self._baseline_p95: Optional[float] = None
58
+ self._slow_threshold_ms: Optional[float] = None
59
+
60
+ def record_trace(
61
+ self,
62
+ request_id: Optional[str] = None,
63
+ prompt_tokens: int = 0,
64
+ output_tokens: int = 0,
65
+ queue_time_ms: float = 0,
66
+ prefill_time_ms: float = 0,
67
+ decode_time_ms: float = 0,
68
+ total_time_ms: Optional[float] = None,
69
+ gpu_memory_start: float = 0,
70
+ gpu_memory_end: float = 0,
71
+ ) -> RequestTrace:
72
+ """
73
+ Record a request trace.
74
+
75
+ Args:
76
+ request_id: Unique request identifier
77
+ prompt_tokens: Number of prompt tokens
78
+ output_tokens: Number of output tokens
79
+ queue_time_ms: Time spent in queue
80
+ prefill_time_ms: Time for prefill/prompt processing
81
+ decode_time_ms: Time for token generation
82
+ total_time_ms: Total end-to-end time
83
+ gpu_memory_start: GPU memory at request start
84
+ gpu_memory_end: GPU memory at request end
85
+
86
+ Returns:
87
+ Created RequestTrace
88
+ """
89
+ if request_id is None:
90
+ request_id = str(uuid.uuid4())[:8]
91
+
92
+ if total_time_ms is None:
93
+ total_time_ms = queue_time_ms + prefill_time_ms + decode_time_ms
94
+
95
+ # Calculate tokens per second
96
+ tokens_per_sec = 0
97
+ if decode_time_ms > 0:
98
+ tokens_per_sec = (output_tokens / decode_time_ms) * 1000
99
+
100
+ # Determine if slow
101
+ is_slow = False
102
+ if self._slow_threshold_ms and total_time_ms > self._slow_threshold_ms:
103
+ is_slow = True
104
+
105
+ trace = RequestTrace(
106
+ request_id=request_id,
107
+ prompt_tokens=prompt_tokens,
108
+ output_tokens=output_tokens,
109
+ queue_time_ms=queue_time_ms,
110
+ prefill_time_ms=prefill_time_ms,
111
+ decode_time_ms=decode_time_ms,
112
+ total_time_ms=total_time_ms,
113
+ tokens_per_second=tokens_per_sec,
114
+ gpu_memory_at_start=gpu_memory_start,
115
+ gpu_memory_at_end=gpu_memory_end,
116
+ is_slow=is_slow,
117
+ )
118
+
119
+ # Store in memory
120
+ self._traces.append(trace)
121
+ self._latency_window.append(total_time_ms)
122
+
123
+ # Update P95 baseline
124
+ self._update_baseline()
125
+
126
+ # Persist to database
127
+ if self.db:
128
+ try:
129
+ self.db.insert_trace(trace)
130
+ except Exception as e:
131
+ logger.error(f"Error persisting trace: {e}")
132
+
133
+ # Log slow requests
134
+ if is_slow:
135
+ logger.warning(
136
+ f"Slow request {request_id}: {total_time_ms:.1f}ms "
137
+ f"(threshold: {self._slow_threshold_ms:.1f}ms)"
138
+ )
139
+
140
+ return trace
141
+
142
+ def _update_baseline(self) -> None:
143
+ """Update P95 baseline from recent requests."""
144
+ if len(self._latency_window) >= 10:
145
+ sorted_latencies = sorted(self._latency_window)
146
+ p95_idx = int(len(sorted_latencies) * 0.95)
147
+ self._baseline_p95 = sorted_latencies[p95_idx]
148
+ # Set slow threshold at 1.5x P95
149
+ self._slow_threshold_ms = self._baseline_p95 * 1.5
150
+
151
+ def get_recent_traces(
152
+ self, limit: int = 100, slow_only: bool = False
153
+ ) -> List[RequestTrace]:
154
+ """
155
+ Get recent traces.
156
+
157
+ Args:
158
+ limit: Maximum number of traces
159
+ slow_only: Only return slow requests
160
+
161
+ Returns:
162
+ List of RequestTrace objects
163
+ """
164
+ traces = list(self._traces)
165
+
166
+ if slow_only:
167
+ traces = [t for t in traces if t.is_slow]
168
+
169
+ return traces[-limit:]
170
+
171
+ def get_latency_breakdown(self) -> LatencyBreakdown:
172
+ """
173
+ Get average latency breakdown.
174
+
175
+ Returns:
176
+ LatencyBreakdown with average times
177
+ """
178
+ if not self._traces:
179
+ return LatencyBreakdown(0, 0, 0, 0)
180
+
181
+ recent = list(self._traces)[-100:]
182
+
183
+ return LatencyBreakdown(
184
+ queue_ms=statistics.mean(t.queue_time_ms for t in recent),
185
+ prefill_ms=statistics.mean(t.prefill_time_ms for t in recent),
186
+ decode_ms=statistics.mean(t.decode_time_ms for t in recent),
187
+ total_ms=statistics.mean(t.total_time_ms for t in recent),
188
+ )
189
+
190
+ def correlate_with_gpu_pressure(self, trace: RequestTrace) -> TraceCorrelation:
191
+ """
192
+ Correlate trace latency with GPU memory pressure.
193
+
194
+ Args:
195
+ trace: Request trace to analyze
196
+
197
+ Returns:
198
+ TraceCorrelation analysis
199
+ """
200
+ memory_delta = trace.gpu_memory_at_end - trace.gpu_memory_at_start
201
+
202
+ # Determine likely cause based on patterns
203
+ if memory_delta > 2.0:
204
+ cause = "batch_contention"
205
+ elif trace.queue_time_ms > trace.total_time_ms * 0.3:
206
+ cause = "queue_congestion"
207
+ elif trace.prefill_time_ms > trace.decode_time_ms * 2:
208
+ cause = "long_prompt"
209
+ else:
210
+ cause = "normal"
211
+
212
+ return TraceCorrelation(
213
+ memory_pressure=memory_delta > 1.0,
214
+ likely_cause=cause,
215
+ memory_delta_gb=memory_delta,
216
+ )
217
+
218
+ def get_percentiles(self) -> Dict[str, float]:
219
+ """
220
+ Get latency percentiles.
221
+
222
+ Returns:
223
+ Dictionary with P50, P95, P99 values
224
+ """
225
+ if not self._latency_window:
226
+ return {"p50": 0, "p95": 0, "p99": 0}
227
+
228
+ sorted_latencies = sorted(self._latency_window)
229
+ n = len(sorted_latencies)
230
+
231
+ return {
232
+ "p50": sorted_latencies[int(n * 0.50)],
233
+ "p95": sorted_latencies[int(n * 0.95)],
234
+ "p99": sorted_latencies[min(int(n * 0.99), n - 1)],
235
+ }
236
+
237
+ def get_stats(self) -> Dict[str, Any]:
238
+ """
239
+ Get comprehensive statistics.
240
+
241
+ Returns:
242
+ Dictionary with various stats
243
+ """
244
+ if not self._traces:
245
+ return {
246
+ "total_requests": 0,
247
+ "slow_requests": 0,
248
+ "avg_latency_ms": 0,
249
+ "percentiles": {"p50": 0, "p95": 0, "p99": 0},
250
+ "breakdown": {"queue": 0, "prefill": 0, "decode": 0},
251
+ }
252
+
253
+ traces = list(self._traces)
254
+ slow_count = sum(1 for t in traces if t.is_slow)
255
+ breakdown = self.get_latency_breakdown()
256
+
257
+ return {
258
+ "total_requests": len(traces),
259
+ "slow_requests": slow_count,
260
+ "slow_rate_percent": (slow_count / len(traces)) * 100,
261
+ "avg_latency_ms": breakdown.total_ms,
262
+ "percentiles": self.get_percentiles(),
263
+ "breakdown": breakdown.as_dict,
264
+ "baseline_p95": self._baseline_p95,
265
+ }
266
+
267
+ def clear(self) -> None:
268
+ """Clear all traces."""
269
+ self._traces.clear()
270
+ self._latency_window.clear()
271
+ self._baseline_p95 = None
272
+ self._slow_threshold_ms = None
storage/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Storage layer for persistent metrics and traces."""
2
+
3
+ from .database import MetricsDB
4
+ from .models import MetricRecord, AlertRecord, RequestTrace
5
+
6
+ __all__ = [
7
+ "MetricsDB",
8
+ "MetricRecord",
9
+ "AlertRecord",
10
+ "RequestTrace",
11
+ ]
storage/database.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQLite database operations for metrics storage."""
2
+
3
+ import sqlite3
4
+ import json
5
+ import logging
6
+ from datetime import datetime, timedelta
7
+ from pathlib import Path
8
+ from typing import List, Optional, Dict, Any
9
+ from contextlib import contextmanager
10
+
11
+ from .models import MetricRecord, AlertRecord, RequestTrace, LoadTestResult
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class MetricsDB:
17
+ """SQLite database for storing metrics, alerts, and traces."""
18
+
19
+ def __init__(self, db_path: str = "data/metrics.db"):
20
+ """
21
+ Initialize database connection.
22
+
23
+ Args:
24
+ db_path: Path to SQLite database file
25
+ """
26
+ self.db_path = db_path
27
+ self._ensure_directory()
28
+ self._init_schema()
29
+
30
+ def _ensure_directory(self) -> None:
31
+ """Ensure the database directory exists."""
32
+ Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
33
+
34
+ @contextmanager
35
+ def _get_connection(self):
36
+ """Get a database connection with context manager."""
37
+ conn = sqlite3.connect(self.db_path)
38
+ conn.row_factory = sqlite3.Row
39
+ try:
40
+ yield conn
41
+ conn.commit()
42
+ except Exception:
43
+ conn.rollback()
44
+ raise
45
+ finally:
46
+ conn.close()
47
+
48
+ def _init_schema(self) -> None:
49
+ """Initialize database schema."""
50
+ with self._get_connection() as conn:
51
+ cursor = conn.cursor()
52
+
53
+ # Metrics table
54
+ cursor.execute("""
55
+ CREATE TABLE IF NOT EXISTS metrics (
56
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
57
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
58
+ metric_name TEXT NOT NULL,
59
+ value REAL NOT NULL,
60
+ labels TEXT
61
+ )
62
+ """)
63
+
64
+ # Indexes for metrics
65
+ cursor.execute("""
66
+ CREATE INDEX IF NOT EXISTS idx_metrics_timestamp
67
+ ON metrics(timestamp)
68
+ """)
69
+ cursor.execute("""
70
+ CREATE INDEX IF NOT EXISTS idx_metrics_name_time
71
+ ON metrics(metric_name, timestamp)
72
+ """)
73
+
74
+ # Alerts table
75
+ cursor.execute("""
76
+ CREATE TABLE IF NOT EXISTS alerts (
77
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
78
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
79
+ rule_name TEXT,
80
+ severity TEXT,
81
+ metric_name TEXT,
82
+ value REAL,
83
+ threshold REAL,
84
+ message TEXT,
85
+ resolved_at DATETIME
86
+ )
87
+ """)
88
+
89
+ # Request traces table
90
+ cursor.execute("""
91
+ CREATE TABLE IF NOT EXISTS request_traces (
92
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
93
+ request_id TEXT UNIQUE,
94
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
95
+ prompt_tokens INTEGER,
96
+ output_tokens INTEGER,
97
+ queue_time_ms REAL,
98
+ prefill_time_ms REAL,
99
+ decode_time_ms REAL,
100
+ total_time_ms REAL,
101
+ tokens_per_second REAL,
102
+ is_slow BOOLEAN
103
+ )
104
+ """)
105
+
106
+ # Load test results table
107
+ cursor.execute("""
108
+ CREATE TABLE IF NOT EXISTS load_tests (
109
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
110
+ test_id TEXT UNIQUE,
111
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
112
+ target_endpoint TEXT,
113
+ concurrent_users INTEGER,
114
+ requests_per_second REAL,
115
+ duration_seconds INTEGER,
116
+ total_requests INTEGER,
117
+ successful_requests INTEGER,
118
+ failed_requests INTEGER,
119
+ avg_latency_ms REAL,
120
+ p50_latency_ms REAL,
121
+ p95_latency_ms REAL,
122
+ p99_latency_ms REAL,
123
+ throughput_rps REAL,
124
+ saturation_point REAL
125
+ )
126
+ """)
127
+
128
+ # Metrics operations
129
+
130
+ def insert_metric(self, record: MetricRecord) -> int:
131
+ """Insert a metric record."""
132
+ with self._get_connection() as conn:
133
+ cursor = conn.cursor()
134
+ cursor.execute(
135
+ """
136
+ INSERT INTO metrics (timestamp, metric_name, value, labels)
137
+ VALUES (?, ?, ?, ?)
138
+ """,
139
+ (
140
+ record.timestamp.isoformat(),
141
+ record.metric_name,
142
+ record.value,
143
+ json.dumps(record.labels) if record.labels else None,
144
+ ),
145
+ )
146
+ return cursor.lastrowid
147
+
148
+ def insert_metrics_batch(self, records: List[MetricRecord]) -> None:
149
+ """Insert multiple metric records efficiently."""
150
+ with self._get_connection() as conn:
151
+ cursor = conn.cursor()
152
+ cursor.executemany(
153
+ """
154
+ INSERT INTO metrics (timestamp, metric_name, value, labels)
155
+ VALUES (?, ?, ?, ?)
156
+ """,
157
+ [
158
+ (
159
+ r.timestamp.isoformat(),
160
+ r.metric_name,
161
+ r.value,
162
+ json.dumps(r.labels) if r.labels else None,
163
+ )
164
+ for r in records
165
+ ],
166
+ )
167
+
168
+ def query_metrics(
169
+ self,
170
+ metric_name: str,
171
+ start: datetime,
172
+ end: datetime,
173
+ labels: Optional[Dict[str, str]] = None,
174
+ ) -> List[MetricRecord]:
175
+ """Query metrics by name and time range."""
176
+ with self._get_connection() as conn:
177
+ cursor = conn.cursor()
178
+ cursor.execute(
179
+ """
180
+ SELECT id, timestamp, metric_name, value, labels
181
+ FROM metrics
182
+ WHERE metric_name = ? AND timestamp BETWEEN ? AND ?
183
+ ORDER BY timestamp
184
+ """,
185
+ (metric_name, start.isoformat(), end.isoformat()),
186
+ )
187
+
188
+ records = []
189
+ for row in cursor.fetchall():
190
+ record = MetricRecord.from_row(tuple(row))
191
+ if labels:
192
+ # Filter by labels if specified
193
+ if all(record.labels.get(k) == v for k, v in labels.items()):
194
+ records.append(record)
195
+ else:
196
+ records.append(record)
197
+
198
+ return records
199
+
200
+ def query_aggregated(
201
+ self,
202
+ metric_name: str,
203
+ start: datetime,
204
+ end: datetime,
205
+ aggregation: str = "avg",
206
+ bucket_minutes: int = 1,
207
+ ) -> List[Dict[str, Any]]:
208
+ """Query metrics with time bucketing and aggregation."""
209
+ agg_func = {
210
+ "avg": "AVG",
211
+ "max": "MAX",
212
+ "min": "MIN",
213
+ "sum": "SUM",
214
+ "count": "COUNT",
215
+ }.get(aggregation, "AVG")
216
+
217
+ with self._get_connection() as conn:
218
+ cursor = conn.cursor()
219
+ cursor.execute(
220
+ f"""
221
+ SELECT
222
+ datetime(
223
+ strftime('%Y-%m-%d %H:', timestamp) ||
224
+ printf('%02d', (CAST(strftime('%M', timestamp) AS INTEGER) / {bucket_minutes}) * {bucket_minutes}) ||
225
+ ':00'
226
+ ) as bucket,
227
+ {agg_func}(value) as value
228
+ FROM metrics
229
+ WHERE metric_name = ? AND timestamp BETWEEN ? AND ?
230
+ GROUP BY bucket
231
+ ORDER BY bucket
232
+ """,
233
+ (metric_name, start.isoformat(), end.isoformat()),
234
+ )
235
+
236
+ return [
237
+ {"time": row["bucket"], "value": row["value"]}
238
+ for row in cursor.fetchall()
239
+ ]
240
+
241
+ # Alert operations
242
+
243
+ def insert_alert(self, alert: AlertRecord) -> int:
244
+ """Insert an alert record."""
245
+ with self._get_connection() as conn:
246
+ cursor = conn.cursor()
247
+ cursor.execute(
248
+ """
249
+ INSERT INTO alerts
250
+ (timestamp, rule_name, severity, metric_name, value, threshold, message, resolved_at)
251
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
252
+ """,
253
+ (
254
+ alert.timestamp.isoformat(),
255
+ alert.rule_name,
256
+ alert.severity,
257
+ alert.metric_name,
258
+ alert.value,
259
+ alert.threshold,
260
+ alert.message,
261
+ alert.resolved_at.isoformat() if alert.resolved_at else None,
262
+ ),
263
+ )
264
+ return cursor.lastrowid
265
+
266
+ def get_active_alerts(self) -> List[AlertRecord]:
267
+ """Get all unresolved alerts."""
268
+ with self._get_connection() as conn:
269
+ cursor = conn.cursor()
270
+ cursor.execute(
271
+ """
272
+ SELECT id, timestamp, rule_name, severity, metric_name, value, threshold, message, resolved_at
273
+ FROM alerts
274
+ WHERE resolved_at IS NULL
275
+ ORDER BY timestamp DESC
276
+ """
277
+ )
278
+ return [AlertRecord.from_row(tuple(row)) for row in cursor.fetchall()]
279
+
280
+ def get_recent_alerts(self, limit: int = 100) -> List[AlertRecord]:
281
+ """Get recent alerts."""
282
+ with self._get_connection() as conn:
283
+ cursor = conn.cursor()
284
+ cursor.execute(
285
+ """
286
+ SELECT id, timestamp, rule_name, severity, metric_name, value, threshold, message, resolved_at
287
+ FROM alerts
288
+ ORDER BY timestamp DESC
289
+ LIMIT ?
290
+ """,
291
+ (limit,),
292
+ )
293
+ return [AlertRecord.from_row(tuple(row)) for row in cursor.fetchall()]
294
+
295
+ def resolve_alert(self, alert_id: int) -> None:
296
+ """Mark an alert as resolved."""
297
+ with self._get_connection() as conn:
298
+ cursor = conn.cursor()
299
+ cursor.execute(
300
+ """
301
+ UPDATE alerts SET resolved_at = ? WHERE id = ?
302
+ """,
303
+ (datetime.now().isoformat(), alert_id),
304
+ )
305
+
306
+ # Request trace operations
307
+
308
+ def insert_trace(self, trace: RequestTrace) -> int:
309
+ """Insert a request trace."""
310
+ with self._get_connection() as conn:
311
+ cursor = conn.cursor()
312
+ cursor.execute(
313
+ """
314
+ INSERT OR REPLACE INTO request_traces
315
+ (request_id, timestamp, prompt_tokens, output_tokens,
316
+ queue_time_ms, prefill_time_ms, decode_time_ms, total_time_ms,
317
+ tokens_per_second, is_slow)
318
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
319
+ """,
320
+ (
321
+ trace.request_id,
322
+ trace.timestamp.isoformat(),
323
+ trace.prompt_tokens,
324
+ trace.output_tokens,
325
+ trace.queue_time_ms,
326
+ trace.prefill_time_ms,
327
+ trace.decode_time_ms,
328
+ trace.total_time_ms,
329
+ trace.tokens_per_second,
330
+ trace.is_slow,
331
+ ),
332
+ )
333
+ return cursor.lastrowid
334
+
335
+ def get_recent_traces(
336
+ self, limit: int = 100, slow_only: bool = False
337
+ ) -> List[RequestTrace]:
338
+ """Get recent request traces."""
339
+ with self._get_connection() as conn:
340
+ cursor = conn.cursor()
341
+ query = """
342
+ SELECT id, request_id, timestamp, prompt_tokens, output_tokens,
343
+ queue_time_ms, prefill_time_ms, decode_time_ms, total_time_ms,
344
+ tokens_per_second, is_slow
345
+ FROM request_traces
346
+ """
347
+ if slow_only:
348
+ query += " WHERE is_slow = 1"
349
+ query += " ORDER BY timestamp DESC LIMIT ?"
350
+
351
+ cursor.execute(query, (limit,))
352
+ return [RequestTrace.from_row(tuple(row)) for row in cursor.fetchall()]
353
+
354
+ def get_trace_stats(self) -> Dict[str, Any]:
355
+ """Get aggregate statistics for traces."""
356
+ with self._get_connection() as conn:
357
+ cursor = conn.cursor()
358
+ cursor.execute(
359
+ """
360
+ SELECT
361
+ COUNT(*) as total,
362
+ AVG(total_time_ms) as avg_latency,
363
+ AVG(queue_time_ms) as avg_queue,
364
+ AVG(prefill_time_ms) as avg_prefill,
365
+ AVG(decode_time_ms) as avg_decode,
366
+ SUM(CASE WHEN is_slow THEN 1 ELSE 0 END) as slow_count
367
+ FROM request_traces
368
+ WHERE timestamp > datetime('now', '-1 hour')
369
+ """
370
+ )
371
+ row = cursor.fetchone()
372
+ return {
373
+ "total_requests": row["total"] or 0,
374
+ "avg_latency_ms": row["avg_latency"] or 0,
375
+ "avg_queue_ms": row["avg_queue"] or 0,
376
+ "avg_prefill_ms": row["avg_prefill"] or 0,
377
+ "avg_decode_ms": row["avg_decode"] or 0,
378
+ "slow_request_count": row["slow_count"] or 0,
379
+ }
380
+
381
+ # Load test operations
382
+
383
+ def insert_load_test(self, result: LoadTestResult) -> int:
384
+ """Insert a load test result."""
385
+ with self._get_connection() as conn:
386
+ cursor = conn.cursor()
387
+ cursor.execute(
388
+ """
389
+ INSERT INTO load_tests
390
+ (test_id, timestamp, target_endpoint, concurrent_users,
391
+ requests_per_second, duration_seconds, total_requests,
392
+ successful_requests, failed_requests, avg_latency_ms,
393
+ p50_latency_ms, p95_latency_ms, p99_latency_ms,
394
+ throughput_rps, saturation_point)
395
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
396
+ """,
397
+ (
398
+ result.test_id,
399
+ result.timestamp.isoformat(),
400
+ result.target_endpoint,
401
+ result.concurrent_users,
402
+ result.requests_per_second,
403
+ result.duration_seconds,
404
+ result.total_requests,
405
+ result.successful_requests,
406
+ result.failed_requests,
407
+ result.avg_latency_ms,
408
+ result.p50_latency_ms,
409
+ result.p95_latency_ms,
410
+ result.p99_latency_ms,
411
+ result.throughput_rps,
412
+ result.saturation_point,
413
+ ),
414
+ )
415
+ return cursor.lastrowid
416
+
417
+ def get_recent_load_tests(self, limit: int = 10) -> List[Dict[str, Any]]:
418
+ """Get recent load test results."""
419
+ with self._get_connection() as conn:
420
+ cursor = conn.cursor()
421
+ cursor.execute(
422
+ """
423
+ SELECT * FROM load_tests
424
+ ORDER BY timestamp DESC
425
+ LIMIT ?
426
+ """,
427
+ (limit,),
428
+ )
429
+ return [dict(row) for row in cursor.fetchall()]
430
+
431
+ # Cleanup operations
432
+
433
+ def cleanup_old_data(self, days: int = 7) -> int:
434
+ """Remove data older than specified days."""
435
+ cutoff = (datetime.now() - timedelta(days=days)).isoformat()
436
+
437
+ with self._get_connection() as conn:
438
+ cursor = conn.cursor()
439
+ total_deleted = 0
440
+
441
+ for table in ["metrics", "alerts", "request_traces"]:
442
+ cursor.execute(
443
+ f"DELETE FROM {table} WHERE timestamp < ?",
444
+ (cutoff,),
445
+ )
446
+ total_deleted += cursor.rowcount
447
+
448
+ return total_deleted
storage/models.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for storage layer."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime
5
+ from typing import Optional, Dict, Any
6
+ import json
7
+
8
+
9
+ @dataclass
10
+ class MetricRecord:
11
+ """A single metric record for storage."""
12
+ metric_name: str
13
+ value: float
14
+ timestamp: datetime = field(default_factory=datetime.now)
15
+ labels: Dict[str, str] = field(default_factory=dict)
16
+ id: Optional[int] = None
17
+
18
+ def to_dict(self) -> Dict[str, Any]:
19
+ return {
20
+ "id": self.id,
21
+ "metric_name": self.metric_name,
22
+ "value": self.value,
23
+ "timestamp": self.timestamp.isoformat(),
24
+ "labels": self.labels,
25
+ }
26
+
27
+ @classmethod
28
+ def from_row(cls, row: tuple) -> "MetricRecord":
29
+ return cls(
30
+ id=row[0],
31
+ timestamp=datetime.fromisoformat(row[1]),
32
+ metric_name=row[2],
33
+ value=row[3],
34
+ labels=json.loads(row[4]) if row[4] else {},
35
+ )
36
+
37
+
38
+ @dataclass
39
+ class AlertRecord:
40
+ """An alert record for storage."""
41
+ rule_name: str
42
+ severity: str
43
+ metric_name: str
44
+ value: float
45
+ threshold: float
46
+ message: str
47
+ timestamp: datetime = field(default_factory=datetime.now)
48
+ resolved_at: Optional[datetime] = None
49
+ id: Optional[int] = None
50
+
51
+ def to_dict(self) -> Dict[str, Any]:
52
+ return {
53
+ "id": self.id,
54
+ "rule_name": self.rule_name,
55
+ "severity": self.severity,
56
+ "metric_name": self.metric_name,
57
+ "value": self.value,
58
+ "threshold": self.threshold,
59
+ "message": self.message,
60
+ "timestamp": self.timestamp.isoformat(),
61
+ "resolved_at": self.resolved_at.isoformat() if self.resolved_at else None,
62
+ }
63
+
64
+ @classmethod
65
+ def from_row(cls, row: tuple) -> "AlertRecord":
66
+ return cls(
67
+ id=row[0],
68
+ timestamp=datetime.fromisoformat(row[1]),
69
+ rule_name=row[2],
70
+ severity=row[3],
71
+ metric_name=row[4],
72
+ value=row[5],
73
+ threshold=row[6],
74
+ message=row[7] if len(row) > 7 else "",
75
+ resolved_at=datetime.fromisoformat(row[8]) if len(row) > 8 and row[8] else None,
76
+ )
77
+
78
+
79
+ @dataclass
80
+ class RequestTrace:
81
+ """A request trace for latency analysis."""
82
+ request_id: str
83
+ prompt_tokens: int
84
+ output_tokens: int
85
+ queue_time_ms: float
86
+ prefill_time_ms: float
87
+ decode_time_ms: float
88
+ total_time_ms: float
89
+ tokens_per_second: float
90
+ gpu_memory_at_start: float = 0.0
91
+ gpu_memory_at_end: float = 0.0
92
+ is_slow: bool = False
93
+ timestamp: datetime = field(default_factory=datetime.now)
94
+ id: Optional[int] = None
95
+
96
+ def to_dict(self) -> Dict[str, Any]:
97
+ return {
98
+ "id": self.id,
99
+ "request_id": self.request_id,
100
+ "timestamp": self.timestamp.isoformat(),
101
+ "prompt_tokens": self.prompt_tokens,
102
+ "output_tokens": self.output_tokens,
103
+ "queue_time_ms": round(self.queue_time_ms, 2),
104
+ "prefill_time_ms": round(self.prefill_time_ms, 2),
105
+ "decode_time_ms": round(self.decode_time_ms, 2),
106
+ "total_time_ms": round(self.total_time_ms, 2),
107
+ "tokens_per_second": round(self.tokens_per_second, 2),
108
+ "is_slow": self.is_slow,
109
+ }
110
+
111
+ @classmethod
112
+ def from_row(cls, row: tuple) -> "RequestTrace":
113
+ return cls(
114
+ id=row[0],
115
+ request_id=row[1],
116
+ timestamp=datetime.fromisoformat(row[2]),
117
+ prompt_tokens=row[3],
118
+ output_tokens=row[4],
119
+ queue_time_ms=row[5],
120
+ prefill_time_ms=row[6],
121
+ decode_time_ms=row[7],
122
+ total_time_ms=row[8],
123
+ tokens_per_second=row[9] if len(row) > 9 else 0,
124
+ is_slow=bool(row[10]) if len(row) > 10 else False,
125
+ )
126
+
127
+
128
+ @dataclass
129
+ class LoadTestResult:
130
+ """Results from a load test run."""
131
+ test_id: str
132
+ target_endpoint: str
133
+ concurrent_users: int
134
+ requests_per_second: float
135
+ duration_seconds: int
136
+ total_requests: int
137
+ successful_requests: int
138
+ failed_requests: int
139
+ avg_latency_ms: float
140
+ p50_latency_ms: float
141
+ p95_latency_ms: float
142
+ p99_latency_ms: float
143
+ throughput_rps: float
144
+ saturation_point: Optional[float] = None
145
+ timestamp: datetime = field(default_factory=datetime.now)
146
+ id: Optional[int] = None
147
+
148
+ def to_dict(self) -> Dict[str, Any]:
149
+ return {
150
+ "test_id": self.test_id,
151
+ "target_endpoint": self.target_endpoint,
152
+ "concurrent_users": self.concurrent_users,
153
+ "requests_per_second": self.requests_per_second,
154
+ "duration_seconds": self.duration_seconds,
155
+ "total_requests": self.total_requests,
156
+ "successful_requests": self.successful_requests,
157
+ "failed_requests": self.failed_requests,
158
+ "avg_latency_ms": round(self.avg_latency_ms, 2),
159
+ "p50_latency_ms": round(self.p50_latency_ms, 2),
160
+ "p95_latency_ms": round(self.p95_latency_ms, 2),
161
+ "p99_latency_ms": round(self.p99_latency_ms, 2),
162
+ "throughput_rps": round(self.throughput_rps, 2),
163
+ "saturation_point": self.saturation_point,
164
+ "timestamp": self.timestamp.isoformat(),
165
+ }
utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Utility modules for the dashboard."""
2
+
3
+ from .prometheus_parser import parse_prometheus_metrics
4
+ from .history import MetricHistory
5
+
6
+ __all__ = ["parse_prometheus_metrics", "MetricHistory"]
utils/history.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """In-memory metric history buffer for time-series data."""
2
+
3
+ from collections import deque
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from typing import Dict, List, Any, Optional
7
+ import threading
8
+
9
+
10
+ @dataclass
11
+ class HistoryPoint:
12
+ """A single point in metric history."""
13
+ timestamp: datetime
14
+ value: float
15
+ labels: Dict[str, str] = field(default_factory=dict)
16
+
17
+
18
+ class MetricHistory:
19
+ """
20
+ Thread-safe in-memory buffer for metric history.
21
+
22
+ Maintains a rolling window of metric values for charting.
23
+ """
24
+
25
+ def __init__(self, max_length: int = 300):
26
+ """
27
+ Initialize history buffer.
28
+
29
+ Args:
30
+ max_length: Maximum number of points to retain
31
+ """
32
+ self.max_length = max_length
33
+ self._data: Dict[str, deque] = {}
34
+ self._lock = threading.Lock()
35
+
36
+ def add(self, metric_name: str, value: float, labels: Optional[Dict[str, str]] = None) -> None:
37
+ """
38
+ Add a data point to the history.
39
+
40
+ Args:
41
+ metric_name: Name of the metric
42
+ value: Metric value
43
+ labels: Optional labels for the metric
44
+ """
45
+ point = HistoryPoint(
46
+ timestamp=datetime.now(),
47
+ value=value,
48
+ labels=labels or {}
49
+ )
50
+
51
+ # Create key including labels for differentiation
52
+ key = self._make_key(metric_name, labels)
53
+
54
+ with self._lock:
55
+ if key not in self._data:
56
+ self._data[key] = deque(maxlen=self.max_length)
57
+ self._data[key].append(point)
58
+
59
+ def get(
60
+ self,
61
+ metric_name: str,
62
+ labels: Optional[Dict[str, str]] = None,
63
+ limit: Optional[int] = None
64
+ ) -> List[HistoryPoint]:
65
+ """
66
+ Get history for a metric.
67
+
68
+ Args:
69
+ metric_name: Name of the metric
70
+ labels: Optional label filter
71
+ limit: Maximum number of points to return
72
+
73
+ Returns:
74
+ List of history points
75
+ """
76
+ key = self._make_key(metric_name, labels)
77
+
78
+ with self._lock:
79
+ if key not in self._data:
80
+ return []
81
+
82
+ points = list(self._data[key])
83
+ if limit:
84
+ points = points[-limit:]
85
+ return points
86
+
87
+ def get_latest(
88
+ self,
89
+ metric_name: str,
90
+ labels: Optional[Dict[str, str]] = None
91
+ ) -> Optional[HistoryPoint]:
92
+ """Get the most recent value for a metric."""
93
+ points = self.get(metric_name, labels, limit=1)
94
+ return points[-1] if points else None
95
+
96
+ def get_all_series(self, metric_name: str) -> Dict[str, List[HistoryPoint]]:
97
+ """
98
+ Get all label combinations for a metric.
99
+
100
+ Args:
101
+ metric_name: Base metric name
102
+
103
+ Returns:
104
+ Dictionary mapping label strings to history lists
105
+ """
106
+ result = {}
107
+ prefix = f"{metric_name}:"
108
+
109
+ with self._lock:
110
+ for key, points in self._data.items():
111
+ if key == metric_name or key.startswith(prefix):
112
+ result[key] = list(points)
113
+
114
+ return result
115
+
116
+ def to_dataframe(self, metric_name: str, labels: Optional[Dict[str, str]] = None):
117
+ """
118
+ Convert history to pandas DataFrame.
119
+
120
+ Args:
121
+ metric_name: Name of the metric
122
+ labels: Optional label filter
123
+
124
+ Returns:
125
+ pandas DataFrame with time and value columns
126
+ """
127
+ import pandas as pd
128
+
129
+ points = self.get(metric_name, labels)
130
+
131
+ if not points:
132
+ return pd.DataFrame(columns=["time", "value"])
133
+
134
+ return pd.DataFrame([
135
+ {"time": p.timestamp, "value": p.value, **p.labels}
136
+ for p in points
137
+ ])
138
+
139
+ def clear(self, metric_name: Optional[str] = None) -> None:
140
+ """
141
+ Clear history.
142
+
143
+ Args:
144
+ metric_name: If provided, clear only this metric; otherwise clear all
145
+ """
146
+ with self._lock:
147
+ if metric_name:
148
+ keys_to_remove = [
149
+ k for k in self._data.keys()
150
+ if k == metric_name or k.startswith(f"{metric_name}:")
151
+ ]
152
+ for key in keys_to_remove:
153
+ del self._data[key]
154
+ else:
155
+ self._data.clear()
156
+
157
+ def _make_key(self, metric_name: str, labels: Optional[Dict[str, str]]) -> str:
158
+ """Create a unique key from metric name and labels."""
159
+ if not labels:
160
+ return metric_name
161
+
162
+ label_str = ",".join(f"{k}={v}" for k, v in sorted(labels.items()))
163
+ return f"{metric_name}:{label_str}"
utils/prometheus_parser.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parser for Prometheus text format metrics."""
2
+
3
+ import re
4
+ from typing import Dict, List, Any, Optional
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass
9
+ class MetricSample:
10
+ """A single metric sample with labels and value."""
11
+ name: str
12
+ labels: Dict[str, str]
13
+ value: float
14
+ timestamp: Optional[float] = None
15
+
16
+
17
+ def parse_prometheus_metrics(text: str) -> Dict[str, List[MetricSample]]:
18
+ """
19
+ Parse Prometheus text format into structured metrics.
20
+
21
+ Args:
22
+ text: Raw Prometheus metrics text
23
+
24
+ Returns:
25
+ Dictionary mapping metric names to lists of samples
26
+ """
27
+ metrics: Dict[str, List[MetricSample]] = {}
28
+
29
+ for line in text.strip().split("\n"):
30
+ line = line.strip()
31
+
32
+ # Skip empty lines and comments
33
+ if not line or line.startswith("#"):
34
+ continue
35
+
36
+ # Parse metric line
37
+ sample = _parse_metric_line(line)
38
+ if sample:
39
+ if sample.name not in metrics:
40
+ metrics[sample.name] = []
41
+ metrics[sample.name].append(sample)
42
+
43
+ return metrics
44
+
45
+
46
+ def _parse_metric_line(line: str) -> Optional[MetricSample]:
47
+ """Parse a single Prometheus metric line."""
48
+ # Pattern: metric_name{label1="value1",label2="value2"} value [timestamp]
49
+ # Or: metric_name value [timestamp]
50
+
51
+ # Match with labels
52
+ match = re.match(
53
+ r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([^\s]+)(?:\s+(\d+))?$',
54
+ line
55
+ )
56
+
57
+ if match:
58
+ name = match.group(1)
59
+ labels_str = match.group(2)
60
+ value_str = match.group(3)
61
+ timestamp_str = match.group(4)
62
+
63
+ labels = _parse_labels(labels_str)
64
+ value = _parse_value(value_str)
65
+ timestamp = float(timestamp_str) if timestamp_str else None
66
+
67
+ return MetricSample(name=name, labels=labels, value=value, timestamp=timestamp)
68
+
69
+ # Match without labels
70
+ match = re.match(
71
+ r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([^\s]+)(?:\s+(\d+))?$',
72
+ line
73
+ )
74
+
75
+ if match:
76
+ name = match.group(1)
77
+ value_str = match.group(2)
78
+ timestamp_str = match.group(3)
79
+
80
+ value = _parse_value(value_str)
81
+ timestamp = float(timestamp_str) if timestamp_str else None
82
+
83
+ return MetricSample(name=name, labels={}, value=value, timestamp=timestamp)
84
+
85
+ return None
86
+
87
+
88
+ def _parse_labels(labels_str: str) -> Dict[str, str]:
89
+ """Parse label string into dictionary."""
90
+ labels = {}
91
+
92
+ # Pattern: key="value"
93
+ for match in re.finditer(r'([a-zA-Z_][a-zA-Z0-9_]*)="([^"]*)"', labels_str):
94
+ labels[match.group(1)] = match.group(2)
95
+
96
+ return labels
97
+
98
+
99
+ def _parse_value(value_str: str) -> float:
100
+ """Parse metric value, handling special cases."""
101
+ if value_str.lower() == "nan":
102
+ return float("nan")
103
+ if value_str.lower() == "+inf":
104
+ return float("inf")
105
+ if value_str.lower() == "-inf":
106
+ return float("-inf")
107
+ return float(value_str)
108
+
109
+
110
+ def get_metric_value(
111
+ metrics: Dict[str, List[MetricSample]],
112
+ name: str,
113
+ labels: Optional[Dict[str, str]] = None
114
+ ) -> Optional[float]:
115
+ """
116
+ Get a specific metric value by name and optional labels.
117
+
118
+ Args:
119
+ metrics: Parsed metrics dictionary
120
+ name: Metric name
121
+ labels: Optional label filter
122
+
123
+ Returns:
124
+ Metric value or None if not found
125
+ """
126
+ if name not in metrics:
127
+ return None
128
+
129
+ for sample in metrics[name]:
130
+ if labels is None:
131
+ return sample.value
132
+
133
+ # Check if all specified labels match
134
+ if all(sample.labels.get(k) == v for k, v in labels.items()):
135
+ return sample.value
136
+
137
+ return None
138
+
139
+
140
+ def get_histogram_quantile(
141
+ metrics: Dict[str, List[MetricSample]],
142
+ name: str,
143
+ quantile: float,
144
+ labels: Optional[Dict[str, str]] = None
145
+ ) -> Optional[float]:
146
+ """
147
+ Get histogram quantile value from Prometheus histogram.
148
+
149
+ Args:
150
+ metrics: Parsed metrics dictionary
151
+ name: Base metric name (without _bucket suffix)
152
+ quantile: Desired quantile (e.g., 0.95 for P95)
153
+ labels: Optional label filter
154
+
155
+ Returns:
156
+ Approximate quantile value or None
157
+ """
158
+ bucket_name = f"{name}_bucket"
159
+ if bucket_name not in metrics:
160
+ return None
161
+
162
+ # Get all buckets
163
+ buckets = []
164
+ for sample in metrics[bucket_name]:
165
+ if labels and not all(sample.labels.get(k) == v for k, v in labels.items()):
166
+ continue
167
+ le = sample.labels.get("le")
168
+ if le and le != "+Inf":
169
+ buckets.append((float(le), sample.value))
170
+
171
+ if not buckets:
172
+ return None
173
+
174
+ # Sort by bucket boundary
175
+ buckets.sort(key=lambda x: x[0])
176
+
177
+ # Get total count
178
+ total = buckets[-1][1] if buckets else 0
179
+ if total == 0:
180
+ return None
181
+
182
+ # Find bucket containing quantile
183
+ target = quantile * total
184
+ prev_bound = 0
185
+ prev_count = 0
186
+
187
+ for bound, count in buckets:
188
+ if count >= target:
189
+ # Linear interpolation within bucket
190
+ fraction = (target - prev_count) / (count - prev_count) if count > prev_count else 0
191
+ return prev_bound + fraction * (bound - prev_bound)
192
+ prev_bound = bound
193
+ prev_count = count
194
+
195
+ return buckets[-1][0] if buckets else None