Spaces:
Sleeping
Sleeping
Commit ·
aefabf0
0
Parent(s):
Initial commit: LLM Inference Dashboard
Browse filesA production-grade Gradio dashboard for monitoring vLLM inference
on multi-GPU setups with:
- GPU/Rank monitoring (memory, utilization, temperature)
- Inference metrics (tokens/sec, TTFT, KV cache)
- Quantization detection (GPTQ, AWQ, bitsandbytes)
- Model loading progress tracking
- Alerting with Slack/webhook integration
- Request tracing with latency breakdown
- A/B deployment comparison
- Built-in load testing
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- .gitignore +53 -0
- README.md +103 -0
- app.py +313 -0
- collectors/__init__.py +13 -0
- collectors/gpu_collector.py +174 -0
- collectors/loading_tracker.py +224 -0
- collectors/quant_collector.py +259 -0
- collectors/vllm_collector.py +226 -0
- components/__init__.py +27 -0
- components/alerts_panel.py +253 -0
- components/comparison_panel.py +207 -0
- components/gpu_panel.py +191 -0
- components/inference_panel.py +209 -0
- components/loading_panel.py +151 -0
- components/loadtest_panel.py +220 -0
- components/quant_panel.py +118 -0
- components/tracing_panel.py +186 -0
- config.py +67 -0
- requirements.txt +16 -0
- services/__init__.py +18 -0
- services/alerting.py +421 -0
- services/comparator.py +366 -0
- services/load_tester.py +359 -0
- services/request_tracer.py +272 -0
- storage/__init__.py +11 -0
- storage/database.py +448 -0
- storage/models.py +165 -0
- utils/__init__.py +6 -0
- utils/history.py +163 -0
- utils/prometheus_parser.py +195 -0
.gitignore
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
ENV/
|
| 26 |
+
env/
|
| 27 |
+
.venv/
|
| 28 |
+
|
| 29 |
+
# IDE
|
| 30 |
+
.idea/
|
| 31 |
+
.vscode/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
*~
|
| 35 |
+
|
| 36 |
+
# Data files
|
| 37 |
+
data/*.db
|
| 38 |
+
*.sqlite
|
| 39 |
+
*.sqlite3
|
| 40 |
+
|
| 41 |
+
# Logs
|
| 42 |
+
*.log
|
| 43 |
+
|
| 44 |
+
# OS
|
| 45 |
+
.DS_Store
|
| 46 |
+
Thumbs.db
|
| 47 |
+
|
| 48 |
+
# Environment
|
| 49 |
+
.env
|
| 50 |
+
.env.local
|
| 51 |
+
|
| 52 |
+
# Claude
|
| 53 |
+
.claude/
|
README.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: LLM Inference Dashboard
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.9.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# LLM Inference Dashboard
|
| 14 |
+
|
| 15 |
+
A production-grade Gradio dashboard for monitoring vLLM inference on multi-GPU setups with alerting, request tracing, A/B comparison, load testing, and historical analysis.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
| Feature | Description |
|
| 20 |
+
|---------|-------------|
|
| 21 |
+
| Core Monitoring | GPU stats, inference metrics, quantization info |
|
| 22 |
+
| Alerting | Configurable thresholds, Slack/webhook notifications |
|
| 23 |
+
| Request Tracing | Per-request latency breakdown, slow request logging |
|
| 24 |
+
| A/B Comparison | Side-by-side deployment comparison |
|
| 25 |
+
| Load Testing | Built-in load generator with saturation detection |
|
| 26 |
+
| Historical Analysis | SQLite storage, trend queries |
|
| 27 |
+
|
| 28 |
+
## Tabs
|
| 29 |
+
|
| 30 |
+
1. **GPU / Rank Status** - Real-time GPU memory, utilization, temperature, and tensor parallel rank mapping
|
| 31 |
+
2. **Inference** - Tokens/sec, TTFT, batch size, KV cache utilization, latency metrics
|
| 32 |
+
3. **Quantization** - Detect and display GPTQ, AWQ, bitsandbytes quantization settings
|
| 33 |
+
4. **Loading** - Model loading progress with shard tracking
|
| 34 |
+
5. **Alerts** - Configure alert thresholds and webhook notifications
|
| 35 |
+
6. **Tracing** - Request-level latency breakdown and slow request analysis
|
| 36 |
+
7. **A/B Compare** - Compare metrics between two vLLM deployments
|
| 37 |
+
8. **Load Test** - Run load tests with configurable concurrency and RPS
|
| 38 |
+
|
| 39 |
+
## Usage
|
| 40 |
+
|
| 41 |
+
### Local Development
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
python app.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### With vLLM Server
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
# Start vLLM server
|
| 52 |
+
python -m vllm.entrypoints.openai.api_server \
|
| 53 |
+
--model <model_name> \
|
| 54 |
+
--tensor-parallel-size <N> \
|
| 55 |
+
--port 8000
|
| 56 |
+
|
| 57 |
+
# Set environment variables (optional)
|
| 58 |
+
export VLLM_HOST=localhost
|
| 59 |
+
export VLLM_PORT=8000
|
| 60 |
+
|
| 61 |
+
# Launch dashboard
|
| 62 |
+
python app.py
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Environment Variables
|
| 66 |
+
|
| 67 |
+
| Variable | Default | Description |
|
| 68 |
+
|----------|---------|-------------|
|
| 69 |
+
| `VLLM_HOST` | localhost | vLLM server hostname |
|
| 70 |
+
| `VLLM_PORT` | 8000 | vLLM server port |
|
| 71 |
+
| `MODEL_PATH` | None | Path to model for quantization detection |
|
| 72 |
+
| `DB_PATH` | data/metrics.db | SQLite database path |
|
| 73 |
+
| `SLACK_WEBHOOK` | None | Slack webhook URL for alerts |
|
| 74 |
+
| `PAGERDUTY_KEY` | None | PagerDuty routing key |
|
| 75 |
+
|
| 76 |
+
## Demo Mode
|
| 77 |
+
|
| 78 |
+
When no vLLM server is connected, the dashboard runs in demo mode with simulated GPU metrics.
|
| 79 |
+
|
| 80 |
+
## Architecture
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
┌─────────────────────────────────────────────────────────┐
|
| 84 |
+
│ Gradio Frontend │
|
| 85 |
+
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────────────┐│
|
| 86 |
+
│ │GPU Stats│ │Loading │ │Quant │ │Inference Metrics││
|
| 87 |
+
│ │ Tab │ │Progress │ │Details │ │ Tab ││
|
| 88 |
+
│ └─────────┘ └─────────┘ └─────────┘ └─────────────────┘│
|
| 89 |
+
└─────────────────────────────────────────────────────────┘
|
| 90 |
+
│
|
| 91 |
+
▼
|
| 92 |
+
┌─────────────────────────────────────────────────────────┐
|
| 93 |
+
│ Metrics Collector │
|
| 94 |
+
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌────────────┐ │
|
| 95 |
+
│ │ pynvml │ │Prometheus│ │ vLLM API │ │Model Config│ │
|
| 96 |
+
│ │ (GPUs) │ │ (/metrics)│ │ (status) │ │ (quant) │ │
|
| 97 |
+
│ └──────────┘ └──────────┘ └──────────┘ └────────────┘ │
|
| 98 |
+
└─────────────────────────────────────────────────────────┘
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## License
|
| 102 |
+
|
| 103 |
+
MIT
|
app.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Inference Dashboard - Main Application
|
| 3 |
+
|
| 4 |
+
A production-grade Gradio dashboard for monitoring vLLM inference
|
| 5 |
+
on multi-GPU setups with alerting, request tracing, A/B comparison,
|
| 6 |
+
load testing, and historical analysis.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
from config import config
|
| 17 |
+
from collectors import GPUCollector, VLLMCollector, QuantizationCollector, LoadingTracker
|
| 18 |
+
from components import (
|
| 19 |
+
create_gpu_panel,
|
| 20 |
+
update_gpu_panel,
|
| 21 |
+
create_inference_panel,
|
| 22 |
+
update_inference_panel,
|
| 23 |
+
create_quant_panel,
|
| 24 |
+
update_quant_panel,
|
| 25 |
+
create_loading_panel,
|
| 26 |
+
update_loading_panel,
|
| 27 |
+
create_alerts_panel,
|
| 28 |
+
update_alerts_panel,
|
| 29 |
+
create_tracing_panel,
|
| 30 |
+
update_tracing_panel,
|
| 31 |
+
create_comparison_panel,
|
| 32 |
+
create_loadtest_panel,
|
| 33 |
+
)
|
| 34 |
+
from components.alerts_panel import get_alert_badge_html
|
| 35 |
+
from components.inference_panel import get_metrics_dict
|
| 36 |
+
from services import AlertEngine, AlertDispatcher, RequestTracer
|
| 37 |
+
from storage import MetricsDB
|
| 38 |
+
from utils import MetricHistory
|
| 39 |
+
|
| 40 |
+
# Configure logging
|
| 41 |
+
logging.basicConfig(
|
| 42 |
+
level=logging.INFO,
|
| 43 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 44 |
+
)
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Initialize global instances
|
| 49 |
+
db = MetricsDB(config.db_path)
|
| 50 |
+
history = MetricHistory(max_length=config.history_length)
|
| 51 |
+
|
| 52 |
+
# Collectors
|
| 53 |
+
gpu_collector = GPUCollector()
|
| 54 |
+
vllm_collector = VLLMCollector(config.metrics_endpoint)
|
| 55 |
+
quant_collector = QuantizationCollector(config.model_path)
|
| 56 |
+
loading_tracker = LoadingTracker(config.model_path)
|
| 57 |
+
|
| 58 |
+
# Services
|
| 59 |
+
alert_engine = AlertEngine(db)
|
| 60 |
+
alert_dispatcher = AlertDispatcher(
|
| 61 |
+
slack_webhook=config.slack_webhook,
|
| 62 |
+
pagerduty_key=config.pagerduty_routing_key,
|
| 63 |
+
generic_webhooks=config.generic_webhooks,
|
| 64 |
+
)
|
| 65 |
+
request_tracer = RequestTracer(db)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def check_connection():
|
| 69 |
+
"""Check connection to vLLM server."""
|
| 70 |
+
connected = vllm_collector.check_connection()
|
| 71 |
+
if connected:
|
| 72 |
+
return (
|
| 73 |
+
'<div style="display: flex; align-items: center;">'
|
| 74 |
+
'<span style="width: 12px; height: 12px; background: #4caf50; '
|
| 75 |
+
'border-radius: 50%; display: inline-block; margin-right: 8px;"></span>'
|
| 76 |
+
'<span style="color: #2e7d32;">Connected</span></div>'
|
| 77 |
+
)
|
| 78 |
+
return (
|
| 79 |
+
'<div style="display: flex; align-items: center;">'
|
| 80 |
+
'<span style="width: 12px; height: 12px; background: #f44336; '
|
| 81 |
+
'border-radius: 50%; display: inline-block; margin-right: 8px;"></span>'
|
| 82 |
+
'<span style="color: #c62828;">Disconnected</span></div>'
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_model_name():
|
| 87 |
+
"""Get current model name."""
|
| 88 |
+
metrics = vllm_collector.collect()
|
| 89 |
+
return metrics.model_name or "Demo Mode"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def update_all_metrics():
|
| 93 |
+
"""Update all metrics from collectors."""
|
| 94 |
+
# GPU metrics
|
| 95 |
+
gpu_table, gpu_memory_plot, gpu_util_plot, nccl_status = update_gpu_panel(
|
| 96 |
+
gpu_collector, history
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Inference metrics
|
| 100 |
+
(
|
| 101 |
+
throughput, ttft, batch_size, kv_cache, throughput_plot,
|
| 102 |
+
prefill_pct, decode_pct, queue_depth, e2e_latency, latency_plot
|
| 103 |
+
) = update_inference_panel(vllm_collector, history)
|
| 104 |
+
|
| 105 |
+
# Check alerts
|
| 106 |
+
metrics = vllm_collector.collect()
|
| 107 |
+
metrics_dict = get_metrics_dict(metrics)
|
| 108 |
+
|
| 109 |
+
# Add GPU metrics for alerting
|
| 110 |
+
gpu_stats = gpu_collector.collect()
|
| 111 |
+
if gpu_stats:
|
| 112 |
+
max_gpu_memory = max(s.memory_percent for s in gpu_stats)
|
| 113 |
+
metrics_dict["gpu_memory_percent"] = max_gpu_memory
|
| 114 |
+
|
| 115 |
+
new_alerts = alert_engine.evaluate(metrics_dict)
|
| 116 |
+
|
| 117 |
+
# Dispatch new alerts (handle async properly)
|
| 118 |
+
for alert in new_alerts:
|
| 119 |
+
try:
|
| 120 |
+
loop = asyncio.get_event_loop()
|
| 121 |
+
if loop.is_running():
|
| 122 |
+
asyncio.create_task(alert_dispatcher.dispatch(alert))
|
| 123 |
+
else:
|
| 124 |
+
loop.run_until_complete(alert_dispatcher.dispatch(alert))
|
| 125 |
+
except RuntimeError:
|
| 126 |
+
pass # No event loop available
|
| 127 |
+
|
| 128 |
+
# Get alert badge
|
| 129 |
+
active_alerts = alert_engine.get_active_alerts()
|
| 130 |
+
alert_badge = get_alert_badge_html(active_alerts)
|
| 131 |
+
|
| 132 |
+
# Connection status
|
| 133 |
+
connection_status = check_connection()
|
| 134 |
+
model_name = get_model_name()
|
| 135 |
+
|
| 136 |
+
return (
|
| 137 |
+
# Header
|
| 138 |
+
connection_status,
|
| 139 |
+
model_name,
|
| 140 |
+
alert_badge,
|
| 141 |
+
# GPU tab
|
| 142 |
+
gpu_table,
|
| 143 |
+
gpu_memory_plot,
|
| 144 |
+
gpu_util_plot,
|
| 145 |
+
nccl_status,
|
| 146 |
+
# Inference tab
|
| 147 |
+
throughput,
|
| 148 |
+
ttft,
|
| 149 |
+
batch_size,
|
| 150 |
+
kv_cache,
|
| 151 |
+
throughput_plot,
|
| 152 |
+
prefill_pct,
|
| 153 |
+
decode_pct,
|
| 154 |
+
queue_depth,
|
| 155 |
+
e2e_latency,
|
| 156 |
+
latency_plot,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def create_dashboard():
|
| 161 |
+
"""Create the main dashboard application."""
|
| 162 |
+
|
| 163 |
+
custom_css = """
|
| 164 |
+
.gradio-container { max-width: 1400px !important; }
|
| 165 |
+
.panel-header { font-size: 1.2em; font-weight: bold; margin-bottom: 10px; }
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
with gr.Blocks(title="LLM Inference Dashboard") as app:
|
| 169 |
+
gr.Markdown("# LLM Inference Dashboard")
|
| 170 |
+
gr.Markdown("*Real-time monitoring for vLLM inference servers*")
|
| 171 |
+
|
| 172 |
+
# Header row: connection status, model info, active alerts
|
| 173 |
+
with gr.Row():
|
| 174 |
+
status_indicator = gr.HTML(
|
| 175 |
+
value=check_connection(),
|
| 176 |
+
label="Connection",
|
| 177 |
+
)
|
| 178 |
+
model_name_display = gr.Textbox(
|
| 179 |
+
label="Model",
|
| 180 |
+
value=get_model_name(),
|
| 181 |
+
interactive=False,
|
| 182 |
+
scale=2,
|
| 183 |
+
)
|
| 184 |
+
active_alerts_display = gr.HTML(
|
| 185 |
+
value=get_alert_badge_html([]),
|
| 186 |
+
label="Alerts",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Main tabs
|
| 190 |
+
with gr.Tabs():
|
| 191 |
+
# Tab 1: GPU Status
|
| 192 |
+
with gr.Tab("GPU / Rank Status"):
|
| 193 |
+
gpu_components = create_gpu_panel(history)
|
| 194 |
+
|
| 195 |
+
# Tab 2: Inference Metrics
|
| 196 |
+
with gr.Tab("Inference"):
|
| 197 |
+
inference_components = create_inference_panel(history)
|
| 198 |
+
|
| 199 |
+
# Tab 3: Quantization
|
| 200 |
+
with gr.Tab("Quantization"):
|
| 201 |
+
quant_components = create_quant_panel()
|
| 202 |
+
|
| 203 |
+
# Initial update
|
| 204 |
+
(
|
| 205 |
+
quant_type, bits, group_size, quant_details, layer_table
|
| 206 |
+
) = update_quant_panel(quant_collector)
|
| 207 |
+
|
| 208 |
+
quant_components["quant_type"].value = quant_type
|
| 209 |
+
quant_components["bits"].value = bits
|
| 210 |
+
quant_components["group_size"].value = group_size
|
| 211 |
+
|
| 212 |
+
# Tab 4: Loading Progress
|
| 213 |
+
with gr.Tab("Loading"):
|
| 214 |
+
loading_components = create_loading_panel()
|
| 215 |
+
|
| 216 |
+
# Tab 5: Alerts
|
| 217 |
+
with gr.Tab("Alerts"):
|
| 218 |
+
alerts_components = create_alerts_panel(alert_engine, alert_dispatcher)
|
| 219 |
+
|
| 220 |
+
# Tab 6: Request Tracing
|
| 221 |
+
with gr.Tab("Tracing"):
|
| 222 |
+
tracing_components = create_tracing_panel(request_tracer)
|
| 223 |
+
|
| 224 |
+
# Tab 7: A/B Comparison
|
| 225 |
+
with gr.Tab("A/B Compare"):
|
| 226 |
+
comparison_components = create_comparison_panel()
|
| 227 |
+
|
| 228 |
+
# Tab 8: Load Testing
|
| 229 |
+
with gr.Tab("Load Test"):
|
| 230 |
+
loadtest_components = create_loadtest_panel()
|
| 231 |
+
|
| 232 |
+
# Auto-refresh timer
|
| 233 |
+
timer = gr.Timer(config.refresh_interval)
|
| 234 |
+
|
| 235 |
+
# Collect all outputs for timer update
|
| 236 |
+
timer_outputs = [
|
| 237 |
+
# Header
|
| 238 |
+
status_indicator,
|
| 239 |
+
model_name_display,
|
| 240 |
+
active_alerts_display,
|
| 241 |
+
# GPU tab
|
| 242 |
+
gpu_components["gpu_table"],
|
| 243 |
+
gpu_components["gpu_memory_plot"],
|
| 244 |
+
gpu_components["gpu_util_plot"],
|
| 245 |
+
gpu_components["nccl_status"],
|
| 246 |
+
# Inference tab
|
| 247 |
+
inference_components["throughput"],
|
| 248 |
+
inference_components["ttft"],
|
| 249 |
+
inference_components["batch_size"],
|
| 250 |
+
inference_components["kv_cache"],
|
| 251 |
+
inference_components["throughput_plot"],
|
| 252 |
+
inference_components["prefill_pct"],
|
| 253 |
+
inference_components["decode_pct"],
|
| 254 |
+
inference_components["queue_depth"],
|
| 255 |
+
inference_components["e2e_latency"],
|
| 256 |
+
inference_components["latency_plot"],
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
timer.tick(fn=update_all_metrics, outputs=timer_outputs)
|
| 260 |
+
|
| 261 |
+
# Manual refresh for tabs that don't auto-update
|
| 262 |
+
def refresh_quant():
|
| 263 |
+
return update_quant_panel(quant_collector)
|
| 264 |
+
|
| 265 |
+
def refresh_loading():
|
| 266 |
+
return update_loading_panel(loading_tracker)
|
| 267 |
+
|
| 268 |
+
def refresh_alerts():
|
| 269 |
+
return update_alerts_panel(alert_engine, db)
|
| 270 |
+
|
| 271 |
+
return app
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def main():
|
| 275 |
+
"""Main entry point."""
|
| 276 |
+
logger.info("Starting LLM Inference Dashboard")
|
| 277 |
+
logger.info(f"vLLM endpoint: {config.metrics_endpoint}")
|
| 278 |
+
logger.info(f"Database: {config.db_path}")
|
| 279 |
+
|
| 280 |
+
# Check initial connection
|
| 281 |
+
if vllm_collector.check_connection():
|
| 282 |
+
logger.info("Successfully connected to vLLM server")
|
| 283 |
+
|
| 284 |
+
# Set model ready if connected
|
| 285 |
+
loading_tracker.set_ready()
|
| 286 |
+
|
| 287 |
+
# Try to detect quantization
|
| 288 |
+
metrics = vllm_collector.collect()
|
| 289 |
+
if metrics.model_name:
|
| 290 |
+
quant_collector.set_model_path(metrics.model_name)
|
| 291 |
+
else:
|
| 292 |
+
logger.warning("Could not connect to vLLM server - dashboard will show mock data")
|
| 293 |
+
|
| 294 |
+
# Create and launch the dashboard
|
| 295 |
+
app = create_dashboard()
|
| 296 |
+
|
| 297 |
+
# Check if running on HuggingFace Spaces
|
| 298 |
+
if os.getenv("SPACE_ID"):
|
| 299 |
+
app.launch()
|
| 300 |
+
else:
|
| 301 |
+
app.launch(
|
| 302 |
+
server_name="0.0.0.0",
|
| 303 |
+
server_port=7860,
|
| 304 |
+
share=False,
|
| 305 |
+
show_error=True,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# For HuggingFace Spaces - create demo instance
|
| 310 |
+
demo = create_dashboard()
|
| 311 |
+
|
| 312 |
+
if __name__ == "__main__":
|
| 313 |
+
main()
|
collectors/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data collectors for monitoring vLLM inference."""
|
| 2 |
+
|
| 3 |
+
from .gpu_collector import GPUCollector
|
| 4 |
+
from .vllm_collector import VLLMCollector
|
| 5 |
+
from .quant_collector import QuantizationCollector
|
| 6 |
+
from .loading_tracker import LoadingTracker
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"GPUCollector",
|
| 10 |
+
"VLLMCollector",
|
| 11 |
+
"QuantizationCollector",
|
| 12 |
+
"LoadingTracker",
|
| 13 |
+
]
|
collectors/gpu_collector.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GPU statistics collector using pynvml."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
# Try to import pynvml, provide mock if unavailable
|
| 10 |
+
try:
|
| 11 |
+
import pynvml
|
| 12 |
+
PYNVML_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
PYNVML_AVAILABLE = False
|
| 15 |
+
logger.warning("pynvml not available - GPU stats will be simulated")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class GPUStats:
|
| 20 |
+
"""Statistics for a single GPU."""
|
| 21 |
+
gpu_id: int
|
| 22 |
+
name: str
|
| 23 |
+
memory_used_gb: float
|
| 24 |
+
memory_total_gb: float
|
| 25 |
+
memory_percent: float
|
| 26 |
+
gpu_util_percent: float
|
| 27 |
+
temperature_c: int
|
| 28 |
+
power_watts: float
|
| 29 |
+
power_limit_watts: float
|
| 30 |
+
tp_rank: Optional[int] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class GPUCollector:
|
| 34 |
+
"""Collects GPU statistics via NVIDIA Management Library."""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
"""Initialize the GPU collector."""
|
| 38 |
+
self._initialized = False
|
| 39 |
+
self._gpu_count = 0
|
| 40 |
+
self._rank_mapping: dict = {}
|
| 41 |
+
|
| 42 |
+
if PYNVML_AVAILABLE:
|
| 43 |
+
try:
|
| 44 |
+
pynvml.nvmlInit()
|
| 45 |
+
self._initialized = True
|
| 46 |
+
self._gpu_count = pynvml.nvmlDeviceGetCount()
|
| 47 |
+
logger.info(f"Initialized pynvml with {self._gpu_count} GPUs")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Failed to initialize pynvml: {e}")
|
| 50 |
+
|
| 51 |
+
def set_rank_mapping(self, mapping: dict) -> None:
|
| 52 |
+
"""
|
| 53 |
+
Set tensor parallel rank to GPU ID mapping.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
mapping: Dictionary mapping TP rank to GPU ID
|
| 57 |
+
"""
|
| 58 |
+
self._rank_mapping = mapping
|
| 59 |
+
|
| 60 |
+
def get_gpu_count(self) -> int:
|
| 61 |
+
"""Get the number of available GPUs."""
|
| 62 |
+
return self._gpu_count
|
| 63 |
+
|
| 64 |
+
def collect(self) -> List[GPUStats]:
|
| 65 |
+
"""
|
| 66 |
+
Collect stats for all GPUs.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
List of GPUStats for each GPU
|
| 70 |
+
"""
|
| 71 |
+
if not self._initialized:
|
| 72 |
+
return self._get_mock_stats()
|
| 73 |
+
|
| 74 |
+
stats = []
|
| 75 |
+
for i in range(self._gpu_count):
|
| 76 |
+
try:
|
| 77 |
+
stat = self._collect_single_gpu(i)
|
| 78 |
+
stats.append(stat)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.error(f"Error collecting stats for GPU {i}: {e}")
|
| 81 |
+
|
| 82 |
+
return stats
|
| 83 |
+
|
| 84 |
+
def _collect_single_gpu(self, gpu_id: int) -> GPUStats:
|
| 85 |
+
"""Collect stats for a single GPU."""
|
| 86 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
| 87 |
+
|
| 88 |
+
# Get device name
|
| 89 |
+
name = pynvml.nvmlDeviceGetName(handle)
|
| 90 |
+
if isinstance(name, bytes):
|
| 91 |
+
name = name.decode("utf-8")
|
| 92 |
+
|
| 93 |
+
# Memory info
|
| 94 |
+
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 95 |
+
memory_used_gb = mem_info.used / 1e9
|
| 96 |
+
memory_total_gb = mem_info.total / 1e9
|
| 97 |
+
memory_percent = (mem_info.used / mem_info.total) * 100
|
| 98 |
+
|
| 99 |
+
# Utilization
|
| 100 |
+
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
| 101 |
+
gpu_util_percent = util.gpu
|
| 102 |
+
|
| 103 |
+
# Temperature
|
| 104 |
+
temperature_c = pynvml.nvmlDeviceGetTemperature(
|
| 105 |
+
handle, pynvml.NVML_TEMPERATURE_GPU
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Power
|
| 109 |
+
try:
|
| 110 |
+
power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
|
| 111 |
+
power_limit_watts = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) / 1000.0
|
| 112 |
+
except pynvml.NVMLError:
|
| 113 |
+
power_watts = 0
|
| 114 |
+
power_limit_watts = 0
|
| 115 |
+
|
| 116 |
+
# Find TP rank for this GPU
|
| 117 |
+
tp_rank = None
|
| 118 |
+
for rank, gid in self._rank_mapping.items():
|
| 119 |
+
if gid == gpu_id:
|
| 120 |
+
tp_rank = rank
|
| 121 |
+
break
|
| 122 |
+
|
| 123 |
+
return GPUStats(
|
| 124 |
+
gpu_id=gpu_id,
|
| 125 |
+
name=name,
|
| 126 |
+
memory_used_gb=memory_used_gb,
|
| 127 |
+
memory_total_gb=memory_total_gb,
|
| 128 |
+
memory_percent=memory_percent,
|
| 129 |
+
gpu_util_percent=gpu_util_percent,
|
| 130 |
+
temperature_c=temperature_c,
|
| 131 |
+
power_watts=power_watts,
|
| 132 |
+
power_limit_watts=power_limit_watts,
|
| 133 |
+
tp_rank=tp_rank,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def _get_mock_stats(self) -> List[GPUStats]:
|
| 137 |
+
"""Return mock stats when pynvml is not available."""
|
| 138 |
+
import random
|
| 139 |
+
|
| 140 |
+
mock_gpus = [
|
| 141 |
+
GPUStats(
|
| 142 |
+
gpu_id=0,
|
| 143 |
+
name="Mock GPU 0",
|
| 144 |
+
memory_used_gb=random.uniform(10, 20),
|
| 145 |
+
memory_total_gb=24.0,
|
| 146 |
+
memory_percent=random.uniform(40, 80),
|
| 147 |
+
gpu_util_percent=random.uniform(20, 90),
|
| 148 |
+
temperature_c=random.randint(40, 70),
|
| 149 |
+
power_watts=random.uniform(100, 300),
|
| 150 |
+
power_limit_watts=350,
|
| 151 |
+
tp_rank=0,
|
| 152 |
+
),
|
| 153 |
+
GPUStats(
|
| 154 |
+
gpu_id=1,
|
| 155 |
+
name="Mock GPU 1",
|
| 156 |
+
memory_used_gb=random.uniform(10, 20),
|
| 157 |
+
memory_total_gb=24.0,
|
| 158 |
+
memory_percent=random.uniform(40, 80),
|
| 159 |
+
gpu_util_percent=random.uniform(20, 90),
|
| 160 |
+
temperature_c=random.randint(40, 70),
|
| 161 |
+
power_watts=random.uniform(100, 300),
|
| 162 |
+
power_limit_watts=350,
|
| 163 |
+
tp_rank=1,
|
| 164 |
+
),
|
| 165 |
+
]
|
| 166 |
+
return mock_gpus
|
| 167 |
+
|
| 168 |
+
def shutdown(self) -> None:
|
| 169 |
+
"""Clean up NVML resources."""
|
| 170 |
+
if self._initialized and PYNVML_AVAILABLE:
|
| 171 |
+
try:
|
| 172 |
+
pynvml.nvmlShutdown()
|
| 173 |
+
except Exception:
|
| 174 |
+
pass
|
collectors/loading_tracker.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading progress tracker."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
import logging
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, List, Dict, Any
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from enum import Enum
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LoadingStatus(Enum):
|
| 15 |
+
"""Status of model loading."""
|
| 16 |
+
NOT_STARTED = "not_started"
|
| 17 |
+
DOWNLOADING = "downloading"
|
| 18 |
+
LOADING = "loading"
|
| 19 |
+
READY = "ready"
|
| 20 |
+
ERROR = "error"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ShardInfo:
|
| 25 |
+
"""Information about a model shard file."""
|
| 26 |
+
filename: str
|
| 27 |
+
size_mb: float
|
| 28 |
+
status: str # pending, loading, loaded
|
| 29 |
+
layers: List[str]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class LoadingProgress:
|
| 34 |
+
"""Overall loading progress."""
|
| 35 |
+
status: LoadingStatus
|
| 36 |
+
total_shards: int
|
| 37 |
+
loaded_shards: int
|
| 38 |
+
current_shard: Optional[str]
|
| 39 |
+
progress_percent: float
|
| 40 |
+
layers_loaded: int
|
| 41 |
+
total_layers: int
|
| 42 |
+
estimated_remaining_seconds: Optional[float]
|
| 43 |
+
error_message: Optional[str] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LoadingTracker:
|
| 47 |
+
"""Tracks model loading progress."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, model_path: Optional[str] = None):
|
| 50 |
+
"""
|
| 51 |
+
Initialize loading tracker.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
model_path: Path to model directory
|
| 55 |
+
"""
|
| 56 |
+
self.model_path = model_path
|
| 57 |
+
self._shards: List[ShardInfo] = []
|
| 58 |
+
self._status = LoadingStatus.NOT_STARTED
|
| 59 |
+
self._progress = 0.0
|
| 60 |
+
self._current_shard: Optional[str] = None
|
| 61 |
+
self._layers_loaded = 0
|
| 62 |
+
self._total_layers = 0
|
| 63 |
+
self._start_time: Optional[float] = None
|
| 64 |
+
|
| 65 |
+
def set_model_path(self, model_path: str) -> None:
|
| 66 |
+
"""Set or update the model path."""
|
| 67 |
+
self.model_path = model_path
|
| 68 |
+
self._load_shard_info()
|
| 69 |
+
|
| 70 |
+
def _load_shard_info(self) -> None:
|
| 71 |
+
"""Load shard information from safetensors index."""
|
| 72 |
+
if not self.model_path:
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
index_path = self._resolve_path("model.safetensors.index.json")
|
| 76 |
+
if not index_path:
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
with open(index_path) as f:
|
| 81 |
+
index = json.load(f)
|
| 82 |
+
|
| 83 |
+
weight_map = index.get("weight_map", {})
|
| 84 |
+
|
| 85 |
+
# Group weights by shard file
|
| 86 |
+
shard_weights: Dict[str, List[str]] = {}
|
| 87 |
+
for weight_name, shard_file in weight_map.items():
|
| 88 |
+
if shard_file not in shard_weights:
|
| 89 |
+
shard_weights[shard_file] = []
|
| 90 |
+
shard_weights[shard_file].append(weight_name)
|
| 91 |
+
|
| 92 |
+
# Create shard info
|
| 93 |
+
self._shards = []
|
| 94 |
+
for shard_file, weights in sorted(shard_weights.items()):
|
| 95 |
+
shard_path = self._resolve_path(shard_file)
|
| 96 |
+
size_mb = 0
|
| 97 |
+
if shard_path and shard_path.exists():
|
| 98 |
+
size_mb = shard_path.stat().st_size / (1024 * 1024)
|
| 99 |
+
|
| 100 |
+
# Extract layer names
|
| 101 |
+
layers = list(set(
|
| 102 |
+
".".join(w.split(".")[:3])
|
| 103 |
+
for w in weights
|
| 104 |
+
if len(w.split(".")) >= 3
|
| 105 |
+
))
|
| 106 |
+
|
| 107 |
+
self._shards.append(ShardInfo(
|
| 108 |
+
filename=shard_file,
|
| 109 |
+
size_mb=size_mb,
|
| 110 |
+
status="pending",
|
| 111 |
+
layers=layers,
|
| 112 |
+
))
|
| 113 |
+
|
| 114 |
+
# Count total layers
|
| 115 |
+
all_layers = set()
|
| 116 |
+
for shard in self._shards:
|
| 117 |
+
all_layers.update(shard.layers)
|
| 118 |
+
self._total_layers = len(all_layers)
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"Error loading shard info: {e}")
|
| 122 |
+
|
| 123 |
+
def _resolve_path(self, filename: str) -> Optional[Path]:
|
| 124 |
+
"""Resolve path to a file in the model directory."""
|
| 125 |
+
if not self.model_path:
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
local_path = Path(self.model_path) / filename
|
| 129 |
+
if local_path.exists():
|
| 130 |
+
return local_path
|
| 131 |
+
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
def update_from_log(self, log_line: str) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Update progress from a vLLM log line.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
log_line: Log line from vLLM server
|
| 140 |
+
"""
|
| 141 |
+
# Detect loading start
|
| 142 |
+
if "Loading model" in log_line:
|
| 143 |
+
self._status = LoadingStatus.LOADING
|
| 144 |
+
import time
|
| 145 |
+
self._start_time = time.time()
|
| 146 |
+
|
| 147 |
+
# Detect shard loading
|
| 148 |
+
match = re.search(r"Loading safetensors: (\d+)/(\d+)", log_line)
|
| 149 |
+
if match:
|
| 150 |
+
loaded = int(match.group(1))
|
| 151 |
+
total = int(match.group(2))
|
| 152 |
+
self._progress = loaded / total * 100
|
| 153 |
+
for i, shard in enumerate(self._shards):
|
| 154 |
+
if i < loaded:
|
| 155 |
+
shard.status = "loaded"
|
| 156 |
+
elif i == loaded:
|
| 157 |
+
shard.status = "loading"
|
| 158 |
+
self._current_shard = shard.filename
|
| 159 |
+
|
| 160 |
+
# Detect completion
|
| 161 |
+
if "Model loaded" in log_line or "Running with" in log_line:
|
| 162 |
+
self._status = LoadingStatus.READY
|
| 163 |
+
self._progress = 100.0
|
| 164 |
+
for shard in self._shards:
|
| 165 |
+
shard.status = "loaded"
|
| 166 |
+
|
| 167 |
+
# Detect errors
|
| 168 |
+
if "Error" in log_line or "Exception" in log_line:
|
| 169 |
+
self._status = LoadingStatus.ERROR
|
| 170 |
+
|
| 171 |
+
def get_progress(self) -> LoadingProgress:
|
| 172 |
+
"""
|
| 173 |
+
Get current loading progress.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
LoadingProgress with current state
|
| 177 |
+
"""
|
| 178 |
+
loaded_shards = sum(1 for s in self._shards if s.status == "loaded")
|
| 179 |
+
total_shards = len(self._shards) if self._shards else 1
|
| 180 |
+
|
| 181 |
+
# Estimate remaining time
|
| 182 |
+
remaining = None
|
| 183 |
+
if self._start_time and self._progress > 0:
|
| 184 |
+
import time
|
| 185 |
+
elapsed = time.time() - self._start_time
|
| 186 |
+
remaining = (elapsed / self._progress) * (100 - self._progress)
|
| 187 |
+
|
| 188 |
+
# Count loaded layers
|
| 189 |
+
loaded_layers = set()
|
| 190 |
+
for shard in self._shards:
|
| 191 |
+
if shard.status == "loaded":
|
| 192 |
+
loaded_layers.update(shard.layers)
|
| 193 |
+
|
| 194 |
+
return LoadingProgress(
|
| 195 |
+
status=self._status,
|
| 196 |
+
total_shards=total_shards,
|
| 197 |
+
loaded_shards=loaded_shards,
|
| 198 |
+
current_shard=self._current_shard,
|
| 199 |
+
progress_percent=self._progress,
|
| 200 |
+
layers_loaded=len(loaded_layers),
|
| 201 |
+
total_layers=self._total_layers,
|
| 202 |
+
estimated_remaining_seconds=remaining,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def get_shards(self) -> List[ShardInfo]:
|
| 206 |
+
"""Get list of all shards with their status."""
|
| 207 |
+
return self._shards
|
| 208 |
+
|
| 209 |
+
def set_ready(self) -> None:
|
| 210 |
+
"""Mark the model as fully loaded."""
|
| 211 |
+
self._status = LoadingStatus.READY
|
| 212 |
+
self._progress = 100.0
|
| 213 |
+
for shard in self._shards:
|
| 214 |
+
shard.status = "loaded"
|
| 215 |
+
|
| 216 |
+
def reset(self) -> None:
|
| 217 |
+
"""Reset progress tracker."""
|
| 218 |
+
self._status = LoadingStatus.NOT_STARTED
|
| 219 |
+
self._progress = 0.0
|
| 220 |
+
self._current_shard = None
|
| 221 |
+
self._layers_loaded = 0
|
| 222 |
+
self._start_time = None
|
| 223 |
+
for shard in self._shards:
|
| 224 |
+
shard.status = "pending"
|
collectors/quant_collector.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quantization information collector."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Dict, Any, List
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class QuantizationInfo:
|
| 15 |
+
"""Quantization details for a model."""
|
| 16 |
+
method: str # GPTQ, AWQ, bitsandbytes, None
|
| 17 |
+
bits: int
|
| 18 |
+
group_size: Optional[int] = None
|
| 19 |
+
desc_act: Optional[bool] = None
|
| 20 |
+
sym: Optional[bool] = None
|
| 21 |
+
compute_dtype: Optional[str] = None
|
| 22 |
+
quant_type: Optional[str] = None # For bitsandbytes: nf4, fp4
|
| 23 |
+
double_quant: Optional[bool] = None
|
| 24 |
+
raw_config: Dict[str, Any] = None
|
| 25 |
+
|
| 26 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 27 |
+
"""Convert to dictionary for JSON display."""
|
| 28 |
+
result = {
|
| 29 |
+
"method": self.method,
|
| 30 |
+
"bits": self.bits,
|
| 31 |
+
}
|
| 32 |
+
if self.group_size is not None:
|
| 33 |
+
result["group_size"] = self.group_size
|
| 34 |
+
if self.desc_act is not None:
|
| 35 |
+
result["desc_act"] = self.desc_act
|
| 36 |
+
if self.sym is not None:
|
| 37 |
+
result["symmetric"] = self.sym
|
| 38 |
+
if self.compute_dtype:
|
| 39 |
+
result["compute_dtype"] = self.compute_dtype
|
| 40 |
+
if self.quant_type:
|
| 41 |
+
result["quant_type"] = self.quant_type
|
| 42 |
+
if self.double_quant is not None:
|
| 43 |
+
result["double_quant"] = self.double_quant
|
| 44 |
+
return result
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class LayerPrecision:
|
| 49 |
+
"""Precision information for a model layer."""
|
| 50 |
+
layer_name: str
|
| 51 |
+
bits: int
|
| 52 |
+
group_size: Optional[int]
|
| 53 |
+
dtype: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class QuantizationCollector:
|
| 57 |
+
"""Detects and collects quantization information from model configs."""
|
| 58 |
+
|
| 59 |
+
def __init__(self, model_path: Optional[str] = None):
|
| 60 |
+
"""
|
| 61 |
+
Initialize quantization collector.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
model_path: Path to model directory (local or HF model ID)
|
| 65 |
+
"""
|
| 66 |
+
self.model_path = model_path
|
| 67 |
+
self._cached_info: Optional[QuantizationInfo] = None
|
| 68 |
+
|
| 69 |
+
def set_model_path(self, model_path: str) -> None:
|
| 70 |
+
"""Set or update the model path."""
|
| 71 |
+
self.model_path = model_path
|
| 72 |
+
self._cached_info = None
|
| 73 |
+
|
| 74 |
+
def detect(self) -> QuantizationInfo:
|
| 75 |
+
"""
|
| 76 |
+
Detect quantization method and settings.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
QuantizationInfo with detected settings
|
| 80 |
+
"""
|
| 81 |
+
if self._cached_info is not None:
|
| 82 |
+
return self._cached_info
|
| 83 |
+
|
| 84 |
+
if not self.model_path:
|
| 85 |
+
return QuantizationInfo(method="Unknown", bits=16)
|
| 86 |
+
|
| 87 |
+
# Try to load config files
|
| 88 |
+
config = self._load_config()
|
| 89 |
+
quant_config = self._load_quant_config()
|
| 90 |
+
|
| 91 |
+
info = self._detect_quantization(config, quant_config)
|
| 92 |
+
self._cached_info = info
|
| 93 |
+
return info
|
| 94 |
+
|
| 95 |
+
def _load_config(self) -> Optional[Dict[str, Any]]:
|
| 96 |
+
"""Load config.json from model path."""
|
| 97 |
+
config_path = self._resolve_path("config.json")
|
| 98 |
+
if config_path and config_path.exists():
|
| 99 |
+
try:
|
| 100 |
+
with open(config_path) as f:
|
| 101 |
+
return json.load(f)
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"Error loading config.json: {e}")
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
def _load_quant_config(self) -> Optional[Dict[str, Any]]:
|
| 107 |
+
"""Load quantize_config.json (GPTQ/AWQ) from model path."""
|
| 108 |
+
config_path = self._resolve_path("quantize_config.json")
|
| 109 |
+
if config_path and config_path.exists():
|
| 110 |
+
try:
|
| 111 |
+
with open(config_path) as f:
|
| 112 |
+
return json.load(f)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Error loading quantize_config.json: {e}")
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
def _resolve_path(self, filename: str) -> Optional[Path]:
|
| 118 |
+
"""Resolve path to a file in the model directory."""
|
| 119 |
+
if not self.model_path:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
# Handle local paths
|
| 123 |
+
local_path = Path(self.model_path) / filename
|
| 124 |
+
if local_path.exists():
|
| 125 |
+
return local_path
|
| 126 |
+
|
| 127 |
+
# Handle HuggingFace cache paths
|
| 128 |
+
cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
|
| 129 |
+
if cache_dir.exists():
|
| 130 |
+
# Search for model in cache
|
| 131 |
+
for model_dir in cache_dir.glob("models--*"):
|
| 132 |
+
model_name = model_dir.name.replace("models--", "").replace("--", "/")
|
| 133 |
+
if model_name.lower() == self.model_path.lower().replace("/", "--"):
|
| 134 |
+
snapshot_path = model_dir / "snapshots"
|
| 135 |
+
if snapshot_path.exists():
|
| 136 |
+
# Get latest snapshot
|
| 137 |
+
snapshots = list(snapshot_path.iterdir())
|
| 138 |
+
if snapshots:
|
| 139 |
+
file_path = snapshots[-1] / filename
|
| 140 |
+
if file_path.exists():
|
| 141 |
+
return file_path
|
| 142 |
+
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
def _detect_quantization(
|
| 146 |
+
self,
|
| 147 |
+
config: Optional[Dict[str, Any]],
|
| 148 |
+
quant_config: Optional[Dict[str, Any]],
|
| 149 |
+
) -> QuantizationInfo:
|
| 150 |
+
"""Detect quantization from config files."""
|
| 151 |
+
|
| 152 |
+
# Check for GPTQ via quantize_config.json
|
| 153 |
+
if quant_config:
|
| 154 |
+
if "bits" in quant_config:
|
| 155 |
+
return QuantizationInfo(
|
| 156 |
+
method="GPTQ",
|
| 157 |
+
bits=quant_config.get("bits", 4),
|
| 158 |
+
group_size=quant_config.get("group_size", 128),
|
| 159 |
+
desc_act=quant_config.get("desc_act", False),
|
| 160 |
+
sym=quant_config.get("sym", True),
|
| 161 |
+
raw_config=quant_config,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if not config:
|
| 165 |
+
return QuantizationInfo(method="Unknown", bits=16)
|
| 166 |
+
|
| 167 |
+
# Check for quantization_config in config.json
|
| 168 |
+
qc = config.get("quantization_config", {})
|
| 169 |
+
|
| 170 |
+
if qc:
|
| 171 |
+
quant_method = qc.get("quant_method", "").lower()
|
| 172 |
+
|
| 173 |
+
# AWQ
|
| 174 |
+
if quant_method == "awq":
|
| 175 |
+
return QuantizationInfo(
|
| 176 |
+
method="AWQ",
|
| 177 |
+
bits=qc.get("bits", 4),
|
| 178 |
+
group_size=qc.get("group_size", 128),
|
| 179 |
+
raw_config=qc,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# GPTQ (in config.json)
|
| 183 |
+
if quant_method == "gptq":
|
| 184 |
+
return QuantizationInfo(
|
| 185 |
+
method="GPTQ",
|
| 186 |
+
bits=qc.get("bits", 4),
|
| 187 |
+
group_size=qc.get("group_size", 128),
|
| 188 |
+
desc_act=qc.get("desc_act", False),
|
| 189 |
+
sym=qc.get("sym", True),
|
| 190 |
+
raw_config=qc,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# bitsandbytes
|
| 194 |
+
if qc.get("load_in_4bit") or qc.get("load_in_8bit"):
|
| 195 |
+
bits = 4 if qc.get("load_in_4bit") else 8
|
| 196 |
+
return QuantizationInfo(
|
| 197 |
+
method="bitsandbytes",
|
| 198 |
+
bits=bits,
|
| 199 |
+
compute_dtype=qc.get("bnb_4bit_compute_dtype", "float16"),
|
| 200 |
+
quant_type=qc.get("bnb_4bit_quant_type", "nf4"),
|
| 201 |
+
double_quant=qc.get("bnb_4bit_use_double_quant", False),
|
| 202 |
+
raw_config=qc,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Check torch_dtype for fp16/bf16
|
| 206 |
+
torch_dtype = config.get("torch_dtype", "float16")
|
| 207 |
+
if torch_dtype in ("float16", "bfloat16"):
|
| 208 |
+
return QuantizationInfo(
|
| 209 |
+
method="None (FP16/BF16)",
|
| 210 |
+
bits=16,
|
| 211 |
+
compute_dtype=torch_dtype,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return QuantizationInfo(method="Unknown", bits=16)
|
| 215 |
+
|
| 216 |
+
def get_layer_precisions(self) -> List[LayerPrecision]:
|
| 217 |
+
"""
|
| 218 |
+
Get per-layer precision information.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
List of LayerPrecision for each layer
|
| 222 |
+
"""
|
| 223 |
+
info = self.detect()
|
| 224 |
+
|
| 225 |
+
# For quantized models, all layers typically have same precision
|
| 226 |
+
# This could be extended to parse safetensors index for more detail
|
| 227 |
+
|
| 228 |
+
index_path = self._resolve_path("model.safetensors.index.json")
|
| 229 |
+
if not index_path or not index_path.exists():
|
| 230 |
+
return []
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
with open(index_path) as f:
|
| 234 |
+
index = json.load(f)
|
| 235 |
+
|
| 236 |
+
weight_map = index.get("weight_map", {})
|
| 237 |
+
layers = []
|
| 238 |
+
seen_layers = set()
|
| 239 |
+
|
| 240 |
+
for weight_name in weight_map.keys():
|
| 241 |
+
# Extract layer name
|
| 242 |
+
parts = weight_name.split(".")
|
| 243 |
+
if len(parts) >= 3:
|
| 244 |
+
layer_name = ".".join(parts[:3])
|
| 245 |
+
if layer_name not in seen_layers:
|
| 246 |
+
seen_layers.add(layer_name)
|
| 247 |
+
layers.append(
|
| 248 |
+
LayerPrecision(
|
| 249 |
+
layer_name=layer_name,
|
| 250 |
+
bits=info.bits,
|
| 251 |
+
group_size=info.group_size,
|
| 252 |
+
dtype=info.compute_dtype or "float16",
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
return layers
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Error parsing layer precisions: {e}")
|
| 259 |
+
return []
|
collectors/vllm_collector.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""vLLM metrics collector via Prometheus endpoint."""
|
| 2 |
+
|
| 3 |
+
import requests
|
| 4 |
+
import logging
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional, Dict, List, Any
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
from utils.prometheus_parser import (
|
| 10 |
+
parse_prometheus_metrics,
|
| 11 |
+
get_metric_value,
|
| 12 |
+
get_histogram_quantile,
|
| 13 |
+
MetricSample,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class InferenceMetrics:
|
| 21 |
+
"""Inference metrics from vLLM."""
|
| 22 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 23 |
+
|
| 24 |
+
# Request counts
|
| 25 |
+
num_requests_running: int = 0
|
| 26 |
+
num_requests_waiting: int = 0
|
| 27 |
+
num_requests_swapped: int = 0
|
| 28 |
+
|
| 29 |
+
# Token throughput
|
| 30 |
+
prompt_tokens_total: int = 0
|
| 31 |
+
generation_tokens_total: int = 0
|
| 32 |
+
tokens_per_second: float = 0.0
|
| 33 |
+
|
| 34 |
+
# Latency
|
| 35 |
+
ttft_ms: float = 0.0 # Time to first token
|
| 36 |
+
tpot_ms: float = 0.0 # Time per output token
|
| 37 |
+
e2e_latency_ms: float = 0.0 # End-to-end latency
|
| 38 |
+
|
| 39 |
+
# Cache
|
| 40 |
+
kv_cache_usage_percent: float = 0.0
|
| 41 |
+
gpu_cache_usage_percent: float = 0.0
|
| 42 |
+
cpu_cache_usage_percent: float = 0.0
|
| 43 |
+
|
| 44 |
+
# Model info
|
| 45 |
+
model_name: str = ""
|
| 46 |
+
max_model_len: int = 0
|
| 47 |
+
|
| 48 |
+
# Derived
|
| 49 |
+
prefill_ratio: float = 0.0
|
| 50 |
+
batch_size: int = 0
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class VLLMCollector:
|
| 54 |
+
"""Collects metrics from vLLM Prometheus endpoint."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, metrics_url: str = "http://localhost:8000/metrics"):
|
| 57 |
+
"""
|
| 58 |
+
Initialize the vLLM collector.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
metrics_url: URL to vLLM's /metrics endpoint
|
| 62 |
+
"""
|
| 63 |
+
self.metrics_url = metrics_url
|
| 64 |
+
self._last_prompt_tokens = 0
|
| 65 |
+
self._last_generation_tokens = 0
|
| 66 |
+
self._last_collect_time: Optional[datetime] = None
|
| 67 |
+
self._connected = False
|
| 68 |
+
|
| 69 |
+
def check_connection(self) -> bool:
|
| 70 |
+
"""Check if vLLM server is accessible."""
|
| 71 |
+
try:
|
| 72 |
+
response = requests.get(self.metrics_url, timeout=2)
|
| 73 |
+
self._connected = response.status_code == 200
|
| 74 |
+
return self._connected
|
| 75 |
+
except Exception:
|
| 76 |
+
self._connected = False
|
| 77 |
+
return False
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def is_connected(self) -> bool:
|
| 81 |
+
"""Return connection status."""
|
| 82 |
+
return self._connected
|
| 83 |
+
|
| 84 |
+
def collect(self) -> InferenceMetrics:
|
| 85 |
+
"""
|
| 86 |
+
Collect all inference metrics from vLLM.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
InferenceMetrics dataclass with current values
|
| 90 |
+
"""
|
| 91 |
+
metrics = InferenceMetrics()
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
response = requests.get(self.metrics_url, timeout=5)
|
| 95 |
+
response.raise_for_status()
|
| 96 |
+
self._connected = True
|
| 97 |
+
|
| 98 |
+
raw_metrics = parse_prometheus_metrics(response.text)
|
| 99 |
+
metrics = self._parse_metrics(raw_metrics)
|
| 100 |
+
|
| 101 |
+
except requests.exceptions.ConnectionError:
|
| 102 |
+
self._connected = False
|
| 103 |
+
logger.debug("Cannot connect to vLLM metrics endpoint")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"Error collecting vLLM metrics: {e}")
|
| 106 |
+
|
| 107 |
+
return metrics
|
| 108 |
+
|
| 109 |
+
def _parse_metrics(self, raw: Dict[str, List[MetricSample]]) -> InferenceMetrics:
|
| 110 |
+
"""Parse raw Prometheus metrics into InferenceMetrics."""
|
| 111 |
+
now = datetime.now()
|
| 112 |
+
metrics = InferenceMetrics(timestamp=now)
|
| 113 |
+
|
| 114 |
+
# Request counts
|
| 115 |
+
metrics.num_requests_running = int(
|
| 116 |
+
get_metric_value(raw, "vllm:num_requests_running") or 0
|
| 117 |
+
)
|
| 118 |
+
metrics.num_requests_waiting = int(
|
| 119 |
+
get_metric_value(raw, "vllm:num_requests_waiting") or 0
|
| 120 |
+
)
|
| 121 |
+
metrics.num_requests_swapped = int(
|
| 122 |
+
get_metric_value(raw, "vllm:num_requests_swapped") or 0
|
| 123 |
+
)
|
| 124 |
+
metrics.batch_size = metrics.num_requests_running
|
| 125 |
+
|
| 126 |
+
# Token counts (counters)
|
| 127 |
+
prompt_tokens = int(get_metric_value(raw, "vllm:prompt_tokens_total") or 0)
|
| 128 |
+
generation_tokens = int(
|
| 129 |
+
get_metric_value(raw, "vllm:generation_tokens_total") or 0
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Calculate tokens per second
|
| 133 |
+
if self._last_collect_time:
|
| 134 |
+
time_delta = (now - self._last_collect_time).total_seconds()
|
| 135 |
+
if time_delta > 0:
|
| 136 |
+
token_delta = generation_tokens - self._last_generation_tokens
|
| 137 |
+
metrics.tokens_per_second = token_delta / time_delta
|
| 138 |
+
|
| 139 |
+
self._last_prompt_tokens = prompt_tokens
|
| 140 |
+
self._last_generation_tokens = generation_tokens
|
| 141 |
+
self._last_collect_time = now
|
| 142 |
+
|
| 143 |
+
metrics.prompt_tokens_total = prompt_tokens
|
| 144 |
+
metrics.generation_tokens_total = generation_tokens
|
| 145 |
+
|
| 146 |
+
# Calculate prefill ratio
|
| 147 |
+
total_tokens = prompt_tokens + generation_tokens
|
| 148 |
+
if total_tokens > 0:
|
| 149 |
+
metrics.prefill_ratio = prompt_tokens / total_tokens
|
| 150 |
+
|
| 151 |
+
# Latency metrics (from histograms, use P50)
|
| 152 |
+
ttft = get_histogram_quantile(raw, "vllm:time_to_first_token_seconds", 0.5)
|
| 153 |
+
if ttft is not None:
|
| 154 |
+
metrics.ttft_ms = ttft * 1000
|
| 155 |
+
|
| 156 |
+
tpot = get_histogram_quantile(raw, "vllm:time_per_output_token_seconds", 0.5)
|
| 157 |
+
if tpot is not None:
|
| 158 |
+
metrics.tpot_ms = tpot * 1000
|
| 159 |
+
|
| 160 |
+
e2e = get_histogram_quantile(raw, "vllm:e2e_request_latency_seconds", 0.5)
|
| 161 |
+
if e2e is not None:
|
| 162 |
+
metrics.e2e_latency_ms = e2e * 1000
|
| 163 |
+
|
| 164 |
+
# Cache usage
|
| 165 |
+
metrics.gpu_cache_usage_percent = (
|
| 166 |
+
get_metric_value(raw, "vllm:gpu_cache_usage_perc") or 0
|
| 167 |
+
) * 100
|
| 168 |
+
metrics.cpu_cache_usage_percent = (
|
| 169 |
+
get_metric_value(raw, "vllm:cpu_cache_usage_perc") or 0
|
| 170 |
+
) * 100
|
| 171 |
+
metrics.kv_cache_usage_percent = metrics.gpu_cache_usage_percent
|
| 172 |
+
|
| 173 |
+
# Model info
|
| 174 |
+
model_name = self._get_model_name(raw)
|
| 175 |
+
if model_name:
|
| 176 |
+
metrics.model_name = model_name
|
| 177 |
+
|
| 178 |
+
return metrics
|
| 179 |
+
|
| 180 |
+
def _get_model_name(self, raw: Dict[str, List[MetricSample]]) -> Optional[str]:
|
| 181 |
+
"""Extract model name from metrics labels."""
|
| 182 |
+
# Look for model name in any metric with model_name label
|
| 183 |
+
for metric_name, samples in raw.items():
|
| 184 |
+
for sample in samples:
|
| 185 |
+
if "model_name" in sample.labels:
|
| 186 |
+
return sample.labels["model_name"]
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
def get_rank_mapping(self) -> Dict[int, int]:
|
| 190 |
+
"""
|
| 191 |
+
Get tensor parallel rank to GPU mapping.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Dictionary mapping TP rank to GPU ID
|
| 195 |
+
"""
|
| 196 |
+
# This would typically come from vLLM's internal state
|
| 197 |
+
# For now, return empty mapping - can be extended
|
| 198 |
+
return {}
|
| 199 |
+
|
| 200 |
+
def get_latency_percentiles(self) -> Dict[str, Dict[str, float]]:
|
| 201 |
+
"""
|
| 202 |
+
Get latency percentiles for detailed analysis.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Dictionary with P50, P95, P99 for each latency metric
|
| 206 |
+
"""
|
| 207 |
+
try:
|
| 208 |
+
response = requests.get(self.metrics_url, timeout=5)
|
| 209 |
+
raw = parse_prometheus_metrics(response.text)
|
| 210 |
+
|
| 211 |
+
result = {}
|
| 212 |
+
for metric_base in [
|
| 213 |
+
"vllm:time_to_first_token_seconds",
|
| 214 |
+
"vllm:time_per_output_token_seconds",
|
| 215 |
+
"vllm:e2e_request_latency_seconds",
|
| 216 |
+
]:
|
| 217 |
+
result[metric_base] = {
|
| 218 |
+
"p50": (get_histogram_quantile(raw, metric_base, 0.5) or 0) * 1000,
|
| 219 |
+
"p95": (get_histogram_quantile(raw, metric_base, 0.95) or 0) * 1000,
|
| 220 |
+
"p99": (get_histogram_quantile(raw, metric_base, 0.99) or 0) * 1000,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return result
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.error(f"Error getting latency percentiles: {e}")
|
| 226 |
+
return {}
|
components/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""UI components for the Gradio dashboard."""
|
| 2 |
+
|
| 3 |
+
from .gpu_panel import create_gpu_panel, update_gpu_panel
|
| 4 |
+
from .inference_panel import create_inference_panel, update_inference_panel
|
| 5 |
+
from .quant_panel import create_quant_panel, update_quant_panel
|
| 6 |
+
from .loading_panel import create_loading_panel, update_loading_panel
|
| 7 |
+
from .alerts_panel import create_alerts_panel, update_alerts_panel
|
| 8 |
+
from .tracing_panel import create_tracing_panel, update_tracing_panel
|
| 9 |
+
from .comparison_panel import create_comparison_panel
|
| 10 |
+
from .loadtest_panel import create_loadtest_panel
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"create_gpu_panel",
|
| 14 |
+
"update_gpu_panel",
|
| 15 |
+
"create_inference_panel",
|
| 16 |
+
"update_inference_panel",
|
| 17 |
+
"create_quant_panel",
|
| 18 |
+
"update_quant_panel",
|
| 19 |
+
"create_loading_panel",
|
| 20 |
+
"update_loading_panel",
|
| 21 |
+
"create_alerts_panel",
|
| 22 |
+
"update_alerts_panel",
|
| 23 |
+
"create_tracing_panel",
|
| 24 |
+
"update_tracing_panel",
|
| 25 |
+
"create_comparison_panel",
|
| 26 |
+
"create_loadtest_panel",
|
| 27 |
+
]
|
components/alerts_panel.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Alerts configuration and history panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Dict, Any, Tuple, List
|
| 7 |
+
|
| 8 |
+
from services.alerting import AlertEngine, AlertDispatcher, Alert, AlertSeverity
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_alerts_panel(
|
| 12 |
+
alert_engine: AlertEngine,
|
| 13 |
+
alert_dispatcher: AlertDispatcher,
|
| 14 |
+
) -> Dict[str, Any]:
|
| 15 |
+
"""
|
| 16 |
+
Create the alerts panel.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
alert_engine: Alert engine instance
|
| 20 |
+
alert_dispatcher: Alert dispatcher instance
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Dictionary of Gradio components
|
| 24 |
+
"""
|
| 25 |
+
with gr.Column():
|
| 26 |
+
with gr.Row():
|
| 27 |
+
# Active alerts column
|
| 28 |
+
with gr.Column(scale=2):
|
| 29 |
+
gr.Markdown("### Active Alerts")
|
| 30 |
+
active_alerts_table = gr.Dataframe(
|
| 31 |
+
headers=["Time", "Severity", "Metric", "Value", "Threshold", "Message"],
|
| 32 |
+
datatype=["str", "str", "str", "number", "number", "str"],
|
| 33 |
+
label="Active Alerts",
|
| 34 |
+
interactive=False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
gr.Markdown("### Alert History")
|
| 38 |
+
alert_history_table = gr.Dataframe(
|
| 39 |
+
headers=["Time", "Severity", "Message", "Resolved"],
|
| 40 |
+
datatype=["str", "str", "str", "str"],
|
| 41 |
+
label="Recent Alerts",
|
| 42 |
+
interactive=False,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Configuration column
|
| 46 |
+
with gr.Column(scale=1):
|
| 47 |
+
gr.Markdown("### Alert Configuration")
|
| 48 |
+
|
| 49 |
+
kv_threshold = gr.Slider(
|
| 50 |
+
label="KV Cache Alert Threshold (%)",
|
| 51 |
+
minimum=50,
|
| 52 |
+
maximum=100,
|
| 53 |
+
value=90,
|
| 54 |
+
step=5,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
gpu_memory_threshold = gr.Slider(
|
| 58 |
+
label="GPU Memory Alert Threshold (%)",
|
| 59 |
+
minimum=70,
|
| 60 |
+
maximum=100,
|
| 61 |
+
value=95,
|
| 62 |
+
step=5,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
ttft_multiplier = gr.Slider(
|
| 66 |
+
label="TTFT Spike Multiplier",
|
| 67 |
+
minimum=1.5,
|
| 68 |
+
maximum=5,
|
| 69 |
+
value=2,
|
| 70 |
+
step=0.5,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
throughput_drop = gr.Slider(
|
| 74 |
+
label="Throughput Drop Alert (%)",
|
| 75 |
+
minimum=20,
|
| 76 |
+
maximum=80,
|
| 77 |
+
value=50,
|
| 78 |
+
step=10,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
gr.Markdown("### Webhook Configuration")
|
| 82 |
+
|
| 83 |
+
slack_webhook = gr.Textbox(
|
| 84 |
+
label="Slack Webhook URL",
|
| 85 |
+
placeholder="https://hooks.slack.com/services/...",
|
| 86 |
+
type="password",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
pagerduty_key = gr.Textbox(
|
| 90 |
+
label="PagerDuty Routing Key",
|
| 91 |
+
placeholder="Enter routing key...",
|
| 92 |
+
type="password",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
with gr.Row():
|
| 96 |
+
save_config_btn = gr.Button("Save Configuration")
|
| 97 |
+
test_alert_btn = gr.Button("Send Test Alert", variant="secondary")
|
| 98 |
+
|
| 99 |
+
config_status = gr.Textbox(
|
| 100 |
+
label="Status",
|
| 101 |
+
interactive=False,
|
| 102 |
+
visible=True,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Event handlers
|
| 106 |
+
def save_config(kv, gpu, ttft, tp_drop, slack, pd_key):
|
| 107 |
+
# Update alert thresholds
|
| 108 |
+
if "kv_cache_high" in alert_engine.rules:
|
| 109 |
+
alert_engine.rules["kv_cache_high"].threshold = kv
|
| 110 |
+
if "gpu_memory_critical" in alert_engine.rules:
|
| 111 |
+
alert_engine.rules["gpu_memory_critical"].threshold = gpu
|
| 112 |
+
if "ttft_spike" in alert_engine.rules:
|
| 113 |
+
alert_engine.rules["ttft_spike"].multiplier = ttft
|
| 114 |
+
if "throughput_drop" in alert_engine.rules:
|
| 115 |
+
alert_engine.rules["throughput_drop"].percent = tp_drop
|
| 116 |
+
|
| 117 |
+
# Update webhook config
|
| 118 |
+
alert_dispatcher.slack_webhook = slack if slack else None
|
| 119 |
+
alert_dispatcher.pagerduty_key = pd_key if pd_key else None
|
| 120 |
+
|
| 121 |
+
return "Configuration saved successfully"
|
| 122 |
+
|
| 123 |
+
save_config_btn.click(
|
| 124 |
+
fn=save_config,
|
| 125 |
+
inputs=[
|
| 126 |
+
kv_threshold,
|
| 127 |
+
gpu_memory_threshold,
|
| 128 |
+
ttft_multiplier,
|
| 129 |
+
throughput_drop,
|
| 130 |
+
slack_webhook,
|
| 131 |
+
pagerduty_key,
|
| 132 |
+
],
|
| 133 |
+
outputs=config_status,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
async def send_test():
|
| 137 |
+
success = await alert_dispatcher.send_test_alert()
|
| 138 |
+
if success:
|
| 139 |
+
return "Test alert sent successfully"
|
| 140 |
+
return "Failed to send test alert - check webhook configuration"
|
| 141 |
+
|
| 142 |
+
test_alert_btn.click(
|
| 143 |
+
fn=send_test,
|
| 144 |
+
outputs=config_status,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
"active_alerts_table": active_alerts_table,
|
| 149 |
+
"alert_history_table": alert_history_table,
|
| 150 |
+
"kv_threshold": kv_threshold,
|
| 151 |
+
"gpu_memory_threshold": gpu_memory_threshold,
|
| 152 |
+
"ttft_multiplier": ttft_multiplier,
|
| 153 |
+
"throughput_drop": throughput_drop,
|
| 154 |
+
"slack_webhook": slack_webhook,
|
| 155 |
+
"pagerduty_key": pagerduty_key,
|
| 156 |
+
"config_status": config_status,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def update_alerts_panel(
|
| 161 |
+
alert_engine: AlertEngine,
|
| 162 |
+
db=None,
|
| 163 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 164 |
+
"""
|
| 165 |
+
Update the alerts panel with current data.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
alert_engine: Alert engine instance
|
| 169 |
+
db: Optional database for history
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Tuple of (active_alerts_df, history_df)
|
| 173 |
+
"""
|
| 174 |
+
# Get active alerts
|
| 175 |
+
active = alert_engine.get_active_alerts()
|
| 176 |
+
active_rows = []
|
| 177 |
+
for alert in active:
|
| 178 |
+
active_rows.append({
|
| 179 |
+
"Time": alert.timestamp.strftime("%H:%M:%S"),
|
| 180 |
+
"Severity": _format_severity(alert.severity),
|
| 181 |
+
"Metric": alert.metric,
|
| 182 |
+
"Value": round(alert.value, 2),
|
| 183 |
+
"Threshold": round(alert.threshold, 2),
|
| 184 |
+
"Message": alert.message,
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
active_df = pd.DataFrame(active_rows) if active_rows else pd.DataFrame(
|
| 188 |
+
columns=["Time", "Severity", "Metric", "Value", "Threshold", "Message"]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Get history from database
|
| 192 |
+
history_rows = []
|
| 193 |
+
if db:
|
| 194 |
+
recent = db.get_recent_alerts(limit=20)
|
| 195 |
+
for record in recent:
|
| 196 |
+
history_rows.append({
|
| 197 |
+
"Time": record.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
|
| 198 |
+
"Severity": _format_severity_str(record.severity),
|
| 199 |
+
"Message": record.message,
|
| 200 |
+
"Resolved": "Yes" if record.resolved_at else "No",
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
history_df = pd.DataFrame(history_rows) if history_rows else pd.DataFrame(
|
| 204 |
+
columns=["Time", "Severity", "Message", "Resolved"]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return active_df, history_df
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _format_severity(severity: AlertSeverity) -> str:
|
| 211 |
+
"""Format severity for display."""
|
| 212 |
+
icons = {
|
| 213 |
+
AlertSeverity.INFO: "INFO",
|
| 214 |
+
AlertSeverity.WARNING: "WARNING",
|
| 215 |
+
AlertSeverity.CRITICAL: "CRITICAL",
|
| 216 |
+
}
|
| 217 |
+
return icons.get(severity, "UNKNOWN")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _format_severity_str(severity: str) -> str:
|
| 221 |
+
"""Format severity string for display."""
|
| 222 |
+
return severity.upper()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_alert_badge_html(alerts: List[Alert]) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Generate HTML badge for active alerts.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
alerts: List of active alerts
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
HTML string for badge
|
| 234 |
+
"""
|
| 235 |
+
if not alerts:
|
| 236 |
+
return '<span style="color: #2e7d32;">No Active Alerts</span>'
|
| 237 |
+
|
| 238 |
+
critical = sum(1 for a in alerts if a.severity == AlertSeverity.CRITICAL)
|
| 239 |
+
warning = sum(1 for a in alerts if a.severity == AlertSeverity.WARNING)
|
| 240 |
+
|
| 241 |
+
badges = []
|
| 242 |
+
if critical > 0:
|
| 243 |
+
badges.append(
|
| 244 |
+
f'<span style="background: #c62828; color: white; padding: 2px 8px; '
|
| 245 |
+
f'border-radius: 12px; margin-right: 5px;">{critical} Critical</span>'
|
| 246 |
+
)
|
| 247 |
+
if warning > 0:
|
| 248 |
+
badges.append(
|
| 249 |
+
f'<span style="background: #ff9800; color: white; padding: 2px 8px; '
|
| 250 |
+
f'border-radius: 12px;">{warning} Warning</span>'
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return "".join(badges)
|
components/comparison_panel.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A/B comparison panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Dict, Any, Tuple
|
| 7 |
+
|
| 8 |
+
from services.comparator import ABComparator, DeploymentConfig, ComparisonResult
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_comparison_panel() -> Dict[str, Any]:
|
| 12 |
+
"""
|
| 13 |
+
Create the A/B comparison panel.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Dictionary of Gradio components
|
| 17 |
+
"""
|
| 18 |
+
with gr.Column():
|
| 19 |
+
gr.Markdown("### A/B Deployment Comparison")
|
| 20 |
+
|
| 21 |
+
# Endpoint configuration
|
| 22 |
+
with gr.Row():
|
| 23 |
+
endpoint_a = gr.Textbox(
|
| 24 |
+
label="Deployment A",
|
| 25 |
+
value="http://localhost:8000",
|
| 26 |
+
placeholder="http://host:port",
|
| 27 |
+
)
|
| 28 |
+
name_a = gr.Textbox(
|
| 29 |
+
label="Name A",
|
| 30 |
+
value="Baseline",
|
| 31 |
+
placeholder="e.g., FP16-baseline",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
with gr.Row():
|
| 35 |
+
endpoint_b = gr.Textbox(
|
| 36 |
+
label="Deployment B",
|
| 37 |
+
value="http://localhost:8001",
|
| 38 |
+
placeholder="http://host:port",
|
| 39 |
+
)
|
| 40 |
+
name_b = gr.Textbox(
|
| 41 |
+
label="Name B",
|
| 42 |
+
value="Candidate",
|
| 43 |
+
placeholder="e.g., AWQ-4bit",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
with gr.Row():
|
| 47 |
+
compare_btn = gr.Button("Compare Now", variant="primary")
|
| 48 |
+
collect_samples_btn = gr.Button("Collect Samples (30s)", variant="secondary")
|
| 49 |
+
|
| 50 |
+
# Status
|
| 51 |
+
comparison_status = gr.Textbox(
|
| 52 |
+
label="Status",
|
| 53 |
+
interactive=False,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Results side by side
|
| 57 |
+
with gr.Row():
|
| 58 |
+
with gr.Column():
|
| 59 |
+
gr.Markdown("### Deployment A")
|
| 60 |
+
a_connected = gr.Checkbox(label="Connected", interactive=False)
|
| 61 |
+
a_throughput = gr.Number(label="Throughput (tok/s)", precision=1, interactive=False)
|
| 62 |
+
a_ttft = gr.Number(label="TTFT (ms)", precision=1, interactive=False)
|
| 63 |
+
a_latency = gr.Number(label="E2E Latency (ms)", precision=1, interactive=False)
|
| 64 |
+
a_kv_cache = gr.Number(label="KV Cache %", precision=1, interactive=False)
|
| 65 |
+
a_batch = gr.Number(label="Batch Size", precision=0, interactive=False)
|
| 66 |
+
|
| 67 |
+
with gr.Column():
|
| 68 |
+
gr.Markdown("### Deployment B")
|
| 69 |
+
b_connected = gr.Checkbox(label="Connected", interactive=False)
|
| 70 |
+
b_throughput = gr.Number(label="Throughput (tok/s)", precision=1, interactive=False)
|
| 71 |
+
b_ttft = gr.Number(label="TTFT (ms)", precision=1, interactive=False)
|
| 72 |
+
b_latency = gr.Number(label="E2E Latency (ms)", precision=1, interactive=False)
|
| 73 |
+
b_kv_cache = gr.Number(label="KV Cache %", precision=1, interactive=False)
|
| 74 |
+
b_batch = gr.Number(label="Batch Size", precision=0, interactive=False)
|
| 75 |
+
|
| 76 |
+
# Comparison table
|
| 77 |
+
gr.Markdown("### Comparison Summary")
|
| 78 |
+
comparison_table = gr.Dataframe(
|
| 79 |
+
headers=["Metric", "Deployment A", "Deployment B", "Difference"],
|
| 80 |
+
label="Comparison",
|
| 81 |
+
interactive=False,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Recommendation
|
| 85 |
+
recommendation = gr.Markdown("")
|
| 86 |
+
|
| 87 |
+
# Statistical significance
|
| 88 |
+
with gr.Row():
|
| 89 |
+
significance_throughput = gr.Textbox(
|
| 90 |
+
label="Throughput Significance",
|
| 91 |
+
interactive=False,
|
| 92 |
+
)
|
| 93 |
+
significance_latency = gr.Textbox(
|
| 94 |
+
label="Latency Significance",
|
| 95 |
+
interactive=False,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Event handlers
|
| 99 |
+
async def run_comparison(ep_a, name_a_val, ep_b, name_b_val):
|
| 100 |
+
config_a = DeploymentConfig(name=name_a_val, endpoint=ep_a)
|
| 101 |
+
config_b = DeploymentConfig(name=name_b_val, endpoint=ep_b)
|
| 102 |
+
|
| 103 |
+
comparator = ABComparator(config_a, config_b)
|
| 104 |
+
result = await comparator.compare()
|
| 105 |
+
|
| 106 |
+
return format_comparison_results(result, comparator)
|
| 107 |
+
|
| 108 |
+
compare_btn.click(
|
| 109 |
+
fn=run_comparison,
|
| 110 |
+
inputs=[endpoint_a, name_a, endpoint_b, name_b],
|
| 111 |
+
outputs=[
|
| 112 |
+
comparison_status,
|
| 113 |
+
a_connected, a_throughput, a_ttft, a_latency, a_kv_cache, a_batch,
|
| 114 |
+
b_connected, b_throughput, b_ttft, b_latency, b_kv_cache, b_batch,
|
| 115 |
+
comparison_table, recommendation,
|
| 116 |
+
significance_throughput, significance_latency,
|
| 117 |
+
],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
async def collect_and_compare(ep_a, name_a_val, ep_b, name_b_val):
|
| 121 |
+
config_a = DeploymentConfig(name=name_a_val, endpoint=ep_a)
|
| 122 |
+
config_b = DeploymentConfig(name=name_b_val, endpoint=ep_b)
|
| 123 |
+
|
| 124 |
+
comparator = ABComparator(config_a, config_b)
|
| 125 |
+
|
| 126 |
+
# Collect samples (this takes ~30 seconds)
|
| 127 |
+
yield (
|
| 128 |
+
"Collecting samples (0/30)...",
|
| 129 |
+
*[None] * 15 # Placeholder outputs
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
await comparator.collect_samples(count=30)
|
| 133 |
+
result = await comparator.compare()
|
| 134 |
+
|
| 135 |
+
yield format_comparison_results(result, comparator)
|
| 136 |
+
|
| 137 |
+
collect_samples_btn.click(
|
| 138 |
+
fn=collect_and_compare,
|
| 139 |
+
inputs=[endpoint_a, name_a, endpoint_b, name_b],
|
| 140 |
+
outputs=[
|
| 141 |
+
comparison_status,
|
| 142 |
+
a_connected, a_throughput, a_ttft, a_latency, a_kv_cache, a_batch,
|
| 143 |
+
b_connected, b_throughput, b_ttft, b_latency, b_kv_cache, b_batch,
|
| 144 |
+
comparison_table, recommendation,
|
| 145 |
+
significance_throughput, significance_latency,
|
| 146 |
+
],
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return {
|
| 150 |
+
"endpoint_a": endpoint_a,
|
| 151 |
+
"name_a": name_a,
|
| 152 |
+
"endpoint_b": endpoint_b,
|
| 153 |
+
"name_b": name_b,
|
| 154 |
+
"comparison_status": comparison_status,
|
| 155 |
+
"a_connected": a_connected,
|
| 156 |
+
"a_throughput": a_throughput,
|
| 157 |
+
"a_ttft": a_ttft,
|
| 158 |
+
"a_latency": a_latency,
|
| 159 |
+
"a_kv_cache": a_kv_cache,
|
| 160 |
+
"a_batch": a_batch,
|
| 161 |
+
"b_connected": b_connected,
|
| 162 |
+
"b_throughput": b_throughput,
|
| 163 |
+
"b_ttft": b_ttft,
|
| 164 |
+
"b_latency": b_latency,
|
| 165 |
+
"b_kv_cache": b_kv_cache,
|
| 166 |
+
"b_batch": b_batch,
|
| 167 |
+
"comparison_table": comparison_table,
|
| 168 |
+
"recommendation": recommendation,
|
| 169 |
+
"significance_throughput": significance_throughput,
|
| 170 |
+
"significance_latency": significance_latency,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def format_comparison_results(
|
| 175 |
+
result: ComparisonResult,
|
| 176 |
+
comparator: ABComparator,
|
| 177 |
+
) -> Tuple:
|
| 178 |
+
"""Format comparison results for UI components."""
|
| 179 |
+
a = result.deployment_a
|
| 180 |
+
b = result.deployment_b
|
| 181 |
+
|
| 182 |
+
# Build comparison table
|
| 183 |
+
table_data = comparator.get_comparison_table(result)
|
| 184 |
+
table_df = pd.DataFrame(table_data)
|
| 185 |
+
|
| 186 |
+
# Format recommendation
|
| 187 |
+
recommendation_md = f"**Recommendation:** {result.recommendation}"
|
| 188 |
+
|
| 189 |
+
# Format significance
|
| 190 |
+
sig_throughput = "Not tested"
|
| 191 |
+
sig_latency = "Not tested"
|
| 192 |
+
|
| 193 |
+
if result.p_value_throughput < 1.0:
|
| 194 |
+
sig_status = "Significant" if result.throughput_significant else "Not significant"
|
| 195 |
+
sig_throughput = f"{sig_status} (p={result.p_value_throughput:.4f})"
|
| 196 |
+
|
| 197 |
+
if result.p_value_latency < 1.0:
|
| 198 |
+
sig_status = "Significant" if result.latency_significant else "Not significant"
|
| 199 |
+
sig_latency = f"{sig_status} (p={result.p_value_latency:.4f})"
|
| 200 |
+
|
| 201 |
+
return (
|
| 202 |
+
"Comparison complete",
|
| 203 |
+
a.connected, a.tokens_per_second, a.ttft_ms, a.e2e_latency_ms, a.kv_cache_percent, a.batch_size,
|
| 204 |
+
b.connected, b.tokens_per_second, b.ttft_ms, b.e2e_latency_ms, b.kv_cache_percent, b.batch_size,
|
| 205 |
+
table_df, recommendation_md,
|
| 206 |
+
sig_throughput, sig_latency,
|
| 207 |
+
)
|
components/gpu_panel.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GPU status panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from typing import List, Dict, Any, Tuple
|
| 6 |
+
|
| 7 |
+
from collectors.gpu_collector import GPUCollector, GPUStats
|
| 8 |
+
from utils.history import MetricHistory
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_gpu_panel(history: MetricHistory) -> Dict[str, Any]:
|
| 12 |
+
"""
|
| 13 |
+
Create the GPU status panel.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
history: Metric history for charting
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Dictionary of Gradio components
|
| 20 |
+
"""
|
| 21 |
+
with gr.Column():
|
| 22 |
+
gr.Markdown("### GPU / Rank Status")
|
| 23 |
+
|
| 24 |
+
# GPU stats table
|
| 25 |
+
gpu_table = gr.Dataframe(
|
| 26 |
+
headers=["GPU", "Name", "Memory", "Memory %", "Util %", "Temp", "Power", "TP Rank"],
|
| 27 |
+
datatype=["number", "str", "str", "number", "number", "str", "str", "str"],
|
| 28 |
+
label="GPU Statistics",
|
| 29 |
+
interactive=False,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
with gr.Row():
|
| 33 |
+
# Memory usage plot
|
| 34 |
+
gpu_memory_plot = gr.LinePlot(
|
| 35 |
+
x="time",
|
| 36 |
+
y="value",
|
| 37 |
+
color="gpu",
|
| 38 |
+
title="GPU Memory Usage (GB)",
|
| 39 |
+
x_title="Time",
|
| 40 |
+
y_title="Memory (GB)",
|
| 41 |
+
height=250,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Utilization plot
|
| 45 |
+
gpu_util_plot = gr.LinePlot(
|
| 46 |
+
x="time",
|
| 47 |
+
y="value",
|
| 48 |
+
color="gpu",
|
| 49 |
+
title="GPU Utilization (%)",
|
| 50 |
+
x_title="Time",
|
| 51 |
+
y_title="Utilization %",
|
| 52 |
+
height=250,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# NCCL / Communication status
|
| 56 |
+
nccl_status = gr.HTML(
|
| 57 |
+
value='<div style="padding: 10px; background: #e8f5e9; border-radius: 5px;">'
|
| 58 |
+
'<span style="color: #2e7d32;">NCCL Status: Healthy</span></div>',
|
| 59 |
+
label="Communication Status",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return {
|
| 63 |
+
"gpu_table": gpu_table,
|
| 64 |
+
"gpu_memory_plot": gpu_memory_plot,
|
| 65 |
+
"gpu_util_plot": gpu_util_plot,
|
| 66 |
+
"nccl_status": nccl_status,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def update_gpu_panel(
|
| 71 |
+
collector: GPUCollector,
|
| 72 |
+
history: MetricHistory,
|
| 73 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, str]:
|
| 74 |
+
"""
|
| 75 |
+
Update the GPU panel with current data.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
collector: GPU collector instance
|
| 79 |
+
history: Metric history
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Tuple of (table_data, memory_plot_data, util_plot_data, nccl_html)
|
| 83 |
+
"""
|
| 84 |
+
stats = collector.collect()
|
| 85 |
+
|
| 86 |
+
# Update history
|
| 87 |
+
for stat in stats:
|
| 88 |
+
history.add(
|
| 89 |
+
"gpu_memory_gb",
|
| 90 |
+
stat.memory_used_gb,
|
| 91 |
+
labels={"gpu": str(stat.gpu_id)},
|
| 92 |
+
)
|
| 93 |
+
history.add(
|
| 94 |
+
"gpu_util_percent",
|
| 95 |
+
stat.gpu_util_percent,
|
| 96 |
+
labels={"gpu": str(stat.gpu_id)},
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Build table data
|
| 100 |
+
table_data = _build_table(stats)
|
| 101 |
+
|
| 102 |
+
# Build chart data
|
| 103 |
+
memory_df = _build_memory_chart_data(history)
|
| 104 |
+
util_df = _build_util_chart_data(history)
|
| 105 |
+
|
| 106 |
+
# NCCL status (simplified - would need more complex detection)
|
| 107 |
+
nccl_html = _build_nccl_status(stats)
|
| 108 |
+
|
| 109 |
+
return table_data, memory_df, util_df, nccl_html
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _build_table(stats: List[GPUStats]) -> pd.DataFrame:
|
| 113 |
+
"""Build GPU stats table."""
|
| 114 |
+
rows = []
|
| 115 |
+
for stat in stats:
|
| 116 |
+
rows.append({
|
| 117 |
+
"GPU": stat.gpu_id,
|
| 118 |
+
"Name": stat.name[:20] if len(stat.name) > 20 else stat.name,
|
| 119 |
+
"Memory": f"{stat.memory_used_gb:.1f}/{stat.memory_total_gb:.1f} GB",
|
| 120 |
+
"Memory %": round(stat.memory_percent, 1),
|
| 121 |
+
"Util %": round(stat.gpu_util_percent, 1),
|
| 122 |
+
"Temp": f"{stat.temperature_c}C",
|
| 123 |
+
"Power": f"{stat.power_watts:.0f}/{stat.power_limit_watts:.0f}W",
|
| 124 |
+
"TP Rank": str(stat.tp_rank) if stat.tp_rank is not None else "-",
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
return pd.DataFrame(rows)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _build_memory_chart_data(history: MetricHistory) -> pd.DataFrame:
|
| 131 |
+
"""Build memory usage chart data."""
|
| 132 |
+
all_series = history.get_all_series("gpu_memory_gb")
|
| 133 |
+
|
| 134 |
+
rows = []
|
| 135 |
+
for key, points in all_series.items():
|
| 136 |
+
gpu_id = key.split("=")[-1] if "=" in key else "0"
|
| 137 |
+
for point in points[-60:]: # Last 60 points
|
| 138 |
+
rows.append({
|
| 139 |
+
"time": point.timestamp,
|
| 140 |
+
"value": point.value,
|
| 141 |
+
"gpu": f"GPU {gpu_id}",
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
if not rows:
|
| 145 |
+
return pd.DataFrame({"time": [], "value": [], "gpu": []})
|
| 146 |
+
|
| 147 |
+
return pd.DataFrame(rows)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _build_util_chart_data(history: MetricHistory) -> pd.DataFrame:
|
| 151 |
+
"""Build utilization chart data."""
|
| 152 |
+
all_series = history.get_all_series("gpu_util_percent")
|
| 153 |
+
|
| 154 |
+
rows = []
|
| 155 |
+
for key, points in all_series.items():
|
| 156 |
+
gpu_id = key.split("=")[-1] if "=" in key else "0"
|
| 157 |
+
for point in points[-60:]:
|
| 158 |
+
rows.append({
|
| 159 |
+
"time": point.timestamp,
|
| 160 |
+
"value": point.value,
|
| 161 |
+
"gpu": f"GPU {gpu_id}",
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
if not rows:
|
| 165 |
+
return pd.DataFrame({"time": [], "value": [], "gpu": []})
|
| 166 |
+
|
| 167 |
+
return pd.DataFrame(rows)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _build_nccl_status(stats: List[GPUStats]) -> str:
|
| 171 |
+
"""Build NCCL status HTML."""
|
| 172 |
+
if not stats:
|
| 173 |
+
return (
|
| 174 |
+
'<div style="padding: 10px; background: #fff3e0; border-radius: 5px;">'
|
| 175 |
+
'<span style="color: #e65100;">NCCL Status: No GPUs detected</span></div>'
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Check for GPU communication health indicators
|
| 179 |
+
# In a real implementation, this would check vLLM metrics for NCCL errors
|
| 180 |
+
all_healthy = all(stat.gpu_util_percent > 0 or stat.memory_percent > 0 for stat in stats)
|
| 181 |
+
|
| 182 |
+
if all_healthy:
|
| 183 |
+
return (
|
| 184 |
+
'<div style="padding: 10px; background: #e8f5e9; border-radius: 5px;">'
|
| 185 |
+
f'<span style="color: #2e7d32;">NCCL Status: Healthy ({len(stats)} GPUs)</span></div>'
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
return (
|
| 189 |
+
'<div style="padding: 10px; background: #ffebee; border-radius: 5px;">'
|
| 190 |
+
'<span style="color: #c62828;">NCCL Status: Communication issue detected</span></div>'
|
| 191 |
+
)
|
components/inference_panel.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference metrics panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from typing import Dict, Any, Tuple
|
| 6 |
+
|
| 7 |
+
from collectors.vllm_collector import VLLMCollector, InferenceMetrics
|
| 8 |
+
from utils.history import MetricHistory
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_inference_panel(history: MetricHistory) -> Dict[str, Any]:
|
| 12 |
+
"""
|
| 13 |
+
Create the inference metrics panel.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
history: Metric history for charting
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Dictionary of Gradio components
|
| 20 |
+
"""
|
| 21 |
+
with gr.Column():
|
| 22 |
+
gr.Markdown("### Inference Metrics")
|
| 23 |
+
|
| 24 |
+
# Key metrics row
|
| 25 |
+
with gr.Row():
|
| 26 |
+
throughput = gr.Number(
|
| 27 |
+
label="Tokens/sec",
|
| 28 |
+
precision=1,
|
| 29 |
+
interactive=False,
|
| 30 |
+
)
|
| 31 |
+
ttft = gr.Number(
|
| 32 |
+
label="TTFT (ms)",
|
| 33 |
+
precision=1,
|
| 34 |
+
interactive=False,
|
| 35 |
+
)
|
| 36 |
+
batch_size = gr.Number(
|
| 37 |
+
label="Batch Size",
|
| 38 |
+
precision=0,
|
| 39 |
+
interactive=False,
|
| 40 |
+
)
|
| 41 |
+
kv_cache = gr.Number(
|
| 42 |
+
label="KV Cache %",
|
| 43 |
+
precision=1,
|
| 44 |
+
interactive=False,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Throughput plot
|
| 48 |
+
throughput_plot = gr.LinePlot(
|
| 49 |
+
x="time",
|
| 50 |
+
y="value",
|
| 51 |
+
title="Throughput Over Time",
|
| 52 |
+
x_title="Time",
|
| 53 |
+
y_title="Tokens/sec",
|
| 54 |
+
height=250,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Secondary metrics row
|
| 58 |
+
with gr.Row():
|
| 59 |
+
prefill_pct = gr.Number(
|
| 60 |
+
label="Prefill %",
|
| 61 |
+
precision=1,
|
| 62 |
+
interactive=False,
|
| 63 |
+
)
|
| 64 |
+
decode_pct = gr.Number(
|
| 65 |
+
label="Decode %",
|
| 66 |
+
precision=1,
|
| 67 |
+
interactive=False,
|
| 68 |
+
)
|
| 69 |
+
queue_depth = gr.Number(
|
| 70 |
+
label="Queue Depth",
|
| 71 |
+
precision=0,
|
| 72 |
+
interactive=False,
|
| 73 |
+
)
|
| 74 |
+
e2e_latency = gr.Number(
|
| 75 |
+
label="E2E Latency (ms)",
|
| 76 |
+
precision=1,
|
| 77 |
+
interactive=False,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Latency plot
|
| 81 |
+
latency_plot = gr.LinePlot(
|
| 82 |
+
x="time",
|
| 83 |
+
y="value",
|
| 84 |
+
color="metric",
|
| 85 |
+
title="Latency Over Time",
|
| 86 |
+
x_title="Time",
|
| 87 |
+
y_title="Latency (ms)",
|
| 88 |
+
height=250,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"throughput": throughput,
|
| 93 |
+
"ttft": ttft,
|
| 94 |
+
"batch_size": batch_size,
|
| 95 |
+
"kv_cache": kv_cache,
|
| 96 |
+
"throughput_plot": throughput_plot,
|
| 97 |
+
"prefill_pct": prefill_pct,
|
| 98 |
+
"decode_pct": decode_pct,
|
| 99 |
+
"queue_depth": queue_depth,
|
| 100 |
+
"e2e_latency": e2e_latency,
|
| 101 |
+
"latency_plot": latency_plot,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def update_inference_panel(
|
| 106 |
+
collector: VLLMCollector,
|
| 107 |
+
history: MetricHistory,
|
| 108 |
+
) -> Tuple[float, float, int, float, pd.DataFrame, float, float, int, float, pd.DataFrame]:
|
| 109 |
+
"""
|
| 110 |
+
Update the inference panel with current data.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
collector: vLLM collector instance
|
| 114 |
+
history: Metric history
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Tuple of all metric values and chart data
|
| 118 |
+
"""
|
| 119 |
+
metrics = collector.collect()
|
| 120 |
+
|
| 121 |
+
# Update history
|
| 122 |
+
history.add("tokens_per_second", metrics.tokens_per_second)
|
| 123 |
+
history.add("ttft_ms", metrics.ttft_ms)
|
| 124 |
+
history.add("e2e_latency_ms", metrics.e2e_latency_ms)
|
| 125 |
+
history.add("kv_cache_percent", metrics.kv_cache_usage_percent)
|
| 126 |
+
|
| 127 |
+
# Build throughput chart
|
| 128 |
+
throughput_df = _build_throughput_chart(history)
|
| 129 |
+
|
| 130 |
+
# Build latency chart
|
| 131 |
+
latency_df = _build_latency_chart(history)
|
| 132 |
+
|
| 133 |
+
# Calculate prefill/decode percentages
|
| 134 |
+
prefill_pct = metrics.prefill_ratio * 100
|
| 135 |
+
decode_pct = 100 - prefill_pct
|
| 136 |
+
|
| 137 |
+
return (
|
| 138 |
+
metrics.tokens_per_second,
|
| 139 |
+
metrics.ttft_ms,
|
| 140 |
+
metrics.batch_size,
|
| 141 |
+
metrics.kv_cache_usage_percent,
|
| 142 |
+
throughput_df,
|
| 143 |
+
prefill_pct,
|
| 144 |
+
decode_pct,
|
| 145 |
+
metrics.num_requests_waiting,
|
| 146 |
+
metrics.e2e_latency_ms,
|
| 147 |
+
latency_df,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _build_throughput_chart(history: MetricHistory) -> pd.DataFrame:
|
| 152 |
+
"""Build throughput chart data."""
|
| 153 |
+
points = history.get("tokens_per_second", limit=60)
|
| 154 |
+
|
| 155 |
+
if not points:
|
| 156 |
+
return pd.DataFrame({"time": [], "value": []})
|
| 157 |
+
|
| 158 |
+
return pd.DataFrame([
|
| 159 |
+
{"time": p.timestamp, "value": p.value}
|
| 160 |
+
for p in points
|
| 161 |
+
])
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _build_latency_chart(history: MetricHistory) -> pd.DataFrame:
|
| 165 |
+
"""Build latency chart data with multiple series."""
|
| 166 |
+
ttft_points = history.get("ttft_ms", limit=60)
|
| 167 |
+
e2e_points = history.get("e2e_latency_ms", limit=60)
|
| 168 |
+
|
| 169 |
+
rows = []
|
| 170 |
+
|
| 171 |
+
for p in ttft_points:
|
| 172 |
+
rows.append({
|
| 173 |
+
"time": p.timestamp,
|
| 174 |
+
"value": p.value,
|
| 175 |
+
"metric": "TTFT",
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
+
for p in e2e_points:
|
| 179 |
+
rows.append({
|
| 180 |
+
"time": p.timestamp,
|
| 181 |
+
"value": p.value,
|
| 182 |
+
"metric": "E2E",
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
if not rows:
|
| 186 |
+
return pd.DataFrame({"time": [], "value": [], "metric": []})
|
| 187 |
+
|
| 188 |
+
return pd.DataFrame(rows)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_metrics_dict(metrics: InferenceMetrics) -> Dict[str, float]:
|
| 192 |
+
"""
|
| 193 |
+
Convert metrics to dictionary for alerting.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
metrics: InferenceMetrics instance
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Dictionary of metric name to value
|
| 200 |
+
"""
|
| 201 |
+
return {
|
| 202 |
+
"tokens_per_second": metrics.tokens_per_second,
|
| 203 |
+
"ttft_ms": metrics.ttft_ms,
|
| 204 |
+
"e2e_latency_ms": metrics.e2e_latency_ms,
|
| 205 |
+
"kv_cache_percent": metrics.kv_cache_usage_percent,
|
| 206 |
+
"batch_size": metrics.batch_size,
|
| 207 |
+
"queue_depth": metrics.num_requests_waiting,
|
| 208 |
+
"gpu_cache_percent": metrics.gpu_cache_usage_percent,
|
| 209 |
+
}
|
components/loading_panel.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading progress panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from typing import Dict, Any, Tuple
|
| 6 |
+
|
| 7 |
+
from collectors.loading_tracker import LoadingTracker, LoadingStatus
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_loading_panel() -> Dict[str, Any]:
|
| 11 |
+
"""
|
| 12 |
+
Create the loading progress panel.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Dictionary of Gradio components
|
| 16 |
+
"""
|
| 17 |
+
with gr.Column():
|
| 18 |
+
gr.Markdown("### Model Loading Progress")
|
| 19 |
+
|
| 20 |
+
# Status indicator
|
| 21 |
+
loading_status = gr.HTML(
|
| 22 |
+
value=_build_status_html(LoadingStatus.NOT_STARTED),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Progress bar
|
| 26 |
+
loading_progress = gr.Slider(
|
| 27 |
+
label="Loading Progress",
|
| 28 |
+
minimum=0,
|
| 29 |
+
maximum=100,
|
| 30 |
+
value=0,
|
| 31 |
+
interactive=False,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
with gr.Row():
|
| 35 |
+
shards_loaded = gr.Textbox(
|
| 36 |
+
label="Shards Loaded",
|
| 37 |
+
value="0 / 0",
|
| 38 |
+
interactive=False,
|
| 39 |
+
)
|
| 40 |
+
layers_loaded = gr.Textbox(
|
| 41 |
+
label="Layers Loaded",
|
| 42 |
+
value="0 / 0",
|
| 43 |
+
interactive=False,
|
| 44 |
+
)
|
| 45 |
+
eta = gr.Textbox(
|
| 46 |
+
label="ETA",
|
| 47 |
+
value="-",
|
| 48 |
+
interactive=False,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Shard details table
|
| 52 |
+
gr.Markdown("#### Shard Details")
|
| 53 |
+
shard_table = gr.Dataframe(
|
| 54 |
+
headers=["Shard", "Size (MB)", "Status", "Layers"],
|
| 55 |
+
datatype=["str", "number", "str", "str"],
|
| 56 |
+
label="Shards",
|
| 57 |
+
interactive=False,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
"loading_status": loading_status,
|
| 62 |
+
"loading_progress": loading_progress,
|
| 63 |
+
"shards_loaded": shards_loaded,
|
| 64 |
+
"layers_loaded": layers_loaded,
|
| 65 |
+
"eta": eta,
|
| 66 |
+
"shard_table": shard_table,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def update_loading_panel(
|
| 71 |
+
tracker: LoadingTracker,
|
| 72 |
+
) -> Tuple[str, float, str, str, str, pd.DataFrame]:
|
| 73 |
+
"""
|
| 74 |
+
Update the loading panel with current data.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
tracker: Loading tracker instance
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Tuple of (status_html, progress, shards_text, layers_text, eta_text, shard_table)
|
| 81 |
+
"""
|
| 82 |
+
progress = tracker.get_progress()
|
| 83 |
+
shards = tracker.get_shards()
|
| 84 |
+
|
| 85 |
+
# Build status HTML
|
| 86 |
+
status_html = _build_status_html(progress.status)
|
| 87 |
+
|
| 88 |
+
# Build shard table
|
| 89 |
+
shard_rows = []
|
| 90 |
+
for shard in shards[:20]:
|
| 91 |
+
shard_rows.append({
|
| 92 |
+
"Shard": shard.filename,
|
| 93 |
+
"Size (MB)": round(shard.size_mb, 1),
|
| 94 |
+
"Status": _format_shard_status(shard.status),
|
| 95 |
+
"Layers": str(len(shard.layers)),
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
shard_df = pd.DataFrame(shard_rows) if shard_rows else pd.DataFrame(
|
| 99 |
+
columns=["Shard", "Size (MB)", "Status", "Layers"]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Format text values
|
| 103 |
+
shards_text = f"{progress.loaded_shards} / {progress.total_shards}"
|
| 104 |
+
layers_text = f"{progress.layers_loaded} / {progress.total_layers}"
|
| 105 |
+
|
| 106 |
+
if progress.estimated_remaining_seconds:
|
| 107 |
+
minutes = int(progress.estimated_remaining_seconds // 60)
|
| 108 |
+
seconds = int(progress.estimated_remaining_seconds % 60)
|
| 109 |
+
eta_text = f"{minutes}m {seconds}s"
|
| 110 |
+
else:
|
| 111 |
+
eta_text = "-"
|
| 112 |
+
|
| 113 |
+
return (
|
| 114 |
+
status_html,
|
| 115 |
+
progress.progress_percent,
|
| 116 |
+
shards_text,
|
| 117 |
+
layers_text,
|
| 118 |
+
eta_text,
|
| 119 |
+
shard_df,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _build_status_html(status: LoadingStatus) -> str:
|
| 124 |
+
"""Build HTML for loading status."""
|
| 125 |
+
status_configs = {
|
| 126 |
+
LoadingStatus.NOT_STARTED: ("Not Started", "#9e9e9e", "#fafafa"),
|
| 127 |
+
LoadingStatus.DOWNLOADING: ("Downloading", "#1976d2", "#e3f2fd"),
|
| 128 |
+
LoadingStatus.LOADING: ("Loading", "#ff9800", "#fff3e0"),
|
| 129 |
+
LoadingStatus.READY: ("Ready", "#2e7d32", "#e8f5e9"),
|
| 130 |
+
LoadingStatus.ERROR: ("Error", "#c62828", "#ffebee"),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
text, color, bg_color = status_configs.get(
|
| 134 |
+
status, ("Unknown", "#9e9e9e", "#fafafa")
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return (
|
| 138 |
+
f'<div style="padding: 10px; background: {bg_color}; '
|
| 139 |
+
f'border-radius: 5px; text-align: center;">'
|
| 140 |
+
f'<span style="color: {color}; font-weight: bold; font-size: 1.2em;">'
|
| 141 |
+
f'{text}</span></div>'
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _format_shard_status(status: str) -> str:
|
| 146 |
+
"""Format shard status with indicator."""
|
| 147 |
+
if status == "loaded":
|
| 148 |
+
return "Loaded"
|
| 149 |
+
if status == "loading":
|
| 150 |
+
return "Loading..."
|
| 151 |
+
return "Pending"
|
components/loadtest_panel.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load testing panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
|
| 8 |
+
from services.load_tester import LoadTester, LoadTestConfig
|
| 9 |
+
from storage.models import LoadTestResult
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Global load tester instance (managed by the panel)
|
| 13 |
+
_active_load_tester: Optional[LoadTester] = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_loadtest_panel() -> Dict[str, Any]:
|
| 17 |
+
"""
|
| 18 |
+
Create the load testing panel.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Dictionary of Gradio components
|
| 22 |
+
"""
|
| 23 |
+
global _active_load_tester
|
| 24 |
+
|
| 25 |
+
with gr.Column():
|
| 26 |
+
gr.Markdown("### Load Testing")
|
| 27 |
+
|
| 28 |
+
# Configuration
|
| 29 |
+
with gr.Row():
|
| 30 |
+
target_endpoint = gr.Textbox(
|
| 31 |
+
label="Target Endpoint",
|
| 32 |
+
value="http://localhost:8000",
|
| 33 |
+
placeholder="http://host:port",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
with gr.Row():
|
| 37 |
+
concurrent_users = gr.Slider(
|
| 38 |
+
label="Concurrent Users",
|
| 39 |
+
minimum=1,
|
| 40 |
+
maximum=100,
|
| 41 |
+
value=10,
|
| 42 |
+
step=1,
|
| 43 |
+
)
|
| 44 |
+
requests_per_second = gr.Slider(
|
| 45 |
+
label="Requests/Second",
|
| 46 |
+
minimum=0.1,
|
| 47 |
+
maximum=50,
|
| 48 |
+
value=5,
|
| 49 |
+
step=0.5,
|
| 50 |
+
)
|
| 51 |
+
duration = gr.Slider(
|
| 52 |
+
label="Duration (seconds)",
|
| 53 |
+
minimum=10,
|
| 54 |
+
maximum=300,
|
| 55 |
+
value=60,
|
| 56 |
+
step=10,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
with gr.Row():
|
| 60 |
+
prompt_distribution = gr.Dropdown(
|
| 61 |
+
choices=["fixed", "realistic", "random"],
|
| 62 |
+
value="fixed",
|
| 63 |
+
label="Prompt Distribution",
|
| 64 |
+
)
|
| 65 |
+
max_tokens = gr.Slider(
|
| 66 |
+
label="Max Tokens",
|
| 67 |
+
minimum=10,
|
| 68 |
+
maximum=500,
|
| 69 |
+
value=100,
|
| 70 |
+
step=10,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Control buttons
|
| 74 |
+
with gr.Row():
|
| 75 |
+
start_btn = gr.Button("Start Load Test", variant="primary")
|
| 76 |
+
stop_btn = gr.Button("Stop", variant="stop")
|
| 77 |
+
|
| 78 |
+
# Status
|
| 79 |
+
test_status = gr.Textbox(
|
| 80 |
+
label="Status",
|
| 81 |
+
interactive=False,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Progress
|
| 85 |
+
with gr.Row():
|
| 86 |
+
progress_elapsed = gr.Number(label="Elapsed (s)", precision=0, interactive=False)
|
| 87 |
+
progress_requests = gr.Number(label="Total Requests", precision=0, interactive=False)
|
| 88 |
+
progress_success = gr.Number(label="Successful", precision=0, interactive=False)
|
| 89 |
+
progress_failed = gr.Number(label="Failed", precision=0, interactive=False)
|
| 90 |
+
|
| 91 |
+
# Results
|
| 92 |
+
gr.Markdown("### Results")
|
| 93 |
+
with gr.Row():
|
| 94 |
+
result_avg = gr.Number(label="Avg Latency (ms)", precision=1, interactive=False)
|
| 95 |
+
result_p50 = gr.Number(label="P50 (ms)", precision=1, interactive=False)
|
| 96 |
+
result_p95 = gr.Number(label="P95 (ms)", precision=1, interactive=False)
|
| 97 |
+
result_p99 = gr.Number(label="P99 (ms)", precision=1, interactive=False)
|
| 98 |
+
|
| 99 |
+
with gr.Row():
|
| 100 |
+
result_throughput = gr.Number(label="Throughput (req/s)", precision=2, interactive=False)
|
| 101 |
+
result_saturation = gr.Number(label="Saturation Point", precision=1, interactive=False)
|
| 102 |
+
|
| 103 |
+
# Latency over time chart
|
| 104 |
+
latency_chart = gr.LinePlot(
|
| 105 |
+
x="time",
|
| 106 |
+
y="latency_ms",
|
| 107 |
+
title="Latency Over Time",
|
| 108 |
+
x_title="Time",
|
| 109 |
+
y_title="Latency (ms)",
|
| 110 |
+
height=250,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Event handlers
|
| 114 |
+
async def start_load_test(endpoint, users, rps, dur, dist, max_tok):
|
| 115 |
+
global _active_load_tester
|
| 116 |
+
|
| 117 |
+
config = LoadTestConfig(
|
| 118 |
+
target_endpoint=endpoint,
|
| 119 |
+
concurrent_users=int(users),
|
| 120 |
+
requests_per_second=rps,
|
| 121 |
+
duration_seconds=int(dur),
|
| 122 |
+
prompt_length_distribution=dist,
|
| 123 |
+
max_tokens=int(max_tok),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
_active_load_tester = LoadTester(config)
|
| 127 |
+
|
| 128 |
+
# Initial status
|
| 129 |
+
yield (
|
| 130 |
+
"Starting load test...",
|
| 131 |
+
0, 0, 0, 0,
|
| 132 |
+
0, 0, 0, 0, 0, None,
|
| 133 |
+
pd.DataFrame({"time": [], "latency_ms": []}),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
result = await _active_load_tester.run()
|
| 138 |
+
yield format_load_test_results(result, _active_load_tester)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
yield (
|
| 141 |
+
f"Error: {str(e)}",
|
| 142 |
+
0, 0, 0, 0,
|
| 143 |
+
0, 0, 0, 0, 0, None,
|
| 144 |
+
pd.DataFrame({"time": [], "latency_ms": []}),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def stop_load_test():
|
| 148 |
+
global _active_load_tester
|
| 149 |
+
if _active_load_tester:
|
| 150 |
+
_active_load_tester.stop()
|
| 151 |
+
return "Load test stopped"
|
| 152 |
+
return "No active load test"
|
| 153 |
+
|
| 154 |
+
start_btn.click(
|
| 155 |
+
fn=start_load_test,
|
| 156 |
+
inputs=[
|
| 157 |
+
target_endpoint, concurrent_users, requests_per_second,
|
| 158 |
+
duration, prompt_distribution, max_tokens
|
| 159 |
+
],
|
| 160 |
+
outputs=[
|
| 161 |
+
test_status,
|
| 162 |
+
progress_elapsed, progress_requests, progress_success, progress_failed,
|
| 163 |
+
result_avg, result_p50, result_p95, result_p99, result_throughput, result_saturation,
|
| 164 |
+
latency_chart,
|
| 165 |
+
],
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
stop_btn.click(
|
| 169 |
+
fn=stop_load_test,
|
| 170 |
+
outputs=test_status,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return {
|
| 174 |
+
"target_endpoint": target_endpoint,
|
| 175 |
+
"concurrent_users": concurrent_users,
|
| 176 |
+
"requests_per_second": requests_per_second,
|
| 177 |
+
"duration": duration,
|
| 178 |
+
"prompt_distribution": prompt_distribution,
|
| 179 |
+
"max_tokens": max_tokens,
|
| 180 |
+
"test_status": test_status,
|
| 181 |
+
"progress_elapsed": progress_elapsed,
|
| 182 |
+
"progress_requests": progress_requests,
|
| 183 |
+
"progress_success": progress_success,
|
| 184 |
+
"progress_failed": progress_failed,
|
| 185 |
+
"result_avg": result_avg,
|
| 186 |
+
"result_p50": result_p50,
|
| 187 |
+
"result_p95": result_p95,
|
| 188 |
+
"result_p99": result_p99,
|
| 189 |
+
"result_throughput": result_throughput,
|
| 190 |
+
"result_saturation": result_saturation,
|
| 191 |
+
"latency_chart": latency_chart,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def format_load_test_results(
|
| 196 |
+
result: LoadTestResult,
|
| 197 |
+
tester: LoadTester,
|
| 198 |
+
) -> tuple:
|
| 199 |
+
"""Format load test results for UI components."""
|
| 200 |
+
# Build latency timeseries
|
| 201 |
+
timeseries = tester.get_latency_timeseries()
|
| 202 |
+
if timeseries:
|
| 203 |
+
latency_df = pd.DataFrame(timeseries)
|
| 204 |
+
else:
|
| 205 |
+
latency_df = pd.DataFrame({"time": [], "latency_ms": []})
|
| 206 |
+
|
| 207 |
+
return (
|
| 208 |
+
f"Load test complete: {result.total_requests} requests",
|
| 209 |
+
result.duration_seconds,
|
| 210 |
+
result.total_requests,
|
| 211 |
+
result.successful_requests,
|
| 212 |
+
result.failed_requests,
|
| 213 |
+
result.avg_latency_ms,
|
| 214 |
+
result.p50_latency_ms,
|
| 215 |
+
result.p95_latency_ms,
|
| 216 |
+
result.p99_latency_ms,
|
| 217 |
+
result.throughput_rps,
|
| 218 |
+
result.saturation_point,
|
| 219 |
+
latency_df,
|
| 220 |
+
)
|
components/quant_panel.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quantization details panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from typing import Dict, Any, Tuple, Optional
|
| 6 |
+
|
| 7 |
+
from collectors.quant_collector import QuantizationCollector, QuantizationInfo
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_quant_panel() -> Dict[str, Any]:
|
| 11 |
+
"""
|
| 12 |
+
Create the quantization details panel.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Dictionary of Gradio components
|
| 16 |
+
"""
|
| 17 |
+
with gr.Column():
|
| 18 |
+
gr.Markdown("### Quantization Details")
|
| 19 |
+
|
| 20 |
+
with gr.Row():
|
| 21 |
+
quant_type = gr.Textbox(
|
| 22 |
+
label="Quantization Method",
|
| 23 |
+
interactive=False,
|
| 24 |
+
)
|
| 25 |
+
bits = gr.Number(
|
| 26 |
+
label="Bits",
|
| 27 |
+
precision=0,
|
| 28 |
+
interactive=False,
|
| 29 |
+
)
|
| 30 |
+
group_size = gr.Number(
|
| 31 |
+
label="Group Size",
|
| 32 |
+
precision=0,
|
| 33 |
+
interactive=False,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Full configuration JSON
|
| 37 |
+
quant_details = gr.JSON(
|
| 38 |
+
label="Full Configuration",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Layer precision table
|
| 42 |
+
gr.Markdown("#### Per-Layer Precision")
|
| 43 |
+
layer_table = gr.Dataframe(
|
| 44 |
+
headers=["Layer", "Bits", "Group Size", "Dtype"],
|
| 45 |
+
datatype=["str", "number", "str", "str"],
|
| 46 |
+
label="Layer Precisions",
|
| 47 |
+
interactive=False,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"quant_type": quant_type,
|
| 52 |
+
"bits": bits,
|
| 53 |
+
"group_size": group_size,
|
| 54 |
+
"quant_details": quant_details,
|
| 55 |
+
"layer_table": layer_table,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def update_quant_panel(
|
| 60 |
+
collector: QuantizationCollector,
|
| 61 |
+
) -> Tuple[str, int, Optional[int], Dict, pd.DataFrame]:
|
| 62 |
+
"""
|
| 63 |
+
Update the quantization panel with current data.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
collector: Quantization collector instance
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Tuple of (method, bits, group_size, details_json, layer_table)
|
| 70 |
+
"""
|
| 71 |
+
info = collector.detect()
|
| 72 |
+
layers = collector.get_layer_precisions()
|
| 73 |
+
|
| 74 |
+
# Build layer table
|
| 75 |
+
layer_rows = []
|
| 76 |
+
for layer in layers[:20]: # Limit to 20 rows
|
| 77 |
+
layer_rows.append({
|
| 78 |
+
"Layer": layer.layer_name,
|
| 79 |
+
"Bits": layer.bits,
|
| 80 |
+
"Group Size": str(layer.group_size) if layer.group_size else "-",
|
| 81 |
+
"Dtype": layer.dtype,
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
layer_df = pd.DataFrame(layer_rows) if layer_rows else pd.DataFrame(
|
| 85 |
+
columns=["Layer", "Bits", "Group Size", "Dtype"]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return (
|
| 89 |
+
info.method,
|
| 90 |
+
info.bits,
|
| 91 |
+
info.group_size,
|
| 92 |
+
info.to_dict(),
|
| 93 |
+
layer_df,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_quant_summary(info: QuantizationInfo) -> str:
|
| 98 |
+
"""
|
| 99 |
+
Get a summary string for the quantization.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
info: QuantizationInfo instance
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Human-readable summary string
|
| 106 |
+
"""
|
| 107 |
+
if info.method == "None (FP16/BF16)":
|
| 108 |
+
return f"Full precision ({info.compute_dtype or 'float16'})"
|
| 109 |
+
|
| 110 |
+
summary = f"{info.method} {info.bits}-bit"
|
| 111 |
+
|
| 112 |
+
if info.group_size:
|
| 113 |
+
summary += f", group size {info.group_size}"
|
| 114 |
+
|
| 115 |
+
if info.quant_type:
|
| 116 |
+
summary += f" ({info.quant_type})"
|
| 117 |
+
|
| 118 |
+
return summary
|
components/tracing_panel.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Request tracing panel component."""
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from typing import Dict, Any, Tuple
|
| 6 |
+
|
| 7 |
+
from services.request_tracer import RequestTracer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_tracing_panel(tracer: RequestTracer) -> Dict[str, Any]:
|
| 11 |
+
"""
|
| 12 |
+
Create the request tracing panel.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
tracer: Request tracer instance
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Dictionary of Gradio components
|
| 19 |
+
"""
|
| 20 |
+
with gr.Column():
|
| 21 |
+
gr.Markdown("### Request Tracing")
|
| 22 |
+
|
| 23 |
+
# Filter controls
|
| 24 |
+
with gr.Row():
|
| 25 |
+
trace_filter = gr.Dropdown(
|
| 26 |
+
choices=["All Requests", "Slow Only"],
|
| 27 |
+
value="All Requests",
|
| 28 |
+
label="Filter",
|
| 29 |
+
)
|
| 30 |
+
trace_limit = gr.Slider(
|
| 31 |
+
minimum=10,
|
| 32 |
+
maximum=500,
|
| 33 |
+
value=100,
|
| 34 |
+
step=10,
|
| 35 |
+
label="Show Last N Requests",
|
| 36 |
+
)
|
| 37 |
+
refresh_btn = gr.Button("Refresh", size="sm")
|
| 38 |
+
|
| 39 |
+
# Summary stats
|
| 40 |
+
with gr.Row():
|
| 41 |
+
total_requests = gr.Number(
|
| 42 |
+
label="Total Requests",
|
| 43 |
+
precision=0,
|
| 44 |
+
interactive=False,
|
| 45 |
+
)
|
| 46 |
+
slow_requests = gr.Number(
|
| 47 |
+
label="Slow Requests",
|
| 48 |
+
precision=0,
|
| 49 |
+
interactive=False,
|
| 50 |
+
)
|
| 51 |
+
slow_rate = gr.Number(
|
| 52 |
+
label="Slow Rate %",
|
| 53 |
+
precision=1,
|
| 54 |
+
interactive=False,
|
| 55 |
+
)
|
| 56 |
+
baseline_p95 = gr.Number(
|
| 57 |
+
label="Baseline P95 (ms)",
|
| 58 |
+
precision=1,
|
| 59 |
+
interactive=False,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Traces table
|
| 63 |
+
traces_table = gr.Dataframe(
|
| 64 |
+
headers=[
|
| 65 |
+
"ID", "Prompt Toks", "Output Toks",
|
| 66 |
+
"Queue (ms)", "Prefill (ms)", "Decode (ms)",
|
| 67 |
+
"Total (ms)", "Tok/s", "Slow?"
|
| 68 |
+
],
|
| 69 |
+
datatype=[
|
| 70 |
+
"str", "number", "number",
|
| 71 |
+
"number", "number", "number",
|
| 72 |
+
"number", "number", "str"
|
| 73 |
+
],
|
| 74 |
+
label="Request Traces",
|
| 75 |
+
interactive=False,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Latency breakdown chart
|
| 79 |
+
gr.Markdown("#### Average Latency Breakdown")
|
| 80 |
+
latency_breakdown = gr.BarPlot(
|
| 81 |
+
x="phase",
|
| 82 |
+
y="ms",
|
| 83 |
+
title="Latency by Phase",
|
| 84 |
+
x_title="Phase",
|
| 85 |
+
y_title="Time (ms)",
|
| 86 |
+
height=200,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Percentiles
|
| 90 |
+
gr.Markdown("#### Latency Percentiles")
|
| 91 |
+
with gr.Row():
|
| 92 |
+
p50 = gr.Number(label="P50 (ms)", precision=1, interactive=False)
|
| 93 |
+
p95 = gr.Number(label="P95 (ms)", precision=1, interactive=False)
|
| 94 |
+
p99 = gr.Number(label="P99 (ms)", precision=1, interactive=False)
|
| 95 |
+
|
| 96 |
+
# Event handlers
|
| 97 |
+
def refresh_traces(filter_val, limit):
|
| 98 |
+
slow_only = filter_val == "Slow Only"
|
| 99 |
+
return update_tracing_panel(tracer, slow_only, int(limit))
|
| 100 |
+
|
| 101 |
+
refresh_btn.click(
|
| 102 |
+
fn=refresh_traces,
|
| 103 |
+
inputs=[trace_filter, trace_limit],
|
| 104 |
+
outputs=[
|
| 105 |
+
total_requests, slow_requests, slow_rate, baseline_p95,
|
| 106 |
+
traces_table, latency_breakdown, p50, p95, p99
|
| 107 |
+
],
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"trace_filter": trace_filter,
|
| 112 |
+
"trace_limit": trace_limit,
|
| 113 |
+
"total_requests": total_requests,
|
| 114 |
+
"slow_requests": slow_requests,
|
| 115 |
+
"slow_rate": slow_rate,
|
| 116 |
+
"baseline_p95": baseline_p95,
|
| 117 |
+
"traces_table": traces_table,
|
| 118 |
+
"latency_breakdown": latency_breakdown,
|
| 119 |
+
"p50": p50,
|
| 120 |
+
"p95": p95,
|
| 121 |
+
"p99": p99,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def update_tracing_panel(
|
| 126 |
+
tracer: RequestTracer,
|
| 127 |
+
slow_only: bool = False,
|
| 128 |
+
limit: int = 100,
|
| 129 |
+
) -> Tuple[int, int, float, float, pd.DataFrame, pd.DataFrame, float, float, float]:
|
| 130 |
+
"""
|
| 131 |
+
Update the tracing panel with current data.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
tracer: Request tracer instance
|
| 135 |
+
slow_only: Only show slow requests
|
| 136 |
+
limit: Maximum number of traces to show
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Tuple of all component values
|
| 140 |
+
"""
|
| 141 |
+
stats = tracer.get_stats()
|
| 142 |
+
traces = tracer.get_recent_traces(limit=limit, slow_only=slow_only)
|
| 143 |
+
breakdown = tracer.get_latency_breakdown()
|
| 144 |
+
percentiles = tracer.get_percentiles()
|
| 145 |
+
|
| 146 |
+
# Build traces table
|
| 147 |
+
trace_rows = []
|
| 148 |
+
for trace in reversed(traces): # Most recent first
|
| 149 |
+
trace_rows.append({
|
| 150 |
+
"ID": trace.request_id,
|
| 151 |
+
"Prompt Toks": trace.prompt_tokens,
|
| 152 |
+
"Output Toks": trace.output_tokens,
|
| 153 |
+
"Queue (ms)": round(trace.queue_time_ms, 1),
|
| 154 |
+
"Prefill (ms)": round(trace.prefill_time_ms, 1),
|
| 155 |
+
"Decode (ms)": round(trace.decode_time_ms, 1),
|
| 156 |
+
"Total (ms)": round(trace.total_time_ms, 1),
|
| 157 |
+
"Tok/s": round(trace.tokens_per_second, 1),
|
| 158 |
+
"Slow?": "Yes" if trace.is_slow else "",
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
traces_df = pd.DataFrame(trace_rows) if trace_rows else pd.DataFrame(
|
| 162 |
+
columns=[
|
| 163 |
+
"ID", "Prompt Toks", "Output Toks",
|
| 164 |
+
"Queue (ms)", "Prefill (ms)", "Decode (ms)",
|
| 165 |
+
"Total (ms)", "Tok/s", "Slow?"
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Build breakdown chart
|
| 170 |
+
breakdown_df = pd.DataFrame([
|
| 171 |
+
{"phase": "Queue", "ms": breakdown.queue_ms},
|
| 172 |
+
{"phase": "Prefill", "ms": breakdown.prefill_ms},
|
| 173 |
+
{"phase": "Decode", "ms": breakdown.decode_ms},
|
| 174 |
+
])
|
| 175 |
+
|
| 176 |
+
return (
|
| 177 |
+
stats["total_requests"],
|
| 178 |
+
stats["slow_requests"],
|
| 179 |
+
stats.get("slow_rate_percent", 0),
|
| 180 |
+
stats.get("baseline_p95", 0) or 0,
|
| 181 |
+
traces_df,
|
| 182 |
+
breakdown_df,
|
| 183 |
+
percentiles["p50"],
|
| 184 |
+
percentiles["p95"],
|
| 185 |
+
percentiles["p99"],
|
| 186 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration settings for LLM Inference Dashboard."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class Config:
|
| 10 |
+
"""Dashboard configuration with sensible defaults."""
|
| 11 |
+
|
| 12 |
+
# vLLM Connection
|
| 13 |
+
vllm_host: str = "localhost"
|
| 14 |
+
vllm_port: int = 8000
|
| 15 |
+
model_path: Optional[str] = None
|
| 16 |
+
|
| 17 |
+
# Dashboard
|
| 18 |
+
refresh_interval: float = 1.0
|
| 19 |
+
history_length: int = 300 # 5 minutes at 1s intervals
|
| 20 |
+
|
| 21 |
+
# Database
|
| 22 |
+
db_path: str = "data/metrics.db"
|
| 23 |
+
|
| 24 |
+
# Alert Thresholds
|
| 25 |
+
alert_kv_cache_threshold: float = 90.0
|
| 26 |
+
alert_gpu_memory_threshold: float = 95.0
|
| 27 |
+
alert_ttft_multiplier: float = 2.0
|
| 28 |
+
alert_throughput_drop_pct: float = 50.0
|
| 29 |
+
|
| 30 |
+
# Webhooks
|
| 31 |
+
slack_webhook: Optional[str] = None
|
| 32 |
+
pagerduty_routing_key: Optional[str] = None
|
| 33 |
+
generic_webhooks: list = field(default_factory=list)
|
| 34 |
+
|
| 35 |
+
# Load Testing Defaults
|
| 36 |
+
loadtest_concurrent_users: int = 10
|
| 37 |
+
loadtest_rps: float = 5.0
|
| 38 |
+
loadtest_duration: int = 60
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def metrics_endpoint(self) -> str:
|
| 42 |
+
return f"http://{self.vllm_host}:{self.vllm_port}/metrics"
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def openai_endpoint(self) -> str:
|
| 46 |
+
return f"http://{self.vllm_host}:{self.vllm_port}/v1"
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def health_endpoint(self) -> str:
|
| 50 |
+
return f"http://{self.vllm_host}:{self.vllm_port}/health"
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_env(cls) -> "Config":
|
| 54 |
+
"""Create config from environment variables."""
|
| 55 |
+
return cls(
|
| 56 |
+
vllm_host=os.getenv("VLLM_HOST", "localhost"),
|
| 57 |
+
vllm_port=int(os.getenv("VLLM_PORT", "8000")),
|
| 58 |
+
model_path=os.getenv("MODEL_PATH"),
|
| 59 |
+
refresh_interval=float(os.getenv("REFRESH_INTERVAL", "1.0")),
|
| 60 |
+
db_path=os.getenv("DB_PATH", "data/metrics.db"),
|
| 61 |
+
slack_webhook=os.getenv("SLACK_WEBHOOK"),
|
| 62 |
+
pagerduty_routing_key=os.getenv("PAGERDUTY_KEY"),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Global config instance
|
| 67 |
+
config = Config.from_env()
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
gradio>=5.0.0
|
| 3 |
+
requests>=2.28.0
|
| 4 |
+
aiohttp>=3.9.0
|
| 5 |
+
|
| 6 |
+
# Data processing
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
numpy>=1.24.0
|
| 9 |
+
scipy>=1.11.0
|
| 10 |
+
|
| 11 |
+
# Model utilities
|
| 12 |
+
safetensors>=0.4.0
|
| 13 |
+
huggingface-hub>=0.20.0
|
| 14 |
+
|
| 15 |
+
# GPU monitoring (optional - will use mock data if unavailable)
|
| 16 |
+
nvidia-ml-py3>=7.352.0
|
services/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Services for alerting, tracing, comparison, and load testing."""
|
| 2 |
+
|
| 3 |
+
from .alerting import AlertEngine, AlertDispatcher, Alert
|
| 4 |
+
from .request_tracer import RequestTracer
|
| 5 |
+
from .comparator import ABComparator, DeploymentConfig, ComparisonResult
|
| 6 |
+
from .load_tester import LoadTester, LoadTestConfig
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"AlertEngine",
|
| 10 |
+
"AlertDispatcher",
|
| 11 |
+
"Alert",
|
| 12 |
+
"RequestTracer",
|
| 13 |
+
"ABComparator",
|
| 14 |
+
"DeploymentConfig",
|
| 15 |
+
"ComparisonResult",
|
| 16 |
+
"LoadTester",
|
| 17 |
+
"LoadTestConfig",
|
| 18 |
+
]
|
services/alerting.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Alert engine and webhook dispatch for monitoring thresholds."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Dict, List, Optional, Any, Callable
|
| 8 |
+
from enum import Enum
|
| 9 |
+
|
| 10 |
+
import aiohttp
|
| 11 |
+
|
| 12 |
+
from storage.database import MetricsDB
|
| 13 |
+
from storage.models import AlertRecord
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AlertSeverity(Enum):
|
| 19 |
+
INFO = "info"
|
| 20 |
+
WARNING = "warning"
|
| 21 |
+
CRITICAL = "critical"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class AlertRule:
|
| 26 |
+
"""Configuration for an alert rule."""
|
| 27 |
+
name: str
|
| 28 |
+
metric: str
|
| 29 |
+
condition: str # >, <, >=, <=, ==
|
| 30 |
+
threshold: float
|
| 31 |
+
severity: AlertSeverity
|
| 32 |
+
message: str
|
| 33 |
+
# For dynamic thresholds
|
| 34 |
+
threshold_type: str = "static" # static, baseline_multiplier, baseline_percent
|
| 35 |
+
multiplier: float = 1.0
|
| 36 |
+
percent: float = 100.0
|
| 37 |
+
cooldown_seconds: int = 60
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class Alert:
|
| 42 |
+
"""A triggered alert instance."""
|
| 43 |
+
rule_name: str
|
| 44 |
+
metric: str
|
| 45 |
+
value: float
|
| 46 |
+
threshold: float
|
| 47 |
+
severity: AlertSeverity
|
| 48 |
+
message: str
|
| 49 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 50 |
+
resolved: bool = False
|
| 51 |
+
|
| 52 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 53 |
+
return {
|
| 54 |
+
"rule_name": self.rule_name,
|
| 55 |
+
"metric": self.metric,
|
| 56 |
+
"value": self.value,
|
| 57 |
+
"threshold": self.threshold,
|
| 58 |
+
"severity": self.severity.value,
|
| 59 |
+
"message": self.message,
|
| 60 |
+
"timestamp": self.timestamp.isoformat(),
|
| 61 |
+
"resolved": self.resolved,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Default alert rules
|
| 66 |
+
DEFAULT_RULES = {
|
| 67 |
+
"kv_cache_high": AlertRule(
|
| 68 |
+
name="kv_cache_high",
|
| 69 |
+
metric="kv_cache_percent",
|
| 70 |
+
condition=">",
|
| 71 |
+
threshold=90.0,
|
| 72 |
+
severity=AlertSeverity.WARNING,
|
| 73 |
+
message="KV cache utilization above 90%",
|
| 74 |
+
),
|
| 75 |
+
"gpu_memory_critical": AlertRule(
|
| 76 |
+
name="gpu_memory_critical",
|
| 77 |
+
metric="gpu_memory_percent",
|
| 78 |
+
condition=">",
|
| 79 |
+
threshold=95.0,
|
| 80 |
+
severity=AlertSeverity.CRITICAL,
|
| 81 |
+
message="GPU memory critically high (>95%)",
|
| 82 |
+
),
|
| 83 |
+
"ttft_spike": AlertRule(
|
| 84 |
+
name="ttft_spike",
|
| 85 |
+
metric="ttft_ms",
|
| 86 |
+
condition=">",
|
| 87 |
+
threshold=0, # Dynamic
|
| 88 |
+
threshold_type="baseline_multiplier",
|
| 89 |
+
multiplier=2.0,
|
| 90 |
+
severity=AlertSeverity.WARNING,
|
| 91 |
+
message="Time to first token spiked to 2x baseline",
|
| 92 |
+
),
|
| 93 |
+
"throughput_drop": AlertRule(
|
| 94 |
+
name="throughput_drop",
|
| 95 |
+
metric="tokens_per_second",
|
| 96 |
+
condition="<",
|
| 97 |
+
threshold=0, # Dynamic
|
| 98 |
+
threshold_type="baseline_percent",
|
| 99 |
+
percent=50.0,
|
| 100 |
+
severity=AlertSeverity.WARNING,
|
| 101 |
+
message="Throughput dropped below 50% of baseline",
|
| 102 |
+
),
|
| 103 |
+
"queue_buildup": AlertRule(
|
| 104 |
+
name="queue_buildup",
|
| 105 |
+
metric="queue_depth",
|
| 106 |
+
condition=">",
|
| 107 |
+
threshold=50.0,
|
| 108 |
+
severity=AlertSeverity.WARNING,
|
| 109 |
+
message="Request queue depth exceeds 50",
|
| 110 |
+
),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class AlertEngine:
|
| 115 |
+
"""Evaluates metrics against alert rules."""
|
| 116 |
+
|
| 117 |
+
def __init__(self, db: Optional[MetricsDB] = None):
|
| 118 |
+
"""
|
| 119 |
+
Initialize alert engine.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
db: Optional database for persisting alerts
|
| 123 |
+
"""
|
| 124 |
+
self.db = db
|
| 125 |
+
self.rules: Dict[str, AlertRule] = dict(DEFAULT_RULES)
|
| 126 |
+
self.active_alerts: Dict[str, Alert] = {}
|
| 127 |
+
self.baselines: Dict[str, float] = {}
|
| 128 |
+
self._last_trigger_times: Dict[str, datetime] = {}
|
| 129 |
+
self._callbacks: List[Callable[[Alert], None]] = []
|
| 130 |
+
|
| 131 |
+
def add_rule(self, rule: AlertRule) -> None:
|
| 132 |
+
"""Add or update an alert rule."""
|
| 133 |
+
self.rules[rule.name] = rule
|
| 134 |
+
|
| 135 |
+
def remove_rule(self, name: str) -> None:
|
| 136 |
+
"""Remove an alert rule."""
|
| 137 |
+
self.rules.pop(name, None)
|
| 138 |
+
|
| 139 |
+
def set_baseline(self, metric: str, value: float) -> None:
|
| 140 |
+
"""Set baseline value for a metric."""
|
| 141 |
+
self.baselines[metric] = value
|
| 142 |
+
|
| 143 |
+
def update_baselines(self, metrics: Dict[str, float]) -> None:
|
| 144 |
+
"""Update baseline values from current metrics."""
|
| 145 |
+
for metric, value in metrics.items():
|
| 146 |
+
if metric not in self.baselines and value > 0:
|
| 147 |
+
self.baselines[metric] = value
|
| 148 |
+
|
| 149 |
+
def on_alert(self, callback: Callable[[Alert], None]) -> None:
|
| 150 |
+
"""Register callback for new alerts."""
|
| 151 |
+
self._callbacks.append(callback)
|
| 152 |
+
|
| 153 |
+
def evaluate(self, metrics: Dict[str, float]) -> List[Alert]:
|
| 154 |
+
"""
|
| 155 |
+
Evaluate metrics against all rules.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
metrics: Current metric values
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
List of newly triggered alerts
|
| 162 |
+
"""
|
| 163 |
+
new_alerts = []
|
| 164 |
+
|
| 165 |
+
for rule_name, rule in self.rules.items():
|
| 166 |
+
if rule.metric not in metrics:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
value = metrics[rule.metric]
|
| 170 |
+
threshold = self._get_threshold(rule)
|
| 171 |
+
|
| 172 |
+
if threshold is None:
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
triggered = self._check_condition(value, rule.condition, threshold)
|
| 176 |
+
|
| 177 |
+
if triggered:
|
| 178 |
+
# Check cooldown
|
| 179 |
+
if rule_name in self._last_trigger_times:
|
| 180 |
+
elapsed = (
|
| 181 |
+
datetime.now() - self._last_trigger_times[rule_name]
|
| 182 |
+
).total_seconds()
|
| 183 |
+
if elapsed < rule.cooldown_seconds:
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
# Create alert
|
| 187 |
+
alert = Alert(
|
| 188 |
+
rule_name=rule_name,
|
| 189 |
+
metric=rule.metric,
|
| 190 |
+
value=value,
|
| 191 |
+
threshold=threshold,
|
| 192 |
+
severity=rule.severity,
|
| 193 |
+
message=rule.message,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.active_alerts[rule_name] = alert
|
| 197 |
+
self._last_trigger_times[rule_name] = datetime.now()
|
| 198 |
+
new_alerts.append(alert)
|
| 199 |
+
|
| 200 |
+
# Persist to database
|
| 201 |
+
if self.db:
|
| 202 |
+
record = AlertRecord(
|
| 203 |
+
rule_name=rule_name,
|
| 204 |
+
severity=rule.severity.value,
|
| 205 |
+
metric_name=rule.metric,
|
| 206 |
+
value=value,
|
| 207 |
+
threshold=threshold,
|
| 208 |
+
message=rule.message,
|
| 209 |
+
)
|
| 210 |
+
self.db.insert_alert(record)
|
| 211 |
+
|
| 212 |
+
# Notify callbacks
|
| 213 |
+
for callback in self._callbacks:
|
| 214 |
+
try:
|
| 215 |
+
callback(alert)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Alert callback error: {e}")
|
| 218 |
+
|
| 219 |
+
elif rule_name in self.active_alerts:
|
| 220 |
+
# Resolve alert
|
| 221 |
+
self.active_alerts[rule_name].resolved = True
|
| 222 |
+
del self.active_alerts[rule_name]
|
| 223 |
+
|
| 224 |
+
return new_alerts
|
| 225 |
+
|
| 226 |
+
def _get_threshold(self, rule: AlertRule) -> Optional[float]:
|
| 227 |
+
"""Calculate threshold for a rule."""
|
| 228 |
+
if rule.threshold_type == "static":
|
| 229 |
+
return rule.threshold
|
| 230 |
+
|
| 231 |
+
baseline = self.baselines.get(rule.metric)
|
| 232 |
+
if baseline is None:
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
if rule.threshold_type == "baseline_multiplier":
|
| 236 |
+
return baseline * rule.multiplier
|
| 237 |
+
|
| 238 |
+
if rule.threshold_type == "baseline_percent":
|
| 239 |
+
return baseline * (rule.percent / 100.0)
|
| 240 |
+
|
| 241 |
+
return rule.threshold
|
| 242 |
+
|
| 243 |
+
def _check_condition(
|
| 244 |
+
self, value: float, condition: str, threshold: float
|
| 245 |
+
) -> bool:
|
| 246 |
+
"""Check if condition is met."""
|
| 247 |
+
if condition == ">":
|
| 248 |
+
return value > threshold
|
| 249 |
+
if condition == ">=":
|
| 250 |
+
return value >= threshold
|
| 251 |
+
if condition == "<":
|
| 252 |
+
return value < threshold
|
| 253 |
+
if condition == "<=":
|
| 254 |
+
return value <= threshold
|
| 255 |
+
if condition == "==":
|
| 256 |
+
return abs(value - threshold) < 0.001
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
def get_active_alerts(self) -> List[Alert]:
|
| 260 |
+
"""Get all active (unresolved) alerts."""
|
| 261 |
+
return list(self.active_alerts.values())
|
| 262 |
+
|
| 263 |
+
def clear_alert(self, rule_name: str) -> None:
|
| 264 |
+
"""Manually clear an alert."""
|
| 265 |
+
if rule_name in self.active_alerts:
|
| 266 |
+
del self.active_alerts[rule_name]
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class AlertDispatcher:
|
| 270 |
+
"""Dispatches alerts to external services."""
|
| 271 |
+
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
slack_webhook: Optional[str] = None,
|
| 275 |
+
pagerduty_key: Optional[str] = None,
|
| 276 |
+
generic_webhooks: Optional[List[str]] = None,
|
| 277 |
+
):
|
| 278 |
+
"""
|
| 279 |
+
Initialize alert dispatcher.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
slack_webhook: Slack incoming webhook URL
|
| 283 |
+
pagerduty_key: PagerDuty routing key
|
| 284 |
+
generic_webhooks: List of generic webhook URLs
|
| 285 |
+
"""
|
| 286 |
+
self.slack_webhook = slack_webhook
|
| 287 |
+
self.pagerduty_key = pagerduty_key
|
| 288 |
+
self.generic_webhooks = generic_webhooks or []
|
| 289 |
+
|
| 290 |
+
async def dispatch(self, alert: Alert) -> None:
|
| 291 |
+
"""
|
| 292 |
+
Dispatch alert to all configured services.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
alert: Alert to dispatch
|
| 296 |
+
"""
|
| 297 |
+
tasks = []
|
| 298 |
+
|
| 299 |
+
if self.slack_webhook:
|
| 300 |
+
tasks.append(self._send_slack(alert))
|
| 301 |
+
|
| 302 |
+
if self.pagerduty_key and alert.severity == AlertSeverity.CRITICAL:
|
| 303 |
+
tasks.append(self._send_pagerduty(alert))
|
| 304 |
+
|
| 305 |
+
for webhook in self.generic_webhooks:
|
| 306 |
+
tasks.append(self._send_generic(webhook, alert))
|
| 307 |
+
|
| 308 |
+
if tasks:
|
| 309 |
+
await asyncio.gather(*tasks, return_exceptions=True)
|
| 310 |
+
|
| 311 |
+
async def _send_slack(self, alert: Alert) -> None:
|
| 312 |
+
"""Send alert to Slack."""
|
| 313 |
+
color = "danger" if alert.severity == AlertSeverity.CRITICAL else "warning"
|
| 314 |
+
emoji = "🚨" if alert.severity == AlertSeverity.CRITICAL else "⚠️"
|
| 315 |
+
|
| 316 |
+
payload = {
|
| 317 |
+
"text": f"{emoji} *{alert.severity.value.upper()}*: {alert.message}",
|
| 318 |
+
"attachments": [
|
| 319 |
+
{
|
| 320 |
+
"color": color,
|
| 321 |
+
"fields": [
|
| 322 |
+
{
|
| 323 |
+
"title": "Metric",
|
| 324 |
+
"value": alert.metric,
|
| 325 |
+
"short": True,
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"title": "Value",
|
| 329 |
+
"value": f"{alert.value:.2f}",
|
| 330 |
+
"short": True,
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"title": "Threshold",
|
| 334 |
+
"value": f"{alert.threshold:.2f}",
|
| 335 |
+
"short": True,
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"title": "Time",
|
| 339 |
+
"value": alert.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
|
| 340 |
+
"short": True,
|
| 341 |
+
},
|
| 342 |
+
],
|
| 343 |
+
}
|
| 344 |
+
],
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
try:
|
| 348 |
+
async with aiohttp.ClientSession() as session:
|
| 349 |
+
async with session.post(
|
| 350 |
+
self.slack_webhook,
|
| 351 |
+
json=payload,
|
| 352 |
+
timeout=aiohttp.ClientTimeout(total=10),
|
| 353 |
+
) as response:
|
| 354 |
+
if response.status != 200:
|
| 355 |
+
logger.error(f"Slack webhook failed: {response.status}")
|
| 356 |
+
except Exception as e:
|
| 357 |
+
logger.error(f"Error sending Slack alert: {e}")
|
| 358 |
+
|
| 359 |
+
async def _send_pagerduty(self, alert: Alert) -> None:
|
| 360 |
+
"""Send alert to PagerDuty."""
|
| 361 |
+
payload = {
|
| 362 |
+
"routing_key": self.pagerduty_key,
|
| 363 |
+
"event_action": "trigger",
|
| 364 |
+
"dedup_key": f"llm-dashboard-{alert.rule_name}",
|
| 365 |
+
"payload": {
|
| 366 |
+
"summary": alert.message,
|
| 367 |
+
"severity": "critical",
|
| 368 |
+
"source": "llm-inference-dashboard",
|
| 369 |
+
"custom_details": {
|
| 370 |
+
"metric": alert.metric,
|
| 371 |
+
"value": alert.value,
|
| 372 |
+
"threshold": alert.threshold,
|
| 373 |
+
},
|
| 374 |
+
},
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
try:
|
| 378 |
+
async with aiohttp.ClientSession() as session:
|
| 379 |
+
async with session.post(
|
| 380 |
+
"https://events.pagerduty.com/v2/enqueue",
|
| 381 |
+
json=payload,
|
| 382 |
+
timeout=aiohttp.ClientTimeout(total=10),
|
| 383 |
+
) as response:
|
| 384 |
+
if response.status != 202:
|
| 385 |
+
logger.error(f"PagerDuty failed: {response.status}")
|
| 386 |
+
except Exception as e:
|
| 387 |
+
logger.error(f"Error sending PagerDuty alert: {e}")
|
| 388 |
+
|
| 389 |
+
async def _send_generic(self, webhook_url: str, alert: Alert) -> None:
|
| 390 |
+
"""Send alert to a generic webhook."""
|
| 391 |
+
payload = alert.to_dict()
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
async with aiohttp.ClientSession() as session:
|
| 395 |
+
async with session.post(
|
| 396 |
+
webhook_url,
|
| 397 |
+
json=payload,
|
| 398 |
+
timeout=aiohttp.ClientTimeout(total=10),
|
| 399 |
+
) as response:
|
| 400 |
+
if response.status >= 400:
|
| 401 |
+
logger.error(f"Webhook {webhook_url} failed: {response.status}")
|
| 402 |
+
except Exception as e:
|
| 403 |
+
logger.error(f"Error sending to webhook {webhook_url}: {e}")
|
| 404 |
+
|
| 405 |
+
async def send_test_alert(self) -> bool:
|
| 406 |
+
"""Send a test alert to verify configuration."""
|
| 407 |
+
test_alert = Alert(
|
| 408 |
+
rule_name="test_alert",
|
| 409 |
+
metric="test_metric",
|
| 410 |
+
value=100.0,
|
| 411 |
+
threshold=50.0,
|
| 412 |
+
severity=AlertSeverity.INFO,
|
| 413 |
+
message="This is a test alert from LLM Inference Dashboard",
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
await self.dispatch(test_alert)
|
| 418 |
+
return True
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.error(f"Test alert failed: {e}")
|
| 421 |
+
return False
|
services/comparator.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A/B comparison of vLLM deployments."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional, Dict, List, Any
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
import aiohttp
|
| 10 |
+
from scipy import stats
|
| 11 |
+
|
| 12 |
+
from utils.prometheus_parser import (
|
| 13 |
+
parse_prometheus_metrics,
|
| 14 |
+
get_metric_value,
|
| 15 |
+
get_histogram_quantile,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DeploymentConfig:
|
| 23 |
+
"""Configuration for a vLLM deployment."""
|
| 24 |
+
name: str
|
| 25 |
+
endpoint: str # Base URL (e.g., http://localhost:8000)
|
| 26 |
+
model_name: str = ""
|
| 27 |
+
quantization: str = ""
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def metrics_url(self) -> str:
|
| 31 |
+
return f"{self.endpoint}/metrics"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class DeploymentMetrics:
|
| 36 |
+
"""Metrics collected from a deployment."""
|
| 37 |
+
endpoint: str
|
| 38 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 39 |
+
connected: bool = False
|
| 40 |
+
|
| 41 |
+
# Throughput
|
| 42 |
+
tokens_per_second: float = 0.0
|
| 43 |
+
throughput_samples: List[float] = field(default_factory=list)
|
| 44 |
+
|
| 45 |
+
# Latency
|
| 46 |
+
ttft_ms: float = 0.0
|
| 47 |
+
tpot_ms: float = 0.0
|
| 48 |
+
e2e_latency_ms: float = 0.0
|
| 49 |
+
latency_samples: List[float] = field(default_factory=list)
|
| 50 |
+
|
| 51 |
+
# Resources
|
| 52 |
+
gpu_memory_gb: float = 0.0
|
| 53 |
+
kv_cache_percent: float = 0.0
|
| 54 |
+
batch_size: int = 0
|
| 55 |
+
|
| 56 |
+
# Model info
|
| 57 |
+
model_name: str = ""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class ComparisonResult:
|
| 62 |
+
"""Result of comparing two deployments."""
|
| 63 |
+
deployment_a: DeploymentMetrics
|
| 64 |
+
deployment_b: DeploymentMetrics
|
| 65 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 66 |
+
|
| 67 |
+
# Differences
|
| 68 |
+
throughput_diff_pct: float = 0.0
|
| 69 |
+
ttft_diff_pct: float = 0.0
|
| 70 |
+
latency_diff_pct: float = 0.0
|
| 71 |
+
memory_diff_gb: float = 0.0
|
| 72 |
+
|
| 73 |
+
# Statistical significance
|
| 74 |
+
throughput_significant: bool = False
|
| 75 |
+
latency_significant: bool = False
|
| 76 |
+
p_value_throughput: float = 1.0
|
| 77 |
+
p_value_latency: float = 1.0
|
| 78 |
+
|
| 79 |
+
# Recommendation
|
| 80 |
+
recommendation: str = ""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ABComparator:
|
| 84 |
+
"""Compares metrics between two vLLM deployments."""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
deployment_a: DeploymentConfig,
|
| 89 |
+
deployment_b: DeploymentConfig,
|
| 90 |
+
sample_count: int = 30,
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Initialize comparator.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
deployment_a: First deployment configuration
|
| 97 |
+
deployment_b: Second deployment configuration
|
| 98 |
+
sample_count: Number of samples to collect for statistical tests
|
| 99 |
+
"""
|
| 100 |
+
self.deployment_a = deployment_a
|
| 101 |
+
self.deployment_b = deployment_b
|
| 102 |
+
self.sample_count = sample_count
|
| 103 |
+
self._samples_a: List[DeploymentMetrics] = []
|
| 104 |
+
self._samples_b: List[DeploymentMetrics] = []
|
| 105 |
+
|
| 106 |
+
async def collect_metrics(self, config: DeploymentConfig) -> DeploymentMetrics:
|
| 107 |
+
"""
|
| 108 |
+
Collect current metrics from a deployment.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
config: Deployment configuration
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
DeploymentMetrics with current values
|
| 115 |
+
"""
|
| 116 |
+
metrics = DeploymentMetrics(endpoint=config.endpoint)
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
async with aiohttp.ClientSession() as session:
|
| 120 |
+
async with session.get(
|
| 121 |
+
config.metrics_url,
|
| 122 |
+
timeout=aiohttp.ClientTimeout(total=5),
|
| 123 |
+
) as response:
|
| 124 |
+
if response.status != 200:
|
| 125 |
+
return metrics
|
| 126 |
+
|
| 127 |
+
text = await response.text()
|
| 128 |
+
raw = parse_prometheus_metrics(text)
|
| 129 |
+
metrics.connected = True
|
| 130 |
+
|
| 131 |
+
# Parse metrics
|
| 132 |
+
metrics.tokens_per_second = self._calculate_tps(raw)
|
| 133 |
+
metrics.ttft_ms = (
|
| 134 |
+
get_histogram_quantile(
|
| 135 |
+
raw, "vllm:time_to_first_token_seconds", 0.5
|
| 136 |
+
)
|
| 137 |
+
or 0
|
| 138 |
+
) * 1000
|
| 139 |
+
metrics.tpot_ms = (
|
| 140 |
+
get_histogram_quantile(
|
| 141 |
+
raw, "vllm:time_per_output_token_seconds", 0.5
|
| 142 |
+
)
|
| 143 |
+
or 0
|
| 144 |
+
) * 1000
|
| 145 |
+
metrics.e2e_latency_ms = (
|
| 146 |
+
get_histogram_quantile(
|
| 147 |
+
raw, "vllm:e2e_request_latency_seconds", 0.5
|
| 148 |
+
)
|
| 149 |
+
or 0
|
| 150 |
+
) * 1000
|
| 151 |
+
metrics.kv_cache_percent = (
|
| 152 |
+
get_metric_value(raw, "vllm:gpu_cache_usage_perc") or 0
|
| 153 |
+
) * 100
|
| 154 |
+
metrics.batch_size = int(
|
| 155 |
+
get_metric_value(raw, "vllm:num_requests_running") or 0
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Model name from labels
|
| 159 |
+
for samples in raw.values():
|
| 160 |
+
for sample in samples:
|
| 161 |
+
if "model_name" in sample.labels:
|
| 162 |
+
metrics.model_name = sample.labels["model_name"]
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Error collecting metrics from {config.endpoint}: {e}")
|
| 167 |
+
|
| 168 |
+
return metrics
|
| 169 |
+
|
| 170 |
+
def _calculate_tps(self, raw: Dict) -> float:
|
| 171 |
+
"""Calculate tokens per second from counter metrics."""
|
| 172 |
+
# This is a simplified calculation
|
| 173 |
+
# In practice, you'd track delta over time
|
| 174 |
+
generation_total = get_metric_value(raw, "vllm:generation_tokens_total") or 0
|
| 175 |
+
if generation_total > 0:
|
| 176 |
+
# Estimate based on running requests
|
| 177 |
+
running = get_metric_value(raw, "vllm:num_requests_running") or 1
|
| 178 |
+
tpot = (
|
| 179 |
+
get_histogram_quantile(
|
| 180 |
+
raw, "vllm:time_per_output_token_seconds", 0.5
|
| 181 |
+
)
|
| 182 |
+
or 0.05
|
| 183 |
+
)
|
| 184 |
+
if tpot > 0:
|
| 185 |
+
return running / tpot
|
| 186 |
+
return 0
|
| 187 |
+
|
| 188 |
+
async def collect_samples(self, count: Optional[int] = None) -> None:
|
| 189 |
+
"""
|
| 190 |
+
Collect multiple samples for statistical comparison.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
count: Number of samples to collect
|
| 194 |
+
"""
|
| 195 |
+
if count is None:
|
| 196 |
+
count = self.sample_count
|
| 197 |
+
|
| 198 |
+
self._samples_a.clear()
|
| 199 |
+
self._samples_b.clear()
|
| 200 |
+
|
| 201 |
+
for i in range(count):
|
| 202 |
+
metrics_a, metrics_b = await asyncio.gather(
|
| 203 |
+
self.collect_metrics(self.deployment_a),
|
| 204 |
+
self.collect_metrics(self.deployment_b),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if metrics_a.connected:
|
| 208 |
+
metrics_a.throughput_samples = [metrics_a.tokens_per_second]
|
| 209 |
+
metrics_a.latency_samples = [metrics_a.e2e_latency_ms]
|
| 210 |
+
self._samples_a.append(metrics_a)
|
| 211 |
+
|
| 212 |
+
if metrics_b.connected:
|
| 213 |
+
metrics_b.throughput_samples = [metrics_b.tokens_per_second]
|
| 214 |
+
metrics_b.latency_samples = [metrics_b.e2e_latency_ms]
|
| 215 |
+
self._samples_b.append(metrics_b)
|
| 216 |
+
|
| 217 |
+
# Wait between samples
|
| 218 |
+
if i < count - 1:
|
| 219 |
+
await asyncio.sleep(1)
|
| 220 |
+
|
| 221 |
+
async def compare(self) -> ComparisonResult:
|
| 222 |
+
"""
|
| 223 |
+
Perform comparison between deployments.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
ComparisonResult with comparison data
|
| 227 |
+
"""
|
| 228 |
+
# Collect current metrics
|
| 229 |
+
metrics_a, metrics_b = await asyncio.gather(
|
| 230 |
+
self.collect_metrics(self.deployment_a),
|
| 231 |
+
self.collect_metrics(self.deployment_b),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
result = ComparisonResult(
|
| 235 |
+
deployment_a=metrics_a,
|
| 236 |
+
deployment_b=metrics_b,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Calculate differences
|
| 240 |
+
if metrics_a.tokens_per_second > 0:
|
| 241 |
+
result.throughput_diff_pct = (
|
| 242 |
+
(metrics_b.tokens_per_second - metrics_a.tokens_per_second)
|
| 243 |
+
/ metrics_a.tokens_per_second
|
| 244 |
+
) * 100
|
| 245 |
+
|
| 246 |
+
if metrics_a.ttft_ms > 0:
|
| 247 |
+
result.ttft_diff_pct = (
|
| 248 |
+
(metrics_b.ttft_ms - metrics_a.ttft_ms) / metrics_a.ttft_ms
|
| 249 |
+
) * 100
|
| 250 |
+
|
| 251 |
+
if metrics_a.e2e_latency_ms > 0:
|
| 252 |
+
result.latency_diff_pct = (
|
| 253 |
+
(metrics_b.e2e_latency_ms - metrics_a.e2e_latency_ms)
|
| 254 |
+
/ metrics_a.e2e_latency_ms
|
| 255 |
+
) * 100
|
| 256 |
+
|
| 257 |
+
result.memory_diff_gb = metrics_b.gpu_memory_gb - metrics_a.gpu_memory_gb
|
| 258 |
+
|
| 259 |
+
# Statistical significance (if we have samples)
|
| 260 |
+
if self._samples_a and self._samples_b:
|
| 261 |
+
result = self._add_significance(result)
|
| 262 |
+
|
| 263 |
+
# Generate recommendation
|
| 264 |
+
result.recommendation = self._generate_recommendation(result)
|
| 265 |
+
|
| 266 |
+
return result
|
| 267 |
+
|
| 268 |
+
def _add_significance(self, result: ComparisonResult) -> ComparisonResult:
|
| 269 |
+
"""Add statistical significance tests to result."""
|
| 270 |
+
tps_a = [s.tokens_per_second for s in self._samples_a]
|
| 271 |
+
tps_b = [s.tokens_per_second for s in self._samples_b]
|
| 272 |
+
|
| 273 |
+
lat_a = [s.e2e_latency_ms for s in self._samples_a]
|
| 274 |
+
lat_b = [s.e2e_latency_ms for s in self._samples_b]
|
| 275 |
+
|
| 276 |
+
if len(tps_a) >= 2 and len(tps_b) >= 2:
|
| 277 |
+
try:
|
| 278 |
+
_, p_tps = stats.ttest_ind(tps_a, tps_b)
|
| 279 |
+
result.p_value_throughput = p_tps
|
| 280 |
+
result.throughput_significant = p_tps < 0.05
|
| 281 |
+
except Exception:
|
| 282 |
+
pass
|
| 283 |
+
|
| 284 |
+
if len(lat_a) >= 2 and len(lat_b) >= 2:
|
| 285 |
+
try:
|
| 286 |
+
_, p_lat = stats.ttest_ind(lat_a, lat_b)
|
| 287 |
+
result.p_value_latency = p_lat
|
| 288 |
+
result.latency_significant = p_lat < 0.05
|
| 289 |
+
except Exception:
|
| 290 |
+
pass
|
| 291 |
+
|
| 292 |
+
return result
|
| 293 |
+
|
| 294 |
+
def _generate_recommendation(self, result: ComparisonResult) -> str:
|
| 295 |
+
"""Generate a human-readable recommendation."""
|
| 296 |
+
parts = []
|
| 297 |
+
a_name = self.deployment_a.name
|
| 298 |
+
b_name = self.deployment_b.name
|
| 299 |
+
|
| 300 |
+
# Throughput comparison
|
| 301 |
+
if abs(result.throughput_diff_pct) > 5:
|
| 302 |
+
faster = b_name if result.throughput_diff_pct > 0 else a_name
|
| 303 |
+
diff = abs(result.throughput_diff_pct)
|
| 304 |
+
sig = " (statistically significant)" if result.throughput_significant else ""
|
| 305 |
+
parts.append(f"{faster} has {diff:.1f}% higher throughput{sig}")
|
| 306 |
+
|
| 307 |
+
# Latency comparison
|
| 308 |
+
if abs(result.latency_diff_pct) > 5:
|
| 309 |
+
faster = a_name if result.latency_diff_pct > 0 else b_name
|
| 310 |
+
diff = abs(result.latency_diff_pct)
|
| 311 |
+
sig = " (statistically significant)" if result.latency_significant else ""
|
| 312 |
+
parts.append(f"{faster} has {diff:.1f}% lower latency{sig}")
|
| 313 |
+
|
| 314 |
+
# Memory comparison
|
| 315 |
+
if abs(result.memory_diff_gb) > 1:
|
| 316 |
+
lower = a_name if result.memory_diff_gb > 0 else b_name
|
| 317 |
+
diff = abs(result.memory_diff_gb)
|
| 318 |
+
parts.append(f"{lower} uses {diff:.1f}GB less GPU memory")
|
| 319 |
+
|
| 320 |
+
if not parts:
|
| 321 |
+
return "Both deployments show similar performance"
|
| 322 |
+
|
| 323 |
+
return ". ".join(parts) + "."
|
| 324 |
+
|
| 325 |
+
def get_comparison_table(self, result: ComparisonResult) -> List[Dict[str, Any]]:
|
| 326 |
+
"""
|
| 327 |
+
Generate comparison table data.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
result: Comparison result
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
List of rows for comparison table
|
| 334 |
+
"""
|
| 335 |
+
return [
|
| 336 |
+
{
|
| 337 |
+
"Metric": "Throughput (tok/s)",
|
| 338 |
+
self.deployment_a.name: f"{result.deployment_a.tokens_per_second:.1f}",
|
| 339 |
+
self.deployment_b.name: f"{result.deployment_b.tokens_per_second:.1f}",
|
| 340 |
+
"Diff": f"{result.throughput_diff_pct:+.1f}%",
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"Metric": "TTFT (ms)",
|
| 344 |
+
self.deployment_a.name: f"{result.deployment_a.ttft_ms:.1f}",
|
| 345 |
+
self.deployment_b.name: f"{result.deployment_b.ttft_ms:.1f}",
|
| 346 |
+
"Diff": f"{result.ttft_diff_pct:+.1f}%",
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"Metric": "E2E Latency (ms)",
|
| 350 |
+
self.deployment_a.name: f"{result.deployment_a.e2e_latency_ms:.1f}",
|
| 351 |
+
self.deployment_b.name: f"{result.deployment_b.e2e_latency_ms:.1f}",
|
| 352 |
+
"Diff": f"{result.latency_diff_pct:+.1f}%",
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"Metric": "KV Cache %",
|
| 356 |
+
self.deployment_a.name: f"{result.deployment_a.kv_cache_percent:.1f}",
|
| 357 |
+
self.deployment_b.name: f"{result.deployment_b.kv_cache_percent:.1f}",
|
| 358 |
+
"Diff": "-",
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"Metric": "Batch Size",
|
| 362 |
+
self.deployment_a.name: str(result.deployment_a.batch_size),
|
| 363 |
+
self.deployment_b.name: str(result.deployment_b.batch_size),
|
| 364 |
+
"Diff": "-",
|
| 365 |
+
},
|
| 366 |
+
]
|
services/load_tester.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load testing engine for vLLM endpoints."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
import statistics
|
| 6 |
+
import time
|
| 7 |
+
import uuid
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import List, Optional, Dict, Any, Callable
|
| 11 |
+
from collections import deque
|
| 12 |
+
|
| 13 |
+
import aiohttp
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from storage.database import MetricsDB
|
| 17 |
+
from storage.models import LoadTestResult
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class LoadTestConfig:
|
| 24 |
+
"""Configuration for a load test."""
|
| 25 |
+
target_endpoint: str
|
| 26 |
+
concurrent_users: int = 10
|
| 27 |
+
requests_per_second: float = 5.0
|
| 28 |
+
duration_seconds: int = 60
|
| 29 |
+
prompt: str = "Hello, please write a short story about a robot."
|
| 30 |
+
max_tokens: int = 100
|
| 31 |
+
prompt_length_distribution: str = "fixed" # fixed, realistic, random
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class RequestResult:
|
| 36 |
+
"""Result of a single request."""
|
| 37 |
+
success: bool
|
| 38 |
+
latency_ms: float
|
| 39 |
+
tokens: int
|
| 40 |
+
error: Optional[str] = None
|
| 41 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LoadTester:
|
| 45 |
+
"""Load testing engine for vLLM inference endpoints."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, config: LoadTestConfig, db: Optional[MetricsDB] = None):
|
| 48 |
+
"""
|
| 49 |
+
Initialize load tester.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
config: Load test configuration
|
| 53 |
+
db: Optional database for storing results
|
| 54 |
+
"""
|
| 55 |
+
self.config = config
|
| 56 |
+
self.db = db
|
| 57 |
+
self.running = False
|
| 58 |
+
self._results: List[RequestResult] = []
|
| 59 |
+
self._latency_over_time: deque = deque(maxlen=10000)
|
| 60 |
+
self._progress_callback: Optional[Callable[[Dict], None]] = None
|
| 61 |
+
self._start_time: Optional[float] = None
|
| 62 |
+
|
| 63 |
+
def set_config(self, config: LoadTestConfig) -> None:
|
| 64 |
+
"""Update configuration."""
|
| 65 |
+
self.config = config
|
| 66 |
+
|
| 67 |
+
def on_progress(self, callback: Callable[[Dict], None]) -> None:
|
| 68 |
+
"""Register progress callback."""
|
| 69 |
+
self._progress_callback = callback
|
| 70 |
+
|
| 71 |
+
async def run(self) -> LoadTestResult:
|
| 72 |
+
"""
|
| 73 |
+
Run the load test.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
LoadTestResult with test results
|
| 77 |
+
"""
|
| 78 |
+
self.running = True
|
| 79 |
+
self._results = []
|
| 80 |
+
self._latency_over_time.clear()
|
| 81 |
+
self._start_time = time.time()
|
| 82 |
+
|
| 83 |
+
test_id = str(uuid.uuid4())[:8]
|
| 84 |
+
|
| 85 |
+
logger.info(
|
| 86 |
+
f"Starting load test {test_id}: "
|
| 87 |
+
f"{self.config.concurrent_users} users, "
|
| 88 |
+
f"{self.config.requests_per_second} RPS, "
|
| 89 |
+
f"{self.config.duration_seconds}s"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Create semaphore for concurrency control
|
| 93 |
+
semaphore = asyncio.Semaphore(self.config.concurrent_users)
|
| 94 |
+
|
| 95 |
+
# Calculate request interval
|
| 96 |
+
interval = 1.0 / self.config.requests_per_second
|
| 97 |
+
|
| 98 |
+
# Generate load
|
| 99 |
+
tasks = []
|
| 100 |
+
end_time = time.time() + self.config.duration_seconds
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
while time.time() < end_time and self.running:
|
| 104 |
+
async with semaphore:
|
| 105 |
+
task = asyncio.create_task(self._make_request())
|
| 106 |
+
tasks.append(task)
|
| 107 |
+
|
| 108 |
+
# Report progress
|
| 109 |
+
if self._progress_callback:
|
| 110 |
+
self._progress_callback(self._get_progress())
|
| 111 |
+
|
| 112 |
+
await asyncio.sleep(interval)
|
| 113 |
+
|
| 114 |
+
# Wait for remaining tasks
|
| 115 |
+
if tasks:
|
| 116 |
+
await asyncio.gather(*tasks, return_exceptions=True)
|
| 117 |
+
|
| 118 |
+
except asyncio.CancelledError:
|
| 119 |
+
logger.info("Load test cancelled")
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"Load test error: {e}")
|
| 122 |
+
finally:
|
| 123 |
+
self.running = False
|
| 124 |
+
|
| 125 |
+
# Analyze results
|
| 126 |
+
result = self._analyze_results(test_id)
|
| 127 |
+
|
| 128 |
+
# Persist to database
|
| 129 |
+
if self.db:
|
| 130 |
+
try:
|
| 131 |
+
self.db.insert_load_test(result)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(f"Error persisting load test: {e}")
|
| 134 |
+
|
| 135 |
+
return result
|
| 136 |
+
|
| 137 |
+
def stop(self) -> None:
|
| 138 |
+
"""Stop the running load test."""
|
| 139 |
+
self.running = False
|
| 140 |
+
|
| 141 |
+
async def _make_request(self) -> None:
|
| 142 |
+
"""Make a single request to the target endpoint."""
|
| 143 |
+
prompt = self._generate_prompt()
|
| 144 |
+
start = time.perf_counter()
|
| 145 |
+
tokens = 0
|
| 146 |
+
error = None
|
| 147 |
+
success = False
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
async with aiohttp.ClientSession() as session:
|
| 151 |
+
payload = {
|
| 152 |
+
"model": "default",
|
| 153 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 154 |
+
"max_tokens": self.config.max_tokens,
|
| 155 |
+
"stream": False,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
async with session.post(
|
| 159 |
+
f"{self.config.target_endpoint}/v1/chat/completions",
|
| 160 |
+
json=payload,
|
| 161 |
+
timeout=aiohttp.ClientTimeout(total=60),
|
| 162 |
+
) as response:
|
| 163 |
+
if response.status == 200:
|
| 164 |
+
data = await response.json()
|
| 165 |
+
tokens = data.get("usage", {}).get("completion_tokens", 0)
|
| 166 |
+
success = True
|
| 167 |
+
else:
|
| 168 |
+
error = f"HTTP {response.status}"
|
| 169 |
+
|
| 170 |
+
except asyncio.TimeoutError:
|
| 171 |
+
error = "Timeout"
|
| 172 |
+
except Exception as e:
|
| 173 |
+
error = str(e)
|
| 174 |
+
|
| 175 |
+
latency = (time.perf_counter() - start) * 1000
|
| 176 |
+
|
| 177 |
+
result = RequestResult(
|
| 178 |
+
success=success,
|
| 179 |
+
latency_ms=latency,
|
| 180 |
+
tokens=tokens,
|
| 181 |
+
error=error,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self._results.append(result)
|
| 185 |
+
self._latency_over_time.append({
|
| 186 |
+
"time": datetime.now(),
|
| 187 |
+
"latency_ms": latency,
|
| 188 |
+
"success": success,
|
| 189 |
+
})
|
| 190 |
+
|
| 191 |
+
def _generate_prompt(self) -> str:
|
| 192 |
+
"""Generate a prompt based on configuration."""
|
| 193 |
+
if self.config.prompt_length_distribution == "fixed":
|
| 194 |
+
return self.config.prompt
|
| 195 |
+
|
| 196 |
+
if self.config.prompt_length_distribution == "realistic":
|
| 197 |
+
# Simulate realistic prompt length distribution
|
| 198 |
+
prompts = [
|
| 199 |
+
"Hello!",
|
| 200 |
+
"Write a haiku about programming.",
|
| 201 |
+
"Explain quantum computing in simple terms.",
|
| 202 |
+
"Write a detailed technical analysis of transformer architectures and their impact on modern NLP systems.",
|
| 203 |
+
"Compare and contrast the approaches of different programming paradigms including object-oriented, functional, and procedural programming. Provide examples in Python for each.",
|
| 204 |
+
]
|
| 205 |
+
import random
|
| 206 |
+
return random.choice(prompts)
|
| 207 |
+
|
| 208 |
+
if self.config.prompt_length_distribution == "random":
|
| 209 |
+
import random
|
| 210 |
+
words = ["the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog"]
|
| 211 |
+
length = random.randint(5, 100)
|
| 212 |
+
return " ".join(random.choices(words, k=length))
|
| 213 |
+
|
| 214 |
+
return self.config.prompt
|
| 215 |
+
|
| 216 |
+
def _get_progress(self) -> Dict[str, Any]:
|
| 217 |
+
"""Get current progress."""
|
| 218 |
+
if not self._start_time:
|
| 219 |
+
return {}
|
| 220 |
+
|
| 221 |
+
elapsed = time.time() - self._start_time
|
| 222 |
+
successful = sum(1 for r in self._results if r.success)
|
| 223 |
+
latencies = [r.latency_ms for r in self._results if r.success]
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
"elapsed_seconds": elapsed,
|
| 227 |
+
"total_requests": len(self._results),
|
| 228 |
+
"successful_requests": successful,
|
| 229 |
+
"failed_requests": len(self._results) - successful,
|
| 230 |
+
"avg_latency_ms": statistics.mean(latencies) if latencies else 0,
|
| 231 |
+
"current_rps": len(self._results) / elapsed if elapsed > 0 else 0,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def _analyze_results(self, test_id: str) -> LoadTestResult:
|
| 235 |
+
"""Analyze test results."""
|
| 236 |
+
successful = [r for r in self._results if r.success]
|
| 237 |
+
failed = [r for r in self._results if not r.success]
|
| 238 |
+
|
| 239 |
+
latencies = [r.latency_ms for r in successful]
|
| 240 |
+
|
| 241 |
+
if not latencies:
|
| 242 |
+
return LoadTestResult(
|
| 243 |
+
test_id=test_id,
|
| 244 |
+
target_endpoint=self.config.target_endpoint,
|
| 245 |
+
concurrent_users=self.config.concurrent_users,
|
| 246 |
+
requests_per_second=self.config.requests_per_second,
|
| 247 |
+
duration_seconds=self.config.duration_seconds,
|
| 248 |
+
total_requests=len(self._results),
|
| 249 |
+
successful_requests=0,
|
| 250 |
+
failed_requests=len(failed),
|
| 251 |
+
avg_latency_ms=0,
|
| 252 |
+
p50_latency_ms=0,
|
| 253 |
+
p95_latency_ms=0,
|
| 254 |
+
p99_latency_ms=0,
|
| 255 |
+
throughput_rps=0,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Calculate percentiles
|
| 259 |
+
sorted_latencies = sorted(latencies)
|
| 260 |
+
n = len(sorted_latencies)
|
| 261 |
+
|
| 262 |
+
p50 = sorted_latencies[int(n * 0.50)]
|
| 263 |
+
p95 = sorted_latencies[int(n * 0.95)]
|
| 264 |
+
p99 = sorted_latencies[min(int(n * 0.99), n - 1)]
|
| 265 |
+
|
| 266 |
+
# Calculate throughput
|
| 267 |
+
duration = self.config.duration_seconds
|
| 268 |
+
if self._start_time:
|
| 269 |
+
duration = time.time() - self._start_time
|
| 270 |
+
|
| 271 |
+
throughput = len(successful) / duration if duration > 0 else 0
|
| 272 |
+
|
| 273 |
+
# Detect saturation point
|
| 274 |
+
saturation = self._find_saturation_point()
|
| 275 |
+
|
| 276 |
+
return LoadTestResult(
|
| 277 |
+
test_id=test_id,
|
| 278 |
+
target_endpoint=self.config.target_endpoint,
|
| 279 |
+
concurrent_users=self.config.concurrent_users,
|
| 280 |
+
requests_per_second=self.config.requests_per_second,
|
| 281 |
+
duration_seconds=self.config.duration_seconds,
|
| 282 |
+
total_requests=len(self._results),
|
| 283 |
+
successful_requests=len(successful),
|
| 284 |
+
failed_requests=len(failed),
|
| 285 |
+
avg_latency_ms=statistics.mean(latencies),
|
| 286 |
+
p50_latency_ms=p50,
|
| 287 |
+
p95_latency_ms=p95,
|
| 288 |
+
p99_latency_ms=p99,
|
| 289 |
+
throughput_rps=throughput,
|
| 290 |
+
saturation_point=saturation,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
def _find_saturation_point(self) -> Optional[float]:
|
| 294 |
+
"""
|
| 295 |
+
Find the point where latency starts increasing dramatically.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Request rate at saturation point, or None if not found
|
| 299 |
+
"""
|
| 300 |
+
if len(self._latency_over_time) < 20:
|
| 301 |
+
return None
|
| 302 |
+
|
| 303 |
+
# Group latencies by time buckets
|
| 304 |
+
latencies = list(self._latency_over_time)
|
| 305 |
+
bucket_size = len(latencies) // 10
|
| 306 |
+
|
| 307 |
+
bucket_avgs = []
|
| 308 |
+
for i in range(0, len(latencies), bucket_size):
|
| 309 |
+
bucket = latencies[i : i + bucket_size]
|
| 310 |
+
if bucket:
|
| 311 |
+
avg = statistics.mean(r["latency_ms"] for r in bucket)
|
| 312 |
+
bucket_avgs.append(avg)
|
| 313 |
+
|
| 314 |
+
if len(bucket_avgs) < 3:
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
# Look for significant increase (2x)
|
| 318 |
+
baseline = bucket_avgs[0]
|
| 319 |
+
for i, avg in enumerate(bucket_avgs):
|
| 320 |
+
if avg > baseline * 2:
|
| 321 |
+
# Estimate RPS at this point
|
| 322 |
+
elapsed = self.config.duration_seconds * (i / len(bucket_avgs))
|
| 323 |
+
return len(self._results) / elapsed if elapsed > 0 else None
|
| 324 |
+
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
def get_latency_timeseries(self) -> List[Dict[str, Any]]:
|
| 328 |
+
"""
|
| 329 |
+
Get latency over time for charting.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
List of {time, latency_ms} points
|
| 333 |
+
"""
|
| 334 |
+
return [
|
| 335 |
+
{"time": p["time"], "latency_ms": p["latency_ms"]}
|
| 336 |
+
for p in self._latency_over_time
|
| 337 |
+
]
|
| 338 |
+
|
| 339 |
+
def get_latency_histogram(self, bins: int = 20) -> Dict[str, Any]:
|
| 340 |
+
"""
|
| 341 |
+
Get latency histogram data.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
bins: Number of histogram bins
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Dictionary with bin edges and counts
|
| 348 |
+
"""
|
| 349 |
+
latencies = [r.latency_ms for r in self._results if r.success]
|
| 350 |
+
|
| 351 |
+
if not latencies:
|
| 352 |
+
return {"bins": [], "counts": []}
|
| 353 |
+
|
| 354 |
+
counts, edges = np.histogram(latencies, bins=bins)
|
| 355 |
+
|
| 356 |
+
return {
|
| 357 |
+
"bins": [(edges[i] + edges[i + 1]) / 2 for i in range(len(counts))],
|
| 358 |
+
"counts": counts.tolist(),
|
| 359 |
+
}
|
services/request_tracer.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Request tracing and latency analysis."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import uuid
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Dict, List, Optional, Any
|
| 8 |
+
from collections import deque
|
| 9 |
+
import statistics
|
| 10 |
+
|
| 11 |
+
from storage.database import MetricsDB
|
| 12 |
+
from storage.models import RequestTrace
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class LatencyBreakdown:
|
| 19 |
+
"""Breakdown of request latency by phase."""
|
| 20 |
+
queue_ms: float
|
| 21 |
+
prefill_ms: float
|
| 22 |
+
decode_ms: float
|
| 23 |
+
total_ms: float
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def as_dict(self) -> Dict[str, float]:
|
| 27 |
+
return {
|
| 28 |
+
"queue": self.queue_ms,
|
| 29 |
+
"prefill": self.prefill_ms,
|
| 30 |
+
"decode": self.decode_ms,
|
| 31 |
+
"total": self.total_ms,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class TraceCorrelation:
|
| 37 |
+
"""Correlation analysis for a trace."""
|
| 38 |
+
memory_pressure: bool
|
| 39 |
+
likely_cause: str
|
| 40 |
+
memory_delta_gb: float
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RequestTracer:
|
| 44 |
+
"""Tracks and analyzes request latency."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, db: Optional[MetricsDB] = None, p95_window: int = 100):
|
| 47 |
+
"""
|
| 48 |
+
Initialize request tracer.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
db: Optional database for persisting traces
|
| 52 |
+
p95_window: Number of recent requests for P95 calculation
|
| 53 |
+
"""
|
| 54 |
+
self.db = db
|
| 55 |
+
self._traces: deque = deque(maxlen=1000)
|
| 56 |
+
self._latency_window: deque = deque(maxlen=p95_window)
|
| 57 |
+
self._baseline_p95: Optional[float] = None
|
| 58 |
+
self._slow_threshold_ms: Optional[float] = None
|
| 59 |
+
|
| 60 |
+
def record_trace(
|
| 61 |
+
self,
|
| 62 |
+
request_id: Optional[str] = None,
|
| 63 |
+
prompt_tokens: int = 0,
|
| 64 |
+
output_tokens: int = 0,
|
| 65 |
+
queue_time_ms: float = 0,
|
| 66 |
+
prefill_time_ms: float = 0,
|
| 67 |
+
decode_time_ms: float = 0,
|
| 68 |
+
total_time_ms: Optional[float] = None,
|
| 69 |
+
gpu_memory_start: float = 0,
|
| 70 |
+
gpu_memory_end: float = 0,
|
| 71 |
+
) -> RequestTrace:
|
| 72 |
+
"""
|
| 73 |
+
Record a request trace.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
request_id: Unique request identifier
|
| 77 |
+
prompt_tokens: Number of prompt tokens
|
| 78 |
+
output_tokens: Number of output tokens
|
| 79 |
+
queue_time_ms: Time spent in queue
|
| 80 |
+
prefill_time_ms: Time for prefill/prompt processing
|
| 81 |
+
decode_time_ms: Time for token generation
|
| 82 |
+
total_time_ms: Total end-to-end time
|
| 83 |
+
gpu_memory_start: GPU memory at request start
|
| 84 |
+
gpu_memory_end: GPU memory at request end
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Created RequestTrace
|
| 88 |
+
"""
|
| 89 |
+
if request_id is None:
|
| 90 |
+
request_id = str(uuid.uuid4())[:8]
|
| 91 |
+
|
| 92 |
+
if total_time_ms is None:
|
| 93 |
+
total_time_ms = queue_time_ms + prefill_time_ms + decode_time_ms
|
| 94 |
+
|
| 95 |
+
# Calculate tokens per second
|
| 96 |
+
tokens_per_sec = 0
|
| 97 |
+
if decode_time_ms > 0:
|
| 98 |
+
tokens_per_sec = (output_tokens / decode_time_ms) * 1000
|
| 99 |
+
|
| 100 |
+
# Determine if slow
|
| 101 |
+
is_slow = False
|
| 102 |
+
if self._slow_threshold_ms and total_time_ms > self._slow_threshold_ms:
|
| 103 |
+
is_slow = True
|
| 104 |
+
|
| 105 |
+
trace = RequestTrace(
|
| 106 |
+
request_id=request_id,
|
| 107 |
+
prompt_tokens=prompt_tokens,
|
| 108 |
+
output_tokens=output_tokens,
|
| 109 |
+
queue_time_ms=queue_time_ms,
|
| 110 |
+
prefill_time_ms=prefill_time_ms,
|
| 111 |
+
decode_time_ms=decode_time_ms,
|
| 112 |
+
total_time_ms=total_time_ms,
|
| 113 |
+
tokens_per_second=tokens_per_sec,
|
| 114 |
+
gpu_memory_at_start=gpu_memory_start,
|
| 115 |
+
gpu_memory_at_end=gpu_memory_end,
|
| 116 |
+
is_slow=is_slow,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Store in memory
|
| 120 |
+
self._traces.append(trace)
|
| 121 |
+
self._latency_window.append(total_time_ms)
|
| 122 |
+
|
| 123 |
+
# Update P95 baseline
|
| 124 |
+
self._update_baseline()
|
| 125 |
+
|
| 126 |
+
# Persist to database
|
| 127 |
+
if self.db:
|
| 128 |
+
try:
|
| 129 |
+
self.db.insert_trace(trace)
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Error persisting trace: {e}")
|
| 132 |
+
|
| 133 |
+
# Log slow requests
|
| 134 |
+
if is_slow:
|
| 135 |
+
logger.warning(
|
| 136 |
+
f"Slow request {request_id}: {total_time_ms:.1f}ms "
|
| 137 |
+
f"(threshold: {self._slow_threshold_ms:.1f}ms)"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return trace
|
| 141 |
+
|
| 142 |
+
def _update_baseline(self) -> None:
|
| 143 |
+
"""Update P95 baseline from recent requests."""
|
| 144 |
+
if len(self._latency_window) >= 10:
|
| 145 |
+
sorted_latencies = sorted(self._latency_window)
|
| 146 |
+
p95_idx = int(len(sorted_latencies) * 0.95)
|
| 147 |
+
self._baseline_p95 = sorted_latencies[p95_idx]
|
| 148 |
+
# Set slow threshold at 1.5x P95
|
| 149 |
+
self._slow_threshold_ms = self._baseline_p95 * 1.5
|
| 150 |
+
|
| 151 |
+
def get_recent_traces(
|
| 152 |
+
self, limit: int = 100, slow_only: bool = False
|
| 153 |
+
) -> List[RequestTrace]:
|
| 154 |
+
"""
|
| 155 |
+
Get recent traces.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
limit: Maximum number of traces
|
| 159 |
+
slow_only: Only return slow requests
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
List of RequestTrace objects
|
| 163 |
+
"""
|
| 164 |
+
traces = list(self._traces)
|
| 165 |
+
|
| 166 |
+
if slow_only:
|
| 167 |
+
traces = [t for t in traces if t.is_slow]
|
| 168 |
+
|
| 169 |
+
return traces[-limit:]
|
| 170 |
+
|
| 171 |
+
def get_latency_breakdown(self) -> LatencyBreakdown:
|
| 172 |
+
"""
|
| 173 |
+
Get average latency breakdown.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
LatencyBreakdown with average times
|
| 177 |
+
"""
|
| 178 |
+
if not self._traces:
|
| 179 |
+
return LatencyBreakdown(0, 0, 0, 0)
|
| 180 |
+
|
| 181 |
+
recent = list(self._traces)[-100:]
|
| 182 |
+
|
| 183 |
+
return LatencyBreakdown(
|
| 184 |
+
queue_ms=statistics.mean(t.queue_time_ms for t in recent),
|
| 185 |
+
prefill_ms=statistics.mean(t.prefill_time_ms for t in recent),
|
| 186 |
+
decode_ms=statistics.mean(t.decode_time_ms for t in recent),
|
| 187 |
+
total_ms=statistics.mean(t.total_time_ms for t in recent),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def correlate_with_gpu_pressure(self, trace: RequestTrace) -> TraceCorrelation:
|
| 191 |
+
"""
|
| 192 |
+
Correlate trace latency with GPU memory pressure.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
trace: Request trace to analyze
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
TraceCorrelation analysis
|
| 199 |
+
"""
|
| 200 |
+
memory_delta = trace.gpu_memory_at_end - trace.gpu_memory_at_start
|
| 201 |
+
|
| 202 |
+
# Determine likely cause based on patterns
|
| 203 |
+
if memory_delta > 2.0:
|
| 204 |
+
cause = "batch_contention"
|
| 205 |
+
elif trace.queue_time_ms > trace.total_time_ms * 0.3:
|
| 206 |
+
cause = "queue_congestion"
|
| 207 |
+
elif trace.prefill_time_ms > trace.decode_time_ms * 2:
|
| 208 |
+
cause = "long_prompt"
|
| 209 |
+
else:
|
| 210 |
+
cause = "normal"
|
| 211 |
+
|
| 212 |
+
return TraceCorrelation(
|
| 213 |
+
memory_pressure=memory_delta > 1.0,
|
| 214 |
+
likely_cause=cause,
|
| 215 |
+
memory_delta_gb=memory_delta,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def get_percentiles(self) -> Dict[str, float]:
|
| 219 |
+
"""
|
| 220 |
+
Get latency percentiles.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Dictionary with P50, P95, P99 values
|
| 224 |
+
"""
|
| 225 |
+
if not self._latency_window:
|
| 226 |
+
return {"p50": 0, "p95": 0, "p99": 0}
|
| 227 |
+
|
| 228 |
+
sorted_latencies = sorted(self._latency_window)
|
| 229 |
+
n = len(sorted_latencies)
|
| 230 |
+
|
| 231 |
+
return {
|
| 232 |
+
"p50": sorted_latencies[int(n * 0.50)],
|
| 233 |
+
"p95": sorted_latencies[int(n * 0.95)],
|
| 234 |
+
"p99": sorted_latencies[min(int(n * 0.99), n - 1)],
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 238 |
+
"""
|
| 239 |
+
Get comprehensive statistics.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Dictionary with various stats
|
| 243 |
+
"""
|
| 244 |
+
if not self._traces:
|
| 245 |
+
return {
|
| 246 |
+
"total_requests": 0,
|
| 247 |
+
"slow_requests": 0,
|
| 248 |
+
"avg_latency_ms": 0,
|
| 249 |
+
"percentiles": {"p50": 0, "p95": 0, "p99": 0},
|
| 250 |
+
"breakdown": {"queue": 0, "prefill": 0, "decode": 0},
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
traces = list(self._traces)
|
| 254 |
+
slow_count = sum(1 for t in traces if t.is_slow)
|
| 255 |
+
breakdown = self.get_latency_breakdown()
|
| 256 |
+
|
| 257 |
+
return {
|
| 258 |
+
"total_requests": len(traces),
|
| 259 |
+
"slow_requests": slow_count,
|
| 260 |
+
"slow_rate_percent": (slow_count / len(traces)) * 100,
|
| 261 |
+
"avg_latency_ms": breakdown.total_ms,
|
| 262 |
+
"percentiles": self.get_percentiles(),
|
| 263 |
+
"breakdown": breakdown.as_dict,
|
| 264 |
+
"baseline_p95": self._baseline_p95,
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
def clear(self) -> None:
|
| 268 |
+
"""Clear all traces."""
|
| 269 |
+
self._traces.clear()
|
| 270 |
+
self._latency_window.clear()
|
| 271 |
+
self._baseline_p95 = None
|
| 272 |
+
self._slow_threshold_ms = None
|
storage/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Storage layer for persistent metrics and traces."""
|
| 2 |
+
|
| 3 |
+
from .database import MetricsDB
|
| 4 |
+
from .models import MetricRecord, AlertRecord, RequestTrace
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"MetricsDB",
|
| 8 |
+
"MetricRecord",
|
| 9 |
+
"AlertRecord",
|
| 10 |
+
"RequestTrace",
|
| 11 |
+
]
|
storage/database.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLite database operations for metrics storage."""
|
| 2 |
+
|
| 3 |
+
import sqlite3
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional, Dict, Any
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
|
| 11 |
+
from .models import MetricRecord, AlertRecord, RequestTrace, LoadTestResult
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MetricsDB:
|
| 17 |
+
"""SQLite database for storing metrics, alerts, and traces."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, db_path: str = "data/metrics.db"):
|
| 20 |
+
"""
|
| 21 |
+
Initialize database connection.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
db_path: Path to SQLite database file
|
| 25 |
+
"""
|
| 26 |
+
self.db_path = db_path
|
| 27 |
+
self._ensure_directory()
|
| 28 |
+
self._init_schema()
|
| 29 |
+
|
| 30 |
+
def _ensure_directory(self) -> None:
|
| 31 |
+
"""Ensure the database directory exists."""
|
| 32 |
+
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
@contextmanager
|
| 35 |
+
def _get_connection(self):
|
| 36 |
+
"""Get a database connection with context manager."""
|
| 37 |
+
conn = sqlite3.connect(self.db_path)
|
| 38 |
+
conn.row_factory = sqlite3.Row
|
| 39 |
+
try:
|
| 40 |
+
yield conn
|
| 41 |
+
conn.commit()
|
| 42 |
+
except Exception:
|
| 43 |
+
conn.rollback()
|
| 44 |
+
raise
|
| 45 |
+
finally:
|
| 46 |
+
conn.close()
|
| 47 |
+
|
| 48 |
+
def _init_schema(self) -> None:
|
| 49 |
+
"""Initialize database schema."""
|
| 50 |
+
with self._get_connection() as conn:
|
| 51 |
+
cursor = conn.cursor()
|
| 52 |
+
|
| 53 |
+
# Metrics table
|
| 54 |
+
cursor.execute("""
|
| 55 |
+
CREATE TABLE IF NOT EXISTS metrics (
|
| 56 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 57 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
| 58 |
+
metric_name TEXT NOT NULL,
|
| 59 |
+
value REAL NOT NULL,
|
| 60 |
+
labels TEXT
|
| 61 |
+
)
|
| 62 |
+
""")
|
| 63 |
+
|
| 64 |
+
# Indexes for metrics
|
| 65 |
+
cursor.execute("""
|
| 66 |
+
CREATE INDEX IF NOT EXISTS idx_metrics_timestamp
|
| 67 |
+
ON metrics(timestamp)
|
| 68 |
+
""")
|
| 69 |
+
cursor.execute("""
|
| 70 |
+
CREATE INDEX IF NOT EXISTS idx_metrics_name_time
|
| 71 |
+
ON metrics(metric_name, timestamp)
|
| 72 |
+
""")
|
| 73 |
+
|
| 74 |
+
# Alerts table
|
| 75 |
+
cursor.execute("""
|
| 76 |
+
CREATE TABLE IF NOT EXISTS alerts (
|
| 77 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 78 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
| 79 |
+
rule_name TEXT,
|
| 80 |
+
severity TEXT,
|
| 81 |
+
metric_name TEXT,
|
| 82 |
+
value REAL,
|
| 83 |
+
threshold REAL,
|
| 84 |
+
message TEXT,
|
| 85 |
+
resolved_at DATETIME
|
| 86 |
+
)
|
| 87 |
+
""")
|
| 88 |
+
|
| 89 |
+
# Request traces table
|
| 90 |
+
cursor.execute("""
|
| 91 |
+
CREATE TABLE IF NOT EXISTS request_traces (
|
| 92 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 93 |
+
request_id TEXT UNIQUE,
|
| 94 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
| 95 |
+
prompt_tokens INTEGER,
|
| 96 |
+
output_tokens INTEGER,
|
| 97 |
+
queue_time_ms REAL,
|
| 98 |
+
prefill_time_ms REAL,
|
| 99 |
+
decode_time_ms REAL,
|
| 100 |
+
total_time_ms REAL,
|
| 101 |
+
tokens_per_second REAL,
|
| 102 |
+
is_slow BOOLEAN
|
| 103 |
+
)
|
| 104 |
+
""")
|
| 105 |
+
|
| 106 |
+
# Load test results table
|
| 107 |
+
cursor.execute("""
|
| 108 |
+
CREATE TABLE IF NOT EXISTS load_tests (
|
| 109 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 110 |
+
test_id TEXT UNIQUE,
|
| 111 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
| 112 |
+
target_endpoint TEXT,
|
| 113 |
+
concurrent_users INTEGER,
|
| 114 |
+
requests_per_second REAL,
|
| 115 |
+
duration_seconds INTEGER,
|
| 116 |
+
total_requests INTEGER,
|
| 117 |
+
successful_requests INTEGER,
|
| 118 |
+
failed_requests INTEGER,
|
| 119 |
+
avg_latency_ms REAL,
|
| 120 |
+
p50_latency_ms REAL,
|
| 121 |
+
p95_latency_ms REAL,
|
| 122 |
+
p99_latency_ms REAL,
|
| 123 |
+
throughput_rps REAL,
|
| 124 |
+
saturation_point REAL
|
| 125 |
+
)
|
| 126 |
+
""")
|
| 127 |
+
|
| 128 |
+
# Metrics operations
|
| 129 |
+
|
| 130 |
+
def insert_metric(self, record: MetricRecord) -> int:
|
| 131 |
+
"""Insert a metric record."""
|
| 132 |
+
with self._get_connection() as conn:
|
| 133 |
+
cursor = conn.cursor()
|
| 134 |
+
cursor.execute(
|
| 135 |
+
"""
|
| 136 |
+
INSERT INTO metrics (timestamp, metric_name, value, labels)
|
| 137 |
+
VALUES (?, ?, ?, ?)
|
| 138 |
+
""",
|
| 139 |
+
(
|
| 140 |
+
record.timestamp.isoformat(),
|
| 141 |
+
record.metric_name,
|
| 142 |
+
record.value,
|
| 143 |
+
json.dumps(record.labels) if record.labels else None,
|
| 144 |
+
),
|
| 145 |
+
)
|
| 146 |
+
return cursor.lastrowid
|
| 147 |
+
|
| 148 |
+
def insert_metrics_batch(self, records: List[MetricRecord]) -> None:
|
| 149 |
+
"""Insert multiple metric records efficiently."""
|
| 150 |
+
with self._get_connection() as conn:
|
| 151 |
+
cursor = conn.cursor()
|
| 152 |
+
cursor.executemany(
|
| 153 |
+
"""
|
| 154 |
+
INSERT INTO metrics (timestamp, metric_name, value, labels)
|
| 155 |
+
VALUES (?, ?, ?, ?)
|
| 156 |
+
""",
|
| 157 |
+
[
|
| 158 |
+
(
|
| 159 |
+
r.timestamp.isoformat(),
|
| 160 |
+
r.metric_name,
|
| 161 |
+
r.value,
|
| 162 |
+
json.dumps(r.labels) if r.labels else None,
|
| 163 |
+
)
|
| 164 |
+
for r in records
|
| 165 |
+
],
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def query_metrics(
|
| 169 |
+
self,
|
| 170 |
+
metric_name: str,
|
| 171 |
+
start: datetime,
|
| 172 |
+
end: datetime,
|
| 173 |
+
labels: Optional[Dict[str, str]] = None,
|
| 174 |
+
) -> List[MetricRecord]:
|
| 175 |
+
"""Query metrics by name and time range."""
|
| 176 |
+
with self._get_connection() as conn:
|
| 177 |
+
cursor = conn.cursor()
|
| 178 |
+
cursor.execute(
|
| 179 |
+
"""
|
| 180 |
+
SELECT id, timestamp, metric_name, value, labels
|
| 181 |
+
FROM metrics
|
| 182 |
+
WHERE metric_name = ? AND timestamp BETWEEN ? AND ?
|
| 183 |
+
ORDER BY timestamp
|
| 184 |
+
""",
|
| 185 |
+
(metric_name, start.isoformat(), end.isoformat()),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
records = []
|
| 189 |
+
for row in cursor.fetchall():
|
| 190 |
+
record = MetricRecord.from_row(tuple(row))
|
| 191 |
+
if labels:
|
| 192 |
+
# Filter by labels if specified
|
| 193 |
+
if all(record.labels.get(k) == v for k, v in labels.items()):
|
| 194 |
+
records.append(record)
|
| 195 |
+
else:
|
| 196 |
+
records.append(record)
|
| 197 |
+
|
| 198 |
+
return records
|
| 199 |
+
|
| 200 |
+
def query_aggregated(
|
| 201 |
+
self,
|
| 202 |
+
metric_name: str,
|
| 203 |
+
start: datetime,
|
| 204 |
+
end: datetime,
|
| 205 |
+
aggregation: str = "avg",
|
| 206 |
+
bucket_minutes: int = 1,
|
| 207 |
+
) -> List[Dict[str, Any]]:
|
| 208 |
+
"""Query metrics with time bucketing and aggregation."""
|
| 209 |
+
agg_func = {
|
| 210 |
+
"avg": "AVG",
|
| 211 |
+
"max": "MAX",
|
| 212 |
+
"min": "MIN",
|
| 213 |
+
"sum": "SUM",
|
| 214 |
+
"count": "COUNT",
|
| 215 |
+
}.get(aggregation, "AVG")
|
| 216 |
+
|
| 217 |
+
with self._get_connection() as conn:
|
| 218 |
+
cursor = conn.cursor()
|
| 219 |
+
cursor.execute(
|
| 220 |
+
f"""
|
| 221 |
+
SELECT
|
| 222 |
+
datetime(
|
| 223 |
+
strftime('%Y-%m-%d %H:', timestamp) ||
|
| 224 |
+
printf('%02d', (CAST(strftime('%M', timestamp) AS INTEGER) / {bucket_minutes}) * {bucket_minutes}) ||
|
| 225 |
+
':00'
|
| 226 |
+
) as bucket,
|
| 227 |
+
{agg_func}(value) as value
|
| 228 |
+
FROM metrics
|
| 229 |
+
WHERE metric_name = ? AND timestamp BETWEEN ? AND ?
|
| 230 |
+
GROUP BY bucket
|
| 231 |
+
ORDER BY bucket
|
| 232 |
+
""",
|
| 233 |
+
(metric_name, start.isoformat(), end.isoformat()),
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
return [
|
| 237 |
+
{"time": row["bucket"], "value": row["value"]}
|
| 238 |
+
for row in cursor.fetchall()
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
# Alert operations
|
| 242 |
+
|
| 243 |
+
def insert_alert(self, alert: AlertRecord) -> int:
|
| 244 |
+
"""Insert an alert record."""
|
| 245 |
+
with self._get_connection() as conn:
|
| 246 |
+
cursor = conn.cursor()
|
| 247 |
+
cursor.execute(
|
| 248 |
+
"""
|
| 249 |
+
INSERT INTO alerts
|
| 250 |
+
(timestamp, rule_name, severity, metric_name, value, threshold, message, resolved_at)
|
| 251 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 252 |
+
""",
|
| 253 |
+
(
|
| 254 |
+
alert.timestamp.isoformat(),
|
| 255 |
+
alert.rule_name,
|
| 256 |
+
alert.severity,
|
| 257 |
+
alert.metric_name,
|
| 258 |
+
alert.value,
|
| 259 |
+
alert.threshold,
|
| 260 |
+
alert.message,
|
| 261 |
+
alert.resolved_at.isoformat() if alert.resolved_at else None,
|
| 262 |
+
),
|
| 263 |
+
)
|
| 264 |
+
return cursor.lastrowid
|
| 265 |
+
|
| 266 |
+
def get_active_alerts(self) -> List[AlertRecord]:
|
| 267 |
+
"""Get all unresolved alerts."""
|
| 268 |
+
with self._get_connection() as conn:
|
| 269 |
+
cursor = conn.cursor()
|
| 270 |
+
cursor.execute(
|
| 271 |
+
"""
|
| 272 |
+
SELECT id, timestamp, rule_name, severity, metric_name, value, threshold, message, resolved_at
|
| 273 |
+
FROM alerts
|
| 274 |
+
WHERE resolved_at IS NULL
|
| 275 |
+
ORDER BY timestamp DESC
|
| 276 |
+
"""
|
| 277 |
+
)
|
| 278 |
+
return [AlertRecord.from_row(tuple(row)) for row in cursor.fetchall()]
|
| 279 |
+
|
| 280 |
+
def get_recent_alerts(self, limit: int = 100) -> List[AlertRecord]:
|
| 281 |
+
"""Get recent alerts."""
|
| 282 |
+
with self._get_connection() as conn:
|
| 283 |
+
cursor = conn.cursor()
|
| 284 |
+
cursor.execute(
|
| 285 |
+
"""
|
| 286 |
+
SELECT id, timestamp, rule_name, severity, metric_name, value, threshold, message, resolved_at
|
| 287 |
+
FROM alerts
|
| 288 |
+
ORDER BY timestamp DESC
|
| 289 |
+
LIMIT ?
|
| 290 |
+
""",
|
| 291 |
+
(limit,),
|
| 292 |
+
)
|
| 293 |
+
return [AlertRecord.from_row(tuple(row)) for row in cursor.fetchall()]
|
| 294 |
+
|
| 295 |
+
def resolve_alert(self, alert_id: int) -> None:
|
| 296 |
+
"""Mark an alert as resolved."""
|
| 297 |
+
with self._get_connection() as conn:
|
| 298 |
+
cursor = conn.cursor()
|
| 299 |
+
cursor.execute(
|
| 300 |
+
"""
|
| 301 |
+
UPDATE alerts SET resolved_at = ? WHERE id = ?
|
| 302 |
+
""",
|
| 303 |
+
(datetime.now().isoformat(), alert_id),
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Request trace operations
|
| 307 |
+
|
| 308 |
+
def insert_trace(self, trace: RequestTrace) -> int:
|
| 309 |
+
"""Insert a request trace."""
|
| 310 |
+
with self._get_connection() as conn:
|
| 311 |
+
cursor = conn.cursor()
|
| 312 |
+
cursor.execute(
|
| 313 |
+
"""
|
| 314 |
+
INSERT OR REPLACE INTO request_traces
|
| 315 |
+
(request_id, timestamp, prompt_tokens, output_tokens,
|
| 316 |
+
queue_time_ms, prefill_time_ms, decode_time_ms, total_time_ms,
|
| 317 |
+
tokens_per_second, is_slow)
|
| 318 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 319 |
+
""",
|
| 320 |
+
(
|
| 321 |
+
trace.request_id,
|
| 322 |
+
trace.timestamp.isoformat(),
|
| 323 |
+
trace.prompt_tokens,
|
| 324 |
+
trace.output_tokens,
|
| 325 |
+
trace.queue_time_ms,
|
| 326 |
+
trace.prefill_time_ms,
|
| 327 |
+
trace.decode_time_ms,
|
| 328 |
+
trace.total_time_ms,
|
| 329 |
+
trace.tokens_per_second,
|
| 330 |
+
trace.is_slow,
|
| 331 |
+
),
|
| 332 |
+
)
|
| 333 |
+
return cursor.lastrowid
|
| 334 |
+
|
| 335 |
+
def get_recent_traces(
|
| 336 |
+
self, limit: int = 100, slow_only: bool = False
|
| 337 |
+
) -> List[RequestTrace]:
|
| 338 |
+
"""Get recent request traces."""
|
| 339 |
+
with self._get_connection() as conn:
|
| 340 |
+
cursor = conn.cursor()
|
| 341 |
+
query = """
|
| 342 |
+
SELECT id, request_id, timestamp, prompt_tokens, output_tokens,
|
| 343 |
+
queue_time_ms, prefill_time_ms, decode_time_ms, total_time_ms,
|
| 344 |
+
tokens_per_second, is_slow
|
| 345 |
+
FROM request_traces
|
| 346 |
+
"""
|
| 347 |
+
if slow_only:
|
| 348 |
+
query += " WHERE is_slow = 1"
|
| 349 |
+
query += " ORDER BY timestamp DESC LIMIT ?"
|
| 350 |
+
|
| 351 |
+
cursor.execute(query, (limit,))
|
| 352 |
+
return [RequestTrace.from_row(tuple(row)) for row in cursor.fetchall()]
|
| 353 |
+
|
| 354 |
+
def get_trace_stats(self) -> Dict[str, Any]:
|
| 355 |
+
"""Get aggregate statistics for traces."""
|
| 356 |
+
with self._get_connection() as conn:
|
| 357 |
+
cursor = conn.cursor()
|
| 358 |
+
cursor.execute(
|
| 359 |
+
"""
|
| 360 |
+
SELECT
|
| 361 |
+
COUNT(*) as total,
|
| 362 |
+
AVG(total_time_ms) as avg_latency,
|
| 363 |
+
AVG(queue_time_ms) as avg_queue,
|
| 364 |
+
AVG(prefill_time_ms) as avg_prefill,
|
| 365 |
+
AVG(decode_time_ms) as avg_decode,
|
| 366 |
+
SUM(CASE WHEN is_slow THEN 1 ELSE 0 END) as slow_count
|
| 367 |
+
FROM request_traces
|
| 368 |
+
WHERE timestamp > datetime('now', '-1 hour')
|
| 369 |
+
"""
|
| 370 |
+
)
|
| 371 |
+
row = cursor.fetchone()
|
| 372 |
+
return {
|
| 373 |
+
"total_requests": row["total"] or 0,
|
| 374 |
+
"avg_latency_ms": row["avg_latency"] or 0,
|
| 375 |
+
"avg_queue_ms": row["avg_queue"] or 0,
|
| 376 |
+
"avg_prefill_ms": row["avg_prefill"] or 0,
|
| 377 |
+
"avg_decode_ms": row["avg_decode"] or 0,
|
| 378 |
+
"slow_request_count": row["slow_count"] or 0,
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
# Load test operations
|
| 382 |
+
|
| 383 |
+
def insert_load_test(self, result: LoadTestResult) -> int:
|
| 384 |
+
"""Insert a load test result."""
|
| 385 |
+
with self._get_connection() as conn:
|
| 386 |
+
cursor = conn.cursor()
|
| 387 |
+
cursor.execute(
|
| 388 |
+
"""
|
| 389 |
+
INSERT INTO load_tests
|
| 390 |
+
(test_id, timestamp, target_endpoint, concurrent_users,
|
| 391 |
+
requests_per_second, duration_seconds, total_requests,
|
| 392 |
+
successful_requests, failed_requests, avg_latency_ms,
|
| 393 |
+
p50_latency_ms, p95_latency_ms, p99_latency_ms,
|
| 394 |
+
throughput_rps, saturation_point)
|
| 395 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 396 |
+
""",
|
| 397 |
+
(
|
| 398 |
+
result.test_id,
|
| 399 |
+
result.timestamp.isoformat(),
|
| 400 |
+
result.target_endpoint,
|
| 401 |
+
result.concurrent_users,
|
| 402 |
+
result.requests_per_second,
|
| 403 |
+
result.duration_seconds,
|
| 404 |
+
result.total_requests,
|
| 405 |
+
result.successful_requests,
|
| 406 |
+
result.failed_requests,
|
| 407 |
+
result.avg_latency_ms,
|
| 408 |
+
result.p50_latency_ms,
|
| 409 |
+
result.p95_latency_ms,
|
| 410 |
+
result.p99_latency_ms,
|
| 411 |
+
result.throughput_rps,
|
| 412 |
+
result.saturation_point,
|
| 413 |
+
),
|
| 414 |
+
)
|
| 415 |
+
return cursor.lastrowid
|
| 416 |
+
|
| 417 |
+
def get_recent_load_tests(self, limit: int = 10) -> List[Dict[str, Any]]:
|
| 418 |
+
"""Get recent load test results."""
|
| 419 |
+
with self._get_connection() as conn:
|
| 420 |
+
cursor = conn.cursor()
|
| 421 |
+
cursor.execute(
|
| 422 |
+
"""
|
| 423 |
+
SELECT * FROM load_tests
|
| 424 |
+
ORDER BY timestamp DESC
|
| 425 |
+
LIMIT ?
|
| 426 |
+
""",
|
| 427 |
+
(limit,),
|
| 428 |
+
)
|
| 429 |
+
return [dict(row) for row in cursor.fetchall()]
|
| 430 |
+
|
| 431 |
+
# Cleanup operations
|
| 432 |
+
|
| 433 |
+
def cleanup_old_data(self, days: int = 7) -> int:
|
| 434 |
+
"""Remove data older than specified days."""
|
| 435 |
+
cutoff = (datetime.now() - timedelta(days=days)).isoformat()
|
| 436 |
+
|
| 437 |
+
with self._get_connection() as conn:
|
| 438 |
+
cursor = conn.cursor()
|
| 439 |
+
total_deleted = 0
|
| 440 |
+
|
| 441 |
+
for table in ["metrics", "alerts", "request_traces"]:
|
| 442 |
+
cursor.execute(
|
| 443 |
+
f"DELETE FROM {table} WHERE timestamp < ?",
|
| 444 |
+
(cutoff,),
|
| 445 |
+
)
|
| 446 |
+
total_deleted += cursor.rowcount
|
| 447 |
+
|
| 448 |
+
return total_deleted
|
storage/models.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models for storage layer."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Optional, Dict, Any
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class MetricRecord:
|
| 11 |
+
"""A single metric record for storage."""
|
| 12 |
+
metric_name: str
|
| 13 |
+
value: float
|
| 14 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 15 |
+
labels: Dict[str, str] = field(default_factory=dict)
|
| 16 |
+
id: Optional[int] = None
|
| 17 |
+
|
| 18 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 19 |
+
return {
|
| 20 |
+
"id": self.id,
|
| 21 |
+
"metric_name": self.metric_name,
|
| 22 |
+
"value": self.value,
|
| 23 |
+
"timestamp": self.timestamp.isoformat(),
|
| 24 |
+
"labels": self.labels,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def from_row(cls, row: tuple) -> "MetricRecord":
|
| 29 |
+
return cls(
|
| 30 |
+
id=row[0],
|
| 31 |
+
timestamp=datetime.fromisoformat(row[1]),
|
| 32 |
+
metric_name=row[2],
|
| 33 |
+
value=row[3],
|
| 34 |
+
labels=json.loads(row[4]) if row[4] else {},
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class AlertRecord:
|
| 40 |
+
"""An alert record for storage."""
|
| 41 |
+
rule_name: str
|
| 42 |
+
severity: str
|
| 43 |
+
metric_name: str
|
| 44 |
+
value: float
|
| 45 |
+
threshold: float
|
| 46 |
+
message: str
|
| 47 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 48 |
+
resolved_at: Optional[datetime] = None
|
| 49 |
+
id: Optional[int] = None
|
| 50 |
+
|
| 51 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 52 |
+
return {
|
| 53 |
+
"id": self.id,
|
| 54 |
+
"rule_name": self.rule_name,
|
| 55 |
+
"severity": self.severity,
|
| 56 |
+
"metric_name": self.metric_name,
|
| 57 |
+
"value": self.value,
|
| 58 |
+
"threshold": self.threshold,
|
| 59 |
+
"message": self.message,
|
| 60 |
+
"timestamp": self.timestamp.isoformat(),
|
| 61 |
+
"resolved_at": self.resolved_at.isoformat() if self.resolved_at else None,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def from_row(cls, row: tuple) -> "AlertRecord":
|
| 66 |
+
return cls(
|
| 67 |
+
id=row[0],
|
| 68 |
+
timestamp=datetime.fromisoformat(row[1]),
|
| 69 |
+
rule_name=row[2],
|
| 70 |
+
severity=row[3],
|
| 71 |
+
metric_name=row[4],
|
| 72 |
+
value=row[5],
|
| 73 |
+
threshold=row[6],
|
| 74 |
+
message=row[7] if len(row) > 7 else "",
|
| 75 |
+
resolved_at=datetime.fromisoformat(row[8]) if len(row) > 8 and row[8] else None,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class RequestTrace:
|
| 81 |
+
"""A request trace for latency analysis."""
|
| 82 |
+
request_id: str
|
| 83 |
+
prompt_tokens: int
|
| 84 |
+
output_tokens: int
|
| 85 |
+
queue_time_ms: float
|
| 86 |
+
prefill_time_ms: float
|
| 87 |
+
decode_time_ms: float
|
| 88 |
+
total_time_ms: float
|
| 89 |
+
tokens_per_second: float
|
| 90 |
+
gpu_memory_at_start: float = 0.0
|
| 91 |
+
gpu_memory_at_end: float = 0.0
|
| 92 |
+
is_slow: bool = False
|
| 93 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 94 |
+
id: Optional[int] = None
|
| 95 |
+
|
| 96 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 97 |
+
return {
|
| 98 |
+
"id": self.id,
|
| 99 |
+
"request_id": self.request_id,
|
| 100 |
+
"timestamp": self.timestamp.isoformat(),
|
| 101 |
+
"prompt_tokens": self.prompt_tokens,
|
| 102 |
+
"output_tokens": self.output_tokens,
|
| 103 |
+
"queue_time_ms": round(self.queue_time_ms, 2),
|
| 104 |
+
"prefill_time_ms": round(self.prefill_time_ms, 2),
|
| 105 |
+
"decode_time_ms": round(self.decode_time_ms, 2),
|
| 106 |
+
"total_time_ms": round(self.total_time_ms, 2),
|
| 107 |
+
"tokens_per_second": round(self.tokens_per_second, 2),
|
| 108 |
+
"is_slow": self.is_slow,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def from_row(cls, row: tuple) -> "RequestTrace":
|
| 113 |
+
return cls(
|
| 114 |
+
id=row[0],
|
| 115 |
+
request_id=row[1],
|
| 116 |
+
timestamp=datetime.fromisoformat(row[2]),
|
| 117 |
+
prompt_tokens=row[3],
|
| 118 |
+
output_tokens=row[4],
|
| 119 |
+
queue_time_ms=row[5],
|
| 120 |
+
prefill_time_ms=row[6],
|
| 121 |
+
decode_time_ms=row[7],
|
| 122 |
+
total_time_ms=row[8],
|
| 123 |
+
tokens_per_second=row[9] if len(row) > 9 else 0,
|
| 124 |
+
is_slow=bool(row[10]) if len(row) > 10 else False,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass
|
| 129 |
+
class LoadTestResult:
|
| 130 |
+
"""Results from a load test run."""
|
| 131 |
+
test_id: str
|
| 132 |
+
target_endpoint: str
|
| 133 |
+
concurrent_users: int
|
| 134 |
+
requests_per_second: float
|
| 135 |
+
duration_seconds: int
|
| 136 |
+
total_requests: int
|
| 137 |
+
successful_requests: int
|
| 138 |
+
failed_requests: int
|
| 139 |
+
avg_latency_ms: float
|
| 140 |
+
p50_latency_ms: float
|
| 141 |
+
p95_latency_ms: float
|
| 142 |
+
p99_latency_ms: float
|
| 143 |
+
throughput_rps: float
|
| 144 |
+
saturation_point: Optional[float] = None
|
| 145 |
+
timestamp: datetime = field(default_factory=datetime.now)
|
| 146 |
+
id: Optional[int] = None
|
| 147 |
+
|
| 148 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 149 |
+
return {
|
| 150 |
+
"test_id": self.test_id,
|
| 151 |
+
"target_endpoint": self.target_endpoint,
|
| 152 |
+
"concurrent_users": self.concurrent_users,
|
| 153 |
+
"requests_per_second": self.requests_per_second,
|
| 154 |
+
"duration_seconds": self.duration_seconds,
|
| 155 |
+
"total_requests": self.total_requests,
|
| 156 |
+
"successful_requests": self.successful_requests,
|
| 157 |
+
"failed_requests": self.failed_requests,
|
| 158 |
+
"avg_latency_ms": round(self.avg_latency_ms, 2),
|
| 159 |
+
"p50_latency_ms": round(self.p50_latency_ms, 2),
|
| 160 |
+
"p95_latency_ms": round(self.p95_latency_ms, 2),
|
| 161 |
+
"p99_latency_ms": round(self.p99_latency_ms, 2),
|
| 162 |
+
"throughput_rps": round(self.throughput_rps, 2),
|
| 163 |
+
"saturation_point": self.saturation_point,
|
| 164 |
+
"timestamp": self.timestamp.isoformat(),
|
| 165 |
+
}
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility modules for the dashboard."""
|
| 2 |
+
|
| 3 |
+
from .prometheus_parser import parse_prometheus_metrics
|
| 4 |
+
from .history import MetricHistory
|
| 5 |
+
|
| 6 |
+
__all__ = ["parse_prometheus_metrics", "MetricHistory"]
|
utils/history.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""In-memory metric history buffer for time-series data."""
|
| 2 |
+
|
| 3 |
+
from collections import deque
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Dict, List, Any, Optional
|
| 7 |
+
import threading
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class HistoryPoint:
|
| 12 |
+
"""A single point in metric history."""
|
| 13 |
+
timestamp: datetime
|
| 14 |
+
value: float
|
| 15 |
+
labels: Dict[str, str] = field(default_factory=dict)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MetricHistory:
|
| 19 |
+
"""
|
| 20 |
+
Thread-safe in-memory buffer for metric history.
|
| 21 |
+
|
| 22 |
+
Maintains a rolling window of metric values for charting.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, max_length: int = 300):
|
| 26 |
+
"""
|
| 27 |
+
Initialize history buffer.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
max_length: Maximum number of points to retain
|
| 31 |
+
"""
|
| 32 |
+
self.max_length = max_length
|
| 33 |
+
self._data: Dict[str, deque] = {}
|
| 34 |
+
self._lock = threading.Lock()
|
| 35 |
+
|
| 36 |
+
def add(self, metric_name: str, value: float, labels: Optional[Dict[str, str]] = None) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Add a data point to the history.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
metric_name: Name of the metric
|
| 42 |
+
value: Metric value
|
| 43 |
+
labels: Optional labels for the metric
|
| 44 |
+
"""
|
| 45 |
+
point = HistoryPoint(
|
| 46 |
+
timestamp=datetime.now(),
|
| 47 |
+
value=value,
|
| 48 |
+
labels=labels or {}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Create key including labels for differentiation
|
| 52 |
+
key = self._make_key(metric_name, labels)
|
| 53 |
+
|
| 54 |
+
with self._lock:
|
| 55 |
+
if key not in self._data:
|
| 56 |
+
self._data[key] = deque(maxlen=self.max_length)
|
| 57 |
+
self._data[key].append(point)
|
| 58 |
+
|
| 59 |
+
def get(
|
| 60 |
+
self,
|
| 61 |
+
metric_name: str,
|
| 62 |
+
labels: Optional[Dict[str, str]] = None,
|
| 63 |
+
limit: Optional[int] = None
|
| 64 |
+
) -> List[HistoryPoint]:
|
| 65 |
+
"""
|
| 66 |
+
Get history for a metric.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
metric_name: Name of the metric
|
| 70 |
+
labels: Optional label filter
|
| 71 |
+
limit: Maximum number of points to return
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
List of history points
|
| 75 |
+
"""
|
| 76 |
+
key = self._make_key(metric_name, labels)
|
| 77 |
+
|
| 78 |
+
with self._lock:
|
| 79 |
+
if key not in self._data:
|
| 80 |
+
return []
|
| 81 |
+
|
| 82 |
+
points = list(self._data[key])
|
| 83 |
+
if limit:
|
| 84 |
+
points = points[-limit:]
|
| 85 |
+
return points
|
| 86 |
+
|
| 87 |
+
def get_latest(
|
| 88 |
+
self,
|
| 89 |
+
metric_name: str,
|
| 90 |
+
labels: Optional[Dict[str, str]] = None
|
| 91 |
+
) -> Optional[HistoryPoint]:
|
| 92 |
+
"""Get the most recent value for a metric."""
|
| 93 |
+
points = self.get(metric_name, labels, limit=1)
|
| 94 |
+
return points[-1] if points else None
|
| 95 |
+
|
| 96 |
+
def get_all_series(self, metric_name: str) -> Dict[str, List[HistoryPoint]]:
|
| 97 |
+
"""
|
| 98 |
+
Get all label combinations for a metric.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
metric_name: Base metric name
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Dictionary mapping label strings to history lists
|
| 105 |
+
"""
|
| 106 |
+
result = {}
|
| 107 |
+
prefix = f"{metric_name}:"
|
| 108 |
+
|
| 109 |
+
with self._lock:
|
| 110 |
+
for key, points in self._data.items():
|
| 111 |
+
if key == metric_name or key.startswith(prefix):
|
| 112 |
+
result[key] = list(points)
|
| 113 |
+
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
def to_dataframe(self, metric_name: str, labels: Optional[Dict[str, str]] = None):
|
| 117 |
+
"""
|
| 118 |
+
Convert history to pandas DataFrame.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
metric_name: Name of the metric
|
| 122 |
+
labels: Optional label filter
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
pandas DataFrame with time and value columns
|
| 126 |
+
"""
|
| 127 |
+
import pandas as pd
|
| 128 |
+
|
| 129 |
+
points = self.get(metric_name, labels)
|
| 130 |
+
|
| 131 |
+
if not points:
|
| 132 |
+
return pd.DataFrame(columns=["time", "value"])
|
| 133 |
+
|
| 134 |
+
return pd.DataFrame([
|
| 135 |
+
{"time": p.timestamp, "value": p.value, **p.labels}
|
| 136 |
+
for p in points
|
| 137 |
+
])
|
| 138 |
+
|
| 139 |
+
def clear(self, metric_name: Optional[str] = None) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Clear history.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
metric_name: If provided, clear only this metric; otherwise clear all
|
| 145 |
+
"""
|
| 146 |
+
with self._lock:
|
| 147 |
+
if metric_name:
|
| 148 |
+
keys_to_remove = [
|
| 149 |
+
k for k in self._data.keys()
|
| 150 |
+
if k == metric_name or k.startswith(f"{metric_name}:")
|
| 151 |
+
]
|
| 152 |
+
for key in keys_to_remove:
|
| 153 |
+
del self._data[key]
|
| 154 |
+
else:
|
| 155 |
+
self._data.clear()
|
| 156 |
+
|
| 157 |
+
def _make_key(self, metric_name: str, labels: Optional[Dict[str, str]]) -> str:
|
| 158 |
+
"""Create a unique key from metric name and labels."""
|
| 159 |
+
if not labels:
|
| 160 |
+
return metric_name
|
| 161 |
+
|
| 162 |
+
label_str = ",".join(f"{k}={v}" for k, v in sorted(labels.items()))
|
| 163 |
+
return f"{metric_name}:{label_str}"
|
utils/prometheus_parser.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Parser for Prometheus text format metrics."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Dict, List, Any, Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class MetricSample:
|
| 10 |
+
"""A single metric sample with labels and value."""
|
| 11 |
+
name: str
|
| 12 |
+
labels: Dict[str, str]
|
| 13 |
+
value: float
|
| 14 |
+
timestamp: Optional[float] = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_prometheus_metrics(text: str) -> Dict[str, List[MetricSample]]:
|
| 18 |
+
"""
|
| 19 |
+
Parse Prometheus text format into structured metrics.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
text: Raw Prometheus metrics text
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Dictionary mapping metric names to lists of samples
|
| 26 |
+
"""
|
| 27 |
+
metrics: Dict[str, List[MetricSample]] = {}
|
| 28 |
+
|
| 29 |
+
for line in text.strip().split("\n"):
|
| 30 |
+
line = line.strip()
|
| 31 |
+
|
| 32 |
+
# Skip empty lines and comments
|
| 33 |
+
if not line or line.startswith("#"):
|
| 34 |
+
continue
|
| 35 |
+
|
| 36 |
+
# Parse metric line
|
| 37 |
+
sample = _parse_metric_line(line)
|
| 38 |
+
if sample:
|
| 39 |
+
if sample.name not in metrics:
|
| 40 |
+
metrics[sample.name] = []
|
| 41 |
+
metrics[sample.name].append(sample)
|
| 42 |
+
|
| 43 |
+
return metrics
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _parse_metric_line(line: str) -> Optional[MetricSample]:
|
| 47 |
+
"""Parse a single Prometheus metric line."""
|
| 48 |
+
# Pattern: metric_name{label1="value1",label2="value2"} value [timestamp]
|
| 49 |
+
# Or: metric_name value [timestamp]
|
| 50 |
+
|
| 51 |
+
# Match with labels
|
| 52 |
+
match = re.match(
|
| 53 |
+
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([^\s]+)(?:\s+(\d+))?$',
|
| 54 |
+
line
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if match:
|
| 58 |
+
name = match.group(1)
|
| 59 |
+
labels_str = match.group(2)
|
| 60 |
+
value_str = match.group(3)
|
| 61 |
+
timestamp_str = match.group(4)
|
| 62 |
+
|
| 63 |
+
labels = _parse_labels(labels_str)
|
| 64 |
+
value = _parse_value(value_str)
|
| 65 |
+
timestamp = float(timestamp_str) if timestamp_str else None
|
| 66 |
+
|
| 67 |
+
return MetricSample(name=name, labels=labels, value=value, timestamp=timestamp)
|
| 68 |
+
|
| 69 |
+
# Match without labels
|
| 70 |
+
match = re.match(
|
| 71 |
+
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([^\s]+)(?:\s+(\d+))?$',
|
| 72 |
+
line
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if match:
|
| 76 |
+
name = match.group(1)
|
| 77 |
+
value_str = match.group(2)
|
| 78 |
+
timestamp_str = match.group(3)
|
| 79 |
+
|
| 80 |
+
value = _parse_value(value_str)
|
| 81 |
+
timestamp = float(timestamp_str) if timestamp_str else None
|
| 82 |
+
|
| 83 |
+
return MetricSample(name=name, labels={}, value=value, timestamp=timestamp)
|
| 84 |
+
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _parse_labels(labels_str: str) -> Dict[str, str]:
|
| 89 |
+
"""Parse label string into dictionary."""
|
| 90 |
+
labels = {}
|
| 91 |
+
|
| 92 |
+
# Pattern: key="value"
|
| 93 |
+
for match in re.finditer(r'([a-zA-Z_][a-zA-Z0-9_]*)="([^"]*)"', labels_str):
|
| 94 |
+
labels[match.group(1)] = match.group(2)
|
| 95 |
+
|
| 96 |
+
return labels
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _parse_value(value_str: str) -> float:
|
| 100 |
+
"""Parse metric value, handling special cases."""
|
| 101 |
+
if value_str.lower() == "nan":
|
| 102 |
+
return float("nan")
|
| 103 |
+
if value_str.lower() == "+inf":
|
| 104 |
+
return float("inf")
|
| 105 |
+
if value_str.lower() == "-inf":
|
| 106 |
+
return float("-inf")
|
| 107 |
+
return float(value_str)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_metric_value(
|
| 111 |
+
metrics: Dict[str, List[MetricSample]],
|
| 112 |
+
name: str,
|
| 113 |
+
labels: Optional[Dict[str, str]] = None
|
| 114 |
+
) -> Optional[float]:
|
| 115 |
+
"""
|
| 116 |
+
Get a specific metric value by name and optional labels.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
metrics: Parsed metrics dictionary
|
| 120 |
+
name: Metric name
|
| 121 |
+
labels: Optional label filter
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Metric value or None if not found
|
| 125 |
+
"""
|
| 126 |
+
if name not in metrics:
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
for sample in metrics[name]:
|
| 130 |
+
if labels is None:
|
| 131 |
+
return sample.value
|
| 132 |
+
|
| 133 |
+
# Check if all specified labels match
|
| 134 |
+
if all(sample.labels.get(k) == v for k, v in labels.items()):
|
| 135 |
+
return sample.value
|
| 136 |
+
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_histogram_quantile(
|
| 141 |
+
metrics: Dict[str, List[MetricSample]],
|
| 142 |
+
name: str,
|
| 143 |
+
quantile: float,
|
| 144 |
+
labels: Optional[Dict[str, str]] = None
|
| 145 |
+
) -> Optional[float]:
|
| 146 |
+
"""
|
| 147 |
+
Get histogram quantile value from Prometheus histogram.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
metrics: Parsed metrics dictionary
|
| 151 |
+
name: Base metric name (without _bucket suffix)
|
| 152 |
+
quantile: Desired quantile (e.g., 0.95 for P95)
|
| 153 |
+
labels: Optional label filter
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Approximate quantile value or None
|
| 157 |
+
"""
|
| 158 |
+
bucket_name = f"{name}_bucket"
|
| 159 |
+
if bucket_name not in metrics:
|
| 160 |
+
return None
|
| 161 |
+
|
| 162 |
+
# Get all buckets
|
| 163 |
+
buckets = []
|
| 164 |
+
for sample in metrics[bucket_name]:
|
| 165 |
+
if labels and not all(sample.labels.get(k) == v for k, v in labels.items()):
|
| 166 |
+
continue
|
| 167 |
+
le = sample.labels.get("le")
|
| 168 |
+
if le and le != "+Inf":
|
| 169 |
+
buckets.append((float(le), sample.value))
|
| 170 |
+
|
| 171 |
+
if not buckets:
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
# Sort by bucket boundary
|
| 175 |
+
buckets.sort(key=lambda x: x[0])
|
| 176 |
+
|
| 177 |
+
# Get total count
|
| 178 |
+
total = buckets[-1][1] if buckets else 0
|
| 179 |
+
if total == 0:
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
# Find bucket containing quantile
|
| 183 |
+
target = quantile * total
|
| 184 |
+
prev_bound = 0
|
| 185 |
+
prev_count = 0
|
| 186 |
+
|
| 187 |
+
for bound, count in buckets:
|
| 188 |
+
if count >= target:
|
| 189 |
+
# Linear interpolation within bucket
|
| 190 |
+
fraction = (target - prev_count) / (count - prev_count) if count > prev_count else 0
|
| 191 |
+
return prev_bound + fraction * (bound - prev_bound)
|
| 192 |
+
prev_bound = bound
|
| 193 |
+
prev_count = count
|
| 194 |
+
|
| 195 |
+
return buckets[-1][0] if buckets else None
|