Spaces:
Sleeping
Sleeping
Fix demo mode with simulated metrics
Browse files- Add realistic demo data for GPU, inference, and quantization
- Fix Gradio compatibility issues (remove max_rows)
- Enable SSR-free mode for HuggingFace Spaces
- Simulated metrics now vary over time realistically
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- app.py +21 -60
- collectors/gpu_collector.py +39 -43
- collectors/loading_tracker.py +43 -26
- collectors/quant_collector.py +39 -32
- collectors/vllm_collector.py +88 -22
app.py
CHANGED
|
@@ -9,7 +9,6 @@ load testing, and historical analysis.
|
|
| 9 |
import asyncio
|
| 10 |
import logging
|
| 11 |
import os
|
| 12 |
-
from datetime import datetime
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
|
|
@@ -44,16 +43,18 @@ logging.basicConfig(
|
|
| 44 |
)
|
| 45 |
logger = logging.getLogger(__name__)
|
| 46 |
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Initialize global instances
|
| 49 |
db = MetricsDB(config.db_path)
|
| 50 |
history = MetricHistory(max_length=config.history_length)
|
| 51 |
|
| 52 |
-
# Collectors
|
| 53 |
-
gpu_collector = GPUCollector()
|
| 54 |
-
vllm_collector = VLLMCollector(config.metrics_endpoint)
|
| 55 |
-
quant_collector = QuantizationCollector(config.model_path)
|
| 56 |
-
loading_tracker = LoadingTracker(config.model_path)
|
| 57 |
|
| 58 |
# Services
|
| 59 |
alert_engine = AlertEngine(db)
|
|
@@ -67,6 +68,14 @@ request_tracer = RequestTracer(db)
|
|
| 67 |
|
| 68 |
def check_connection():
|
| 69 |
"""Check connection to vLLM server."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
connected = vllm_collector.check_connection()
|
| 71 |
if connected:
|
| 72 |
return (
|
|
@@ -86,7 +95,7 @@ def check_connection():
|
|
| 86 |
def get_model_name():
|
| 87 |
"""Get current model name."""
|
| 88 |
metrics = vllm_collector.collect()
|
| 89 |
-
return metrics.model_name or "
|
| 90 |
|
| 91 |
|
| 92 |
def update_all_metrics():
|
|
@@ -114,17 +123,6 @@ def update_all_metrics():
|
|
| 114 |
|
| 115 |
new_alerts = alert_engine.evaluate(metrics_dict)
|
| 116 |
|
| 117 |
-
# Dispatch new alerts (handle async properly)
|
| 118 |
-
for alert in new_alerts:
|
| 119 |
-
try:
|
| 120 |
-
loop = asyncio.get_event_loop()
|
| 121 |
-
if loop.is_running():
|
| 122 |
-
asyncio.create_task(alert_dispatcher.dispatch(alert))
|
| 123 |
-
else:
|
| 124 |
-
loop.run_until_complete(alert_dispatcher.dispatch(alert))
|
| 125 |
-
except RuntimeError:
|
| 126 |
-
pass # No event loop available
|
| 127 |
-
|
| 128 |
# Get alert badge
|
| 129 |
active_alerts = alert_engine.get_active_alerts()
|
| 130 |
alert_badge = get_alert_badge_html(active_alerts)
|
|
@@ -160,11 +158,6 @@ def update_all_metrics():
|
|
| 160 |
def create_dashboard():
|
| 161 |
"""Create the main dashboard application."""
|
| 162 |
|
| 163 |
-
custom_css = """
|
| 164 |
-
.gradio-container { max-width: 1400px !important; }
|
| 165 |
-
.panel-header { font-size: 1.2em; font-weight: bold; margin-bottom: 10px; }
|
| 166 |
-
"""
|
| 167 |
-
|
| 168 |
with gr.Blocks(title="LLM Inference Dashboard") as app:
|
| 169 |
gr.Markdown("# LLM Inference Dashboard")
|
| 170 |
gr.Markdown("*Real-time monitoring for vLLM inference servers*")
|
|
@@ -200,15 +193,6 @@ def create_dashboard():
|
|
| 200 |
with gr.Tab("Quantization"):
|
| 201 |
quant_components = create_quant_panel()
|
| 202 |
|
| 203 |
-
# Initial update
|
| 204 |
-
(
|
| 205 |
-
quant_type, bits, group_size, quant_details, layer_table
|
| 206 |
-
) = update_quant_panel(quant_collector)
|
| 207 |
-
|
| 208 |
-
quant_components["quant_type"].value = quant_type
|
| 209 |
-
quant_components["bits"].value = bits
|
| 210 |
-
quant_components["group_size"].value = group_size
|
| 211 |
-
|
| 212 |
# Tab 4: Loading Progress
|
| 213 |
with gr.Tab("Loading"):
|
| 214 |
loading_components = create_loading_panel()
|
|
@@ -229,8 +213,8 @@ def create_dashboard():
|
|
| 229 |
with gr.Tab("Load Test"):
|
| 230 |
loadtest_components = create_loadtest_panel()
|
| 231 |
|
| 232 |
-
# Auto-refresh timer
|
| 233 |
-
timer = gr.Timer(
|
| 234 |
|
| 235 |
# Collect all outputs for timer update
|
| 236 |
timer_outputs = [
|
|
@@ -258,51 +242,28 @@ def create_dashboard():
|
|
| 258 |
|
| 259 |
timer.tick(fn=update_all_metrics, outputs=timer_outputs)
|
| 260 |
|
| 261 |
-
# Manual refresh for tabs that don't auto-update
|
| 262 |
-
def refresh_quant():
|
| 263 |
-
return update_quant_panel(quant_collector)
|
| 264 |
-
|
| 265 |
-
def refresh_loading():
|
| 266 |
-
return update_loading_panel(loading_tracker)
|
| 267 |
-
|
| 268 |
-
def refresh_alerts():
|
| 269 |
-
return update_alerts_panel(alert_engine, db)
|
| 270 |
-
|
| 271 |
return app
|
| 272 |
|
| 273 |
|
| 274 |
def main():
|
| 275 |
"""Main entry point."""
|
| 276 |
logger.info("Starting LLM Inference Dashboard")
|
| 277 |
-
logger.info(f"
|
| 278 |
logger.info(f"Database: {config.db_path}")
|
| 279 |
|
| 280 |
-
# Check initial connection
|
| 281 |
-
if vllm_collector.check_connection():
|
| 282 |
-
logger.info("Successfully connected to vLLM server")
|
| 283 |
-
|
| 284 |
-
# Set model ready if connected
|
| 285 |
-
loading_tracker.set_ready()
|
| 286 |
-
|
| 287 |
-
# Try to detect quantization
|
| 288 |
-
metrics = vllm_collector.collect()
|
| 289 |
-
if metrics.model_name:
|
| 290 |
-
quant_collector.set_model_path(metrics.model_name)
|
| 291 |
-
else:
|
| 292 |
-
logger.warning("Could not connect to vLLM server - dashboard will show mock data")
|
| 293 |
-
|
| 294 |
# Create and launch the dashboard
|
| 295 |
app = create_dashboard()
|
| 296 |
|
| 297 |
# Check if running on HuggingFace Spaces
|
| 298 |
if os.getenv("SPACE_ID"):
|
| 299 |
-
app.launch()
|
| 300 |
else:
|
| 301 |
app.launch(
|
| 302 |
server_name="0.0.0.0",
|
| 303 |
server_port=7860,
|
| 304 |
share=False,
|
| 305 |
show_error=True,
|
|
|
|
| 306 |
)
|
| 307 |
|
| 308 |
|
|
|
|
| 9 |
import asyncio
|
| 10 |
import logging
|
| 11 |
import os
|
|
|
|
| 12 |
|
| 13 |
import gradio as gr
|
| 14 |
|
|
|
|
| 43 |
)
|
| 44 |
logger = logging.getLogger(__name__)
|
| 45 |
|
| 46 |
+
# Check if running in demo mode (no vLLM server)
|
| 47 |
+
DEMO_MODE = True
|
| 48 |
|
| 49 |
# Initialize global instances
|
| 50 |
db = MetricsDB(config.db_path)
|
| 51 |
history = MetricHistory(max_length=config.history_length)
|
| 52 |
|
| 53 |
+
# Collectors - all in demo mode by default
|
| 54 |
+
gpu_collector = GPUCollector(demo_mode=DEMO_MODE)
|
| 55 |
+
vllm_collector = VLLMCollector(config.metrics_endpoint, demo_mode=DEMO_MODE)
|
| 56 |
+
quant_collector = QuantizationCollector(config.model_path, demo_mode=DEMO_MODE)
|
| 57 |
+
loading_tracker = LoadingTracker(config.model_path, demo_mode=DEMO_MODE)
|
| 58 |
|
| 59 |
# Services
|
| 60 |
alert_engine = AlertEngine(db)
|
|
|
|
| 68 |
|
| 69 |
def check_connection():
|
| 70 |
"""Check connection to vLLM server."""
|
| 71 |
+
if DEMO_MODE:
|
| 72 |
+
return (
|
| 73 |
+
'<div style="display: flex; align-items: center;">'
|
| 74 |
+
'<span style="width: 12px; height: 12px; background: #2196f3; '
|
| 75 |
+
'border-radius: 50%; display: inline-block; margin-right: 8px;"></span>'
|
| 76 |
+
'<span style="color: #1565c0;">Demo Mode</span></div>'
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
connected = vllm_collector.check_connection()
|
| 80 |
if connected:
|
| 81 |
return (
|
|
|
|
| 95 |
def get_model_name():
|
| 96 |
"""Get current model name."""
|
| 97 |
metrics = vllm_collector.collect()
|
| 98 |
+
return metrics.model_name or "Qwen/Qwen2.5-3B-Instruct"
|
| 99 |
|
| 100 |
|
| 101 |
def update_all_metrics():
|
|
|
|
| 123 |
|
| 124 |
new_alerts = alert_engine.evaluate(metrics_dict)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# Get alert badge
|
| 127 |
active_alerts = alert_engine.get_active_alerts()
|
| 128 |
alert_badge = get_alert_badge_html(active_alerts)
|
|
|
|
| 158 |
def create_dashboard():
|
| 159 |
"""Create the main dashboard application."""
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
with gr.Blocks(title="LLM Inference Dashboard") as app:
|
| 162 |
gr.Markdown("# LLM Inference Dashboard")
|
| 163 |
gr.Markdown("*Real-time monitoring for vLLM inference servers*")
|
|
|
|
| 193 |
with gr.Tab("Quantization"):
|
| 194 |
quant_components = create_quant_panel()
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
# Tab 4: Loading Progress
|
| 197 |
with gr.Tab("Loading"):
|
| 198 |
loading_components = create_loading_panel()
|
|
|
|
| 213 |
with gr.Tab("Load Test"):
|
| 214 |
loadtest_components = create_loadtest_panel()
|
| 215 |
|
| 216 |
+
# Auto-refresh timer (every 2 seconds for demo)
|
| 217 |
+
timer = gr.Timer(2.0)
|
| 218 |
|
| 219 |
# Collect all outputs for timer update
|
| 220 |
timer_outputs = [
|
|
|
|
| 242 |
|
| 243 |
timer.tick(fn=update_all_metrics, outputs=timer_outputs)
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
return app
|
| 246 |
|
| 247 |
|
| 248 |
def main():
|
| 249 |
"""Main entry point."""
|
| 250 |
logger.info("Starting LLM Inference Dashboard")
|
| 251 |
+
logger.info(f"Demo mode: {DEMO_MODE}")
|
| 252 |
logger.info(f"Database: {config.db_path}")
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
# Create and launch the dashboard
|
| 255 |
app = create_dashboard()
|
| 256 |
|
| 257 |
# Check if running on HuggingFace Spaces
|
| 258 |
if os.getenv("SPACE_ID"):
|
| 259 |
+
app.launch(ssr_mode=False)
|
| 260 |
else:
|
| 261 |
app.launch(
|
| 262 |
server_name="0.0.0.0",
|
| 263 |
server_port=7860,
|
| 264 |
share=False,
|
| 265 |
show_error=True,
|
| 266 |
+
ssr_mode=False,
|
| 267 |
)
|
| 268 |
|
| 269 |
|
collectors/gpu_collector.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""GPU statistics collector using pynvml."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import List, Optional
|
| 5 |
import logging
|
|
@@ -33,13 +35,15 @@ class GPUStats:
|
|
| 33 |
class GPUCollector:
|
| 34 |
"""Collects GPU statistics via NVIDIA Management Library."""
|
| 35 |
|
| 36 |
-
def __init__(self):
|
| 37 |
"""Initialize the GPU collector."""
|
| 38 |
self._initialized = False
|
| 39 |
self._gpu_count = 0
|
| 40 |
self._rank_mapping: dict = {}
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
if PYNVML_AVAILABLE:
|
| 43 |
try:
|
| 44 |
pynvml.nvmlInit()
|
| 45 |
self._initialized = True
|
|
@@ -47,14 +51,13 @@ class GPUCollector:
|
|
| 47 |
logger.info(f"Initialized pynvml with {self._gpu_count} GPUs")
|
| 48 |
except Exception as e:
|
| 49 |
logger.error(f"Failed to initialize pynvml: {e}")
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
Set tensor parallel rank to GPU ID mapping.
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
"""
|
| 58 |
self._rank_mapping = mapping
|
| 59 |
|
| 60 |
def get_gpu_count(self) -> int:
|
|
@@ -62,14 +65,9 @@ class GPUCollector:
|
|
| 62 |
return self._gpu_count
|
| 63 |
|
| 64 |
def collect(self) -> List[GPUStats]:
|
| 65 |
-
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
Returns:
|
| 69 |
-
List of GPUStats for each GPU
|
| 70 |
-
"""
|
| 71 |
-
if not self._initialized:
|
| 72 |
-
return self._get_mock_stats()
|
| 73 |
|
| 74 |
stats = []
|
| 75 |
for i in range(self._gpu_count):
|
|
@@ -85,27 +83,22 @@ class GPUCollector:
|
|
| 85 |
"""Collect stats for a single GPU."""
|
| 86 |
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
| 87 |
|
| 88 |
-
# Get device name
|
| 89 |
name = pynvml.nvmlDeviceGetName(handle)
|
| 90 |
if isinstance(name, bytes):
|
| 91 |
name = name.decode("utf-8")
|
| 92 |
|
| 93 |
-
# Memory info
|
| 94 |
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 95 |
memory_used_gb = mem_info.used / 1e9
|
| 96 |
memory_total_gb = mem_info.total / 1e9
|
| 97 |
memory_percent = (mem_info.used / mem_info.total) * 100
|
| 98 |
|
| 99 |
-
# Utilization
|
| 100 |
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
| 101 |
gpu_util_percent = util.gpu
|
| 102 |
|
| 103 |
-
# Temperature
|
| 104 |
temperature_c = pynvml.nvmlDeviceGetTemperature(
|
| 105 |
handle, pynvml.NVML_TEMPERATURE_GPU
|
| 106 |
)
|
| 107 |
|
| 108 |
-
# Power
|
| 109 |
try:
|
| 110 |
power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
|
| 111 |
power_limit_watts = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) / 1000.0
|
|
@@ -113,7 +106,6 @@ class GPUCollector:
|
|
| 113 |
power_watts = 0
|
| 114 |
power_limit_watts = 0
|
| 115 |
|
| 116 |
-
# Find TP rank for this GPU
|
| 117 |
tp_rank = None
|
| 118 |
for rank, gid in self._rank_mapping.items():
|
| 119 |
if gid == gpu_id:
|
|
@@ -133,37 +125,41 @@ class GPUCollector:
|
|
| 133 |
tp_rank=tp_rank,
|
| 134 |
)
|
| 135 |
|
| 136 |
-
def
|
| 137 |
-
"""Return
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
GPUStats(
|
| 142 |
gpu_id=0,
|
| 143 |
-
name="
|
| 144 |
-
memory_used_gb=random.uniform(
|
| 145 |
-
memory_total_gb=
|
| 146 |
-
memory_percent=
|
| 147 |
-
gpu_util_percent=random.uniform(
|
| 148 |
-
temperature_c=random.randint(
|
| 149 |
-
power_watts=random.uniform(
|
| 150 |
-
power_limit_watts=
|
| 151 |
tp_rank=0,
|
| 152 |
),
|
| 153 |
GPUStats(
|
| 154 |
gpu_id=1,
|
| 155 |
-
name="
|
| 156 |
-
memory_used_gb=random.uniform(
|
| 157 |
-
memory_total_gb=
|
| 158 |
-
memory_percent=
|
| 159 |
-
gpu_util_percent=random.uniform(
|
| 160 |
-
temperature_c=random.randint(
|
| 161 |
-
power_watts=random.uniform(
|
| 162 |
-
power_limit_watts=
|
| 163 |
tp_rank=1,
|
| 164 |
),
|
| 165 |
]
|
| 166 |
-
return
|
| 167 |
|
| 168 |
def shutdown(self) -> None:
|
| 169 |
"""Clean up NVML resources."""
|
|
|
|
| 1 |
"""GPU statistics collector using pynvml."""
|
| 2 |
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import List, Optional
|
| 7 |
import logging
|
|
|
|
| 35 |
class GPUCollector:
|
| 36 |
"""Collects GPU statistics via NVIDIA Management Library."""
|
| 37 |
|
| 38 |
+
def __init__(self, demo_mode: bool = True):
|
| 39 |
"""Initialize the GPU collector."""
|
| 40 |
self._initialized = False
|
| 41 |
self._gpu_count = 0
|
| 42 |
self._rank_mapping: dict = {}
|
| 43 |
+
self._demo_mode = demo_mode
|
| 44 |
+
self._demo_start_time = time.time()
|
| 45 |
|
| 46 |
+
if PYNVML_AVAILABLE and not demo_mode:
|
| 47 |
try:
|
| 48 |
pynvml.nvmlInit()
|
| 49 |
self._initialized = True
|
|
|
|
| 51 |
logger.info(f"Initialized pynvml with {self._gpu_count} GPUs")
|
| 52 |
except Exception as e:
|
| 53 |
logger.error(f"Failed to initialize pynvml: {e}")
|
| 54 |
+
self._demo_mode = True
|
| 55 |
|
| 56 |
+
if self._demo_mode:
|
| 57 |
+
self._gpu_count = 2 # Simulate 2 GPUs for demo
|
|
|
|
| 58 |
|
| 59 |
+
def set_rank_mapping(self, mapping: dict) -> None:
|
| 60 |
+
"""Set tensor parallel rank to GPU ID mapping."""
|
|
|
|
| 61 |
self._rank_mapping = mapping
|
| 62 |
|
| 63 |
def get_gpu_count(self) -> int:
|
|
|
|
| 65 |
return self._gpu_count
|
| 66 |
|
| 67 |
def collect(self) -> List[GPUStats]:
|
| 68 |
+
"""Collect stats for all GPUs."""
|
| 69 |
+
if self._demo_mode or not self._initialized:
|
| 70 |
+
return self._get_demo_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
stats = []
|
| 73 |
for i in range(self._gpu_count):
|
|
|
|
| 83 |
"""Collect stats for a single GPU."""
|
| 84 |
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
| 85 |
|
|
|
|
| 86 |
name = pynvml.nvmlDeviceGetName(handle)
|
| 87 |
if isinstance(name, bytes):
|
| 88 |
name = name.decode("utf-8")
|
| 89 |
|
|
|
|
| 90 |
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
| 91 |
memory_used_gb = mem_info.used / 1e9
|
| 92 |
memory_total_gb = mem_info.total / 1e9
|
| 93 |
memory_percent = (mem_info.used / mem_info.total) * 100
|
| 94 |
|
|
|
|
| 95 |
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
| 96 |
gpu_util_percent = util.gpu
|
| 97 |
|
|
|
|
| 98 |
temperature_c = pynvml.nvmlDeviceGetTemperature(
|
| 99 |
handle, pynvml.NVML_TEMPERATURE_GPU
|
| 100 |
)
|
| 101 |
|
|
|
|
| 102 |
try:
|
| 103 |
power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
|
| 104 |
power_limit_watts = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) / 1000.0
|
|
|
|
| 106 |
power_watts = 0
|
| 107 |
power_limit_watts = 0
|
| 108 |
|
|
|
|
| 109 |
tp_rank = None
|
| 110 |
for rank, gid in self._rank_mapping.items():
|
| 111 |
if gid == gpu_id:
|
|
|
|
| 125 |
tp_rank=tp_rank,
|
| 126 |
)
|
| 127 |
|
| 128 |
+
def _get_demo_stats(self) -> List[GPUStats]:
|
| 129 |
+
"""Return realistic demo stats simulating a running LLM."""
|
| 130 |
+
elapsed = time.time() - self._demo_start_time
|
| 131 |
+
|
| 132 |
+
# Simulate varying load patterns
|
| 133 |
+
base_util = 45 + 30 * abs((elapsed % 20) - 10) / 10 # Oscillates 45-75%
|
| 134 |
+
base_memory = 18.5 + random.uniform(-0.5, 0.5) # ~18.5 GB for a 7B model
|
| 135 |
|
| 136 |
+
demo_gpus = [
|
| 137 |
GPUStats(
|
| 138 |
gpu_id=0,
|
| 139 |
+
name="NVIDIA A100-SXM4-40GB",
|
| 140 |
+
memory_used_gb=base_memory + random.uniform(-0.2, 0.2),
|
| 141 |
+
memory_total_gb=40.0,
|
| 142 |
+
memory_percent=(base_memory / 40.0) * 100,
|
| 143 |
+
gpu_util_percent=base_util + random.uniform(-5, 5),
|
| 144 |
+
temperature_c=int(55 + base_util * 0.2 + random.randint(-2, 2)),
|
| 145 |
+
power_watts=180 + base_util * 1.5 + random.uniform(-10, 10),
|
| 146 |
+
power_limit_watts=400,
|
| 147 |
tp_rank=0,
|
| 148 |
),
|
| 149 |
GPUStats(
|
| 150 |
gpu_id=1,
|
| 151 |
+
name="NVIDIA A100-SXM4-40GB",
|
| 152 |
+
memory_used_gb=base_memory + random.uniform(-0.3, 0.3),
|
| 153 |
+
memory_total_gb=40.0,
|
| 154 |
+
memory_percent=(base_memory / 40.0) * 100,
|
| 155 |
+
gpu_util_percent=base_util + random.uniform(-8, 8),
|
| 156 |
+
temperature_c=int(54 + base_util * 0.2 + random.randint(-2, 2)),
|
| 157 |
+
power_watts=175 + base_util * 1.5 + random.uniform(-10, 10),
|
| 158 |
+
power_limit_watts=400,
|
| 159 |
tp_rank=1,
|
| 160 |
),
|
| 161 |
]
|
| 162 |
+
return demo_gpus
|
| 163 |
|
| 164 |
def shutdown(self) -> None:
|
| 165 |
"""Clean up NVML resources."""
|
collectors/loading_tracker.py
CHANGED
|
@@ -46,26 +46,63 @@ class LoadingProgress:
|
|
| 46 |
class LoadingTracker:
|
| 47 |
"""Tracks model loading progress."""
|
| 48 |
|
| 49 |
-
def __init__(self, model_path: Optional[str] = None):
|
| 50 |
"""
|
| 51 |
Initialize loading tracker.
|
| 52 |
|
| 53 |
Args:
|
| 54 |
model_path: Path to model directory
|
|
|
|
| 55 |
"""
|
| 56 |
self.model_path = model_path
|
| 57 |
self._shards: List[ShardInfo] = []
|
| 58 |
-
self._status = LoadingStatus.NOT_STARTED
|
| 59 |
-
self._progress = 0.0
|
| 60 |
self._current_shard: Optional[str] = None
|
| 61 |
self._layers_loaded = 0
|
| 62 |
self._total_layers = 0
|
| 63 |
self._start_time: Optional[float] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def set_model_path(self, model_path: str) -> None:
|
| 66 |
"""Set or update the model path."""
|
| 67 |
self.model_path = model_path
|
| 68 |
-
self.
|
|
|
|
| 69 |
|
| 70 |
def _load_shard_info(self) -> None:
|
| 71 |
"""Load shard information from safetensors index."""
|
|
@@ -82,14 +119,12 @@ class LoadingTracker:
|
|
| 82 |
|
| 83 |
weight_map = index.get("weight_map", {})
|
| 84 |
|
| 85 |
-
# Group weights by shard file
|
| 86 |
shard_weights: Dict[str, List[str]] = {}
|
| 87 |
for weight_name, shard_file in weight_map.items():
|
| 88 |
if shard_file not in shard_weights:
|
| 89 |
shard_weights[shard_file] = []
|
| 90 |
shard_weights[shard_file].append(weight_name)
|
| 91 |
|
| 92 |
-
# Create shard info
|
| 93 |
self._shards = []
|
| 94 |
for shard_file, weights in sorted(shard_weights.items()):
|
| 95 |
shard_path = self._resolve_path(shard_file)
|
|
@@ -97,7 +132,6 @@ class LoadingTracker:
|
|
| 97 |
if shard_path and shard_path.exists():
|
| 98 |
size_mb = shard_path.stat().st_size / (1024 * 1024)
|
| 99 |
|
| 100 |
-
# Extract layer names
|
| 101 |
layers = list(set(
|
| 102 |
".".join(w.split(".")[:3])
|
| 103 |
for w in weights
|
|
@@ -111,7 +145,6 @@ class LoadingTracker:
|
|
| 111 |
layers=layers,
|
| 112 |
))
|
| 113 |
|
| 114 |
-
# Count total layers
|
| 115 |
all_layers = set()
|
| 116 |
for shard in self._shards:
|
| 117 |
all_layers.update(shard.layers)
|
|
@@ -132,19 +165,12 @@ class LoadingTracker:
|
|
| 132 |
return None
|
| 133 |
|
| 134 |
def update_from_log(self, log_line: str) -> None:
|
| 135 |
-
"""
|
| 136 |
-
Update progress from a vLLM log line.
|
| 137 |
-
|
| 138 |
-
Args:
|
| 139 |
-
log_line: Log line from vLLM server
|
| 140 |
-
"""
|
| 141 |
-
# Detect loading start
|
| 142 |
if "Loading model" in log_line:
|
| 143 |
self._status = LoadingStatus.LOADING
|
| 144 |
import time
|
| 145 |
self._start_time = time.time()
|
| 146 |
|
| 147 |
-
# Detect shard loading
|
| 148 |
match = re.search(r"Loading safetensors: (\d+)/(\d+)", log_line)
|
| 149 |
if match:
|
| 150 |
loaded = int(match.group(1))
|
|
@@ -157,35 +183,26 @@ class LoadingTracker:
|
|
| 157 |
shard.status = "loading"
|
| 158 |
self._current_shard = shard.filename
|
| 159 |
|
| 160 |
-
# Detect completion
|
| 161 |
if "Model loaded" in log_line or "Running with" in log_line:
|
| 162 |
self._status = LoadingStatus.READY
|
| 163 |
self._progress = 100.0
|
| 164 |
for shard in self._shards:
|
| 165 |
shard.status = "loaded"
|
| 166 |
|
| 167 |
-
# Detect errors
|
| 168 |
if "Error" in log_line or "Exception" in log_line:
|
| 169 |
self._status = LoadingStatus.ERROR
|
| 170 |
|
| 171 |
def get_progress(self) -> LoadingProgress:
|
| 172 |
-
"""
|
| 173 |
-
Get current loading progress.
|
| 174 |
-
|
| 175 |
-
Returns:
|
| 176 |
-
LoadingProgress with current state
|
| 177 |
-
"""
|
| 178 |
loaded_shards = sum(1 for s in self._shards if s.status == "loaded")
|
| 179 |
total_shards = len(self._shards) if self._shards else 1
|
| 180 |
|
| 181 |
-
# Estimate remaining time
|
| 182 |
remaining = None
|
| 183 |
if self._start_time and self._progress > 0:
|
| 184 |
import time
|
| 185 |
elapsed = time.time() - self._start_time
|
| 186 |
remaining = (elapsed / self._progress) * (100 - self._progress)
|
| 187 |
|
| 188 |
-
# Count loaded layers
|
| 189 |
loaded_layers = set()
|
| 190 |
for shard in self._shards:
|
| 191 |
if shard.status == "loaded":
|
|
|
|
| 46 |
class LoadingTracker:
|
| 47 |
"""Tracks model loading progress."""
|
| 48 |
|
| 49 |
+
def __init__(self, model_path: Optional[str] = None, demo_mode: bool = True):
|
| 50 |
"""
|
| 51 |
Initialize loading tracker.
|
| 52 |
|
| 53 |
Args:
|
| 54 |
model_path: Path to model directory
|
| 55 |
+
demo_mode: Whether to use demo data
|
| 56 |
"""
|
| 57 |
self.model_path = model_path
|
| 58 |
self._shards: List[ShardInfo] = []
|
| 59 |
+
self._status = LoadingStatus.READY if demo_mode else LoadingStatus.NOT_STARTED
|
| 60 |
+
self._progress = 100.0 if demo_mode else 0.0
|
| 61 |
self._current_shard: Optional[str] = None
|
| 62 |
self._layers_loaded = 0
|
| 63 |
self._total_layers = 0
|
| 64 |
self._start_time: Optional[float] = None
|
| 65 |
+
self._demo_mode = demo_mode
|
| 66 |
+
|
| 67 |
+
if demo_mode:
|
| 68 |
+
self._init_demo_shards()
|
| 69 |
+
|
| 70 |
+
def _init_demo_shards(self) -> None:
|
| 71 |
+
"""Initialize demo shard data."""
|
| 72 |
+
self._shards = [
|
| 73 |
+
ShardInfo(
|
| 74 |
+
filename="model-00001-of-00004.safetensors",
|
| 75 |
+
size_mb=4850.2,
|
| 76 |
+
status="loaded",
|
| 77 |
+
layers=[f"model.layers.{i}" for i in range(8)],
|
| 78 |
+
),
|
| 79 |
+
ShardInfo(
|
| 80 |
+
filename="model-00002-of-00004.safetensors",
|
| 81 |
+
size_mb=4912.8,
|
| 82 |
+
status="loaded",
|
| 83 |
+
layers=[f"model.layers.{i}" for i in range(8, 16)],
|
| 84 |
+
),
|
| 85 |
+
ShardInfo(
|
| 86 |
+
filename="model-00003-of-00004.safetensors",
|
| 87 |
+
size_mb=4887.5,
|
| 88 |
+
status="loaded",
|
| 89 |
+
layers=[f"model.layers.{i}" for i in range(16, 24)],
|
| 90 |
+
),
|
| 91 |
+
ShardInfo(
|
| 92 |
+
filename="model-00004-of-00004.safetensors",
|
| 93 |
+
size_mb=4756.1,
|
| 94 |
+
status="loaded",
|
| 95 |
+
layers=[f"model.layers.{i}" for i in range(24, 32)],
|
| 96 |
+
),
|
| 97 |
+
]
|
| 98 |
+
self._total_layers = 32
|
| 99 |
+
self._layers_loaded = 32
|
| 100 |
|
| 101 |
def set_model_path(self, model_path: str) -> None:
|
| 102 |
"""Set or update the model path."""
|
| 103 |
self.model_path = model_path
|
| 104 |
+
if not self._demo_mode:
|
| 105 |
+
self._load_shard_info()
|
| 106 |
|
| 107 |
def _load_shard_info(self) -> None:
|
| 108 |
"""Load shard information from safetensors index."""
|
|
|
|
| 119 |
|
| 120 |
weight_map = index.get("weight_map", {})
|
| 121 |
|
|
|
|
| 122 |
shard_weights: Dict[str, List[str]] = {}
|
| 123 |
for weight_name, shard_file in weight_map.items():
|
| 124 |
if shard_file not in shard_weights:
|
| 125 |
shard_weights[shard_file] = []
|
| 126 |
shard_weights[shard_file].append(weight_name)
|
| 127 |
|
|
|
|
| 128 |
self._shards = []
|
| 129 |
for shard_file, weights in sorted(shard_weights.items()):
|
| 130 |
shard_path = self._resolve_path(shard_file)
|
|
|
|
| 132 |
if shard_path and shard_path.exists():
|
| 133 |
size_mb = shard_path.stat().st_size / (1024 * 1024)
|
| 134 |
|
|
|
|
| 135 |
layers = list(set(
|
| 136 |
".".join(w.split(".")[:3])
|
| 137 |
for w in weights
|
|
|
|
| 145 |
layers=layers,
|
| 146 |
))
|
| 147 |
|
|
|
|
| 148 |
all_layers = set()
|
| 149 |
for shard in self._shards:
|
| 150 |
all_layers.update(shard.layers)
|
|
|
|
| 165 |
return None
|
| 166 |
|
| 167 |
def update_from_log(self, log_line: str) -> None:
|
| 168 |
+
"""Update progress from a vLLM log line."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
if "Loading model" in log_line:
|
| 170 |
self._status = LoadingStatus.LOADING
|
| 171 |
import time
|
| 172 |
self._start_time = time.time()
|
| 173 |
|
|
|
|
| 174 |
match = re.search(r"Loading safetensors: (\d+)/(\d+)", log_line)
|
| 175 |
if match:
|
| 176 |
loaded = int(match.group(1))
|
|
|
|
| 183 |
shard.status = "loading"
|
| 184 |
self._current_shard = shard.filename
|
| 185 |
|
|
|
|
| 186 |
if "Model loaded" in log_line or "Running with" in log_line:
|
| 187 |
self._status = LoadingStatus.READY
|
| 188 |
self._progress = 100.0
|
| 189 |
for shard in self._shards:
|
| 190 |
shard.status = "loaded"
|
| 191 |
|
|
|
|
| 192 |
if "Error" in log_line or "Exception" in log_line:
|
| 193 |
self._status = LoadingStatus.ERROR
|
| 194 |
|
| 195 |
def get_progress(self) -> LoadingProgress:
|
| 196 |
+
"""Get current loading progress."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
loaded_shards = sum(1 for s in self._shards if s.status == "loaded")
|
| 198 |
total_shards = len(self._shards) if self._shards else 1
|
| 199 |
|
|
|
|
| 200 |
remaining = None
|
| 201 |
if self._start_time and self._progress > 0:
|
| 202 |
import time
|
| 203 |
elapsed = time.time() - self._start_time
|
| 204 |
remaining = (elapsed / self._progress) * (100 - self._progress)
|
| 205 |
|
|
|
|
| 206 |
loaded_layers = set()
|
| 207 |
for shard in self._shards:
|
| 208 |
if shard.status == "loaded":
|
collectors/quant_collector.py
CHANGED
|
@@ -19,7 +19,7 @@ class QuantizationInfo:
|
|
| 19 |
desc_act: Optional[bool] = None
|
| 20 |
sym: Optional[bool] = None
|
| 21 |
compute_dtype: Optional[str] = None
|
| 22 |
-
quant_type: Optional[str] = None
|
| 23 |
double_quant: Optional[bool] = None
|
| 24 |
raw_config: Dict[str, Any] = None
|
| 25 |
|
|
@@ -56,15 +56,17 @@ class LayerPrecision:
|
|
| 56 |
class QuantizationCollector:
|
| 57 |
"""Detects and collects quantization information from model configs."""
|
| 58 |
|
| 59 |
-
def __init__(self, model_path: Optional[str] = None):
|
| 60 |
"""
|
| 61 |
Initialize quantization collector.
|
| 62 |
|
| 63 |
Args:
|
| 64 |
model_path: Path to model directory (local or HF model ID)
|
|
|
|
| 65 |
"""
|
| 66 |
self.model_path = model_path
|
| 67 |
self._cached_info: Optional[QuantizationInfo] = None
|
|
|
|
| 68 |
|
| 69 |
def set_model_path(self, model_path: str) -> None:
|
| 70 |
"""Set or update the model path."""
|
|
@@ -72,19 +74,13 @@ class QuantizationCollector:
|
|
| 72 |
self._cached_info = None
|
| 73 |
|
| 74 |
def detect(self) -> QuantizationInfo:
|
| 75 |
-
"""
|
| 76 |
-
Detect quantization method and settings.
|
| 77 |
-
|
| 78 |
-
Returns:
|
| 79 |
-
QuantizationInfo with detected settings
|
| 80 |
-
"""
|
| 81 |
if self._cached_info is not None:
|
| 82 |
return self._cached_info
|
| 83 |
|
| 84 |
-
if not self.model_path:
|
| 85 |
-
return
|
| 86 |
|
| 87 |
-
# Try to load config files
|
| 88 |
config = self._load_config()
|
| 89 |
quant_config = self._load_quant_config()
|
| 90 |
|
|
@@ -92,6 +88,22 @@ class QuantizationCollector:
|
|
| 92 |
self._cached_info = info
|
| 93 |
return info
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def _load_config(self) -> Optional[Dict[str, Any]]:
|
| 96 |
"""Load config.json from model path."""
|
| 97 |
config_path = self._resolve_path("config.json")
|
|
@@ -119,21 +131,17 @@ class QuantizationCollector:
|
|
| 119 |
if not self.model_path:
|
| 120 |
return None
|
| 121 |
|
| 122 |
-
# Handle local paths
|
| 123 |
local_path = Path(self.model_path) / filename
|
| 124 |
if local_path.exists():
|
| 125 |
return local_path
|
| 126 |
|
| 127 |
-
# Handle HuggingFace cache paths
|
| 128 |
cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
|
| 129 |
if cache_dir.exists():
|
| 130 |
-
# Search for model in cache
|
| 131 |
for model_dir in cache_dir.glob("models--*"):
|
| 132 |
model_name = model_dir.name.replace("models--", "").replace("--", "/")
|
| 133 |
if model_name.lower() == self.model_path.lower().replace("/", "--"):
|
| 134 |
snapshot_path = model_dir / "snapshots"
|
| 135 |
if snapshot_path.exists():
|
| 136 |
-
# Get latest snapshot
|
| 137 |
snapshots = list(snapshot_path.iterdir())
|
| 138 |
if snapshots:
|
| 139 |
file_path = snapshots[-1] / filename
|
|
@@ -149,7 +157,6 @@ class QuantizationCollector:
|
|
| 149 |
) -> QuantizationInfo:
|
| 150 |
"""Detect quantization from config files."""
|
| 151 |
|
| 152 |
-
# Check for GPTQ via quantize_config.json
|
| 153 |
if quant_config:
|
| 154 |
if "bits" in quant_config:
|
| 155 |
return QuantizationInfo(
|
|
@@ -162,15 +169,13 @@ class QuantizationCollector:
|
|
| 162 |
)
|
| 163 |
|
| 164 |
if not config:
|
| 165 |
-
return
|
| 166 |
|
| 167 |
-
# Check for quantization_config in config.json
|
| 168 |
qc = config.get("quantization_config", {})
|
| 169 |
|
| 170 |
if qc:
|
| 171 |
quant_method = qc.get("quant_method", "").lower()
|
| 172 |
|
| 173 |
-
# AWQ
|
| 174 |
if quant_method == "awq":
|
| 175 |
return QuantizationInfo(
|
| 176 |
method="AWQ",
|
|
@@ -179,7 +184,6 @@ class QuantizationCollector:
|
|
| 179 |
raw_config=qc,
|
| 180 |
)
|
| 181 |
|
| 182 |
-
# GPTQ (in config.json)
|
| 183 |
if quant_method == "gptq":
|
| 184 |
return QuantizationInfo(
|
| 185 |
method="GPTQ",
|
|
@@ -190,7 +194,6 @@ class QuantizationCollector:
|
|
| 190 |
raw_config=qc,
|
| 191 |
)
|
| 192 |
|
| 193 |
-
# bitsandbytes
|
| 194 |
if qc.get("load_in_4bit") or qc.get("load_in_8bit"):
|
| 195 |
bits = 4 if qc.get("load_in_4bit") else 8
|
| 196 |
return QuantizationInfo(
|
|
@@ -202,7 +205,6 @@ class QuantizationCollector:
|
|
| 202 |
raw_config=qc,
|
| 203 |
)
|
| 204 |
|
| 205 |
-
# Check torch_dtype for fp16/bf16
|
| 206 |
torch_dtype = config.get("torch_dtype", "float16")
|
| 207 |
if torch_dtype in ("float16", "bfloat16"):
|
| 208 |
return QuantizationInfo(
|
|
@@ -211,19 +213,25 @@ class QuantizationCollector:
|
|
| 211 |
compute_dtype=torch_dtype,
|
| 212 |
)
|
| 213 |
|
| 214 |
-
return
|
| 215 |
|
| 216 |
def get_layer_precisions(self) -> List[LayerPrecision]:
|
| 217 |
-
"""
|
| 218 |
-
Get per-layer precision information.
|
| 219 |
-
|
| 220 |
-
Returns:
|
| 221 |
-
List of LayerPrecision for each layer
|
| 222 |
-
"""
|
| 223 |
info = self.detect()
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
index_path = self._resolve_path("model.safetensors.index.json")
|
| 229 |
if not index_path or not index_path.exists():
|
|
@@ -238,7 +246,6 @@ class QuantizationCollector:
|
|
| 238 |
seen_layers = set()
|
| 239 |
|
| 240 |
for weight_name in weight_map.keys():
|
| 241 |
-
# Extract layer name
|
| 242 |
parts = weight_name.split(".")
|
| 243 |
if len(parts) >= 3:
|
| 244 |
layer_name = ".".join(parts[:3])
|
|
|
|
| 19 |
desc_act: Optional[bool] = None
|
| 20 |
sym: Optional[bool] = None
|
| 21 |
compute_dtype: Optional[str] = None
|
| 22 |
+
quant_type: Optional[str] = None
|
| 23 |
double_quant: Optional[bool] = None
|
| 24 |
raw_config: Dict[str, Any] = None
|
| 25 |
|
|
|
|
| 56 |
class QuantizationCollector:
|
| 57 |
"""Detects and collects quantization information from model configs."""
|
| 58 |
|
| 59 |
+
def __init__(self, model_path: Optional[str] = None, demo_mode: bool = True):
|
| 60 |
"""
|
| 61 |
Initialize quantization collector.
|
| 62 |
|
| 63 |
Args:
|
| 64 |
model_path: Path to model directory (local or HF model ID)
|
| 65 |
+
demo_mode: Whether to use demo data
|
| 66 |
"""
|
| 67 |
self.model_path = model_path
|
| 68 |
self._cached_info: Optional[QuantizationInfo] = None
|
| 69 |
+
self._demo_mode = demo_mode
|
| 70 |
|
| 71 |
def set_model_path(self, model_path: str) -> None:
|
| 72 |
"""Set or update the model path."""
|
|
|
|
| 74 |
self._cached_info = None
|
| 75 |
|
| 76 |
def detect(self) -> QuantizationInfo:
|
| 77 |
+
"""Detect quantization method and settings."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if self._cached_info is not None:
|
| 79 |
return self._cached_info
|
| 80 |
|
| 81 |
+
if self._demo_mode or not self.model_path:
|
| 82 |
+
return self._get_demo_info()
|
| 83 |
|
|
|
|
| 84 |
config = self._load_config()
|
| 85 |
quant_config = self._load_quant_config()
|
| 86 |
|
|
|
|
| 88 |
self._cached_info = info
|
| 89 |
return info
|
| 90 |
|
| 91 |
+
def _get_demo_info(self) -> QuantizationInfo:
|
| 92 |
+
"""Return demo quantization info."""
|
| 93 |
+
return QuantizationInfo(
|
| 94 |
+
method="AWQ",
|
| 95 |
+
bits=4,
|
| 96 |
+
group_size=128,
|
| 97 |
+
compute_dtype="float16",
|
| 98 |
+
raw_config={
|
| 99 |
+
"quant_method": "awq",
|
| 100 |
+
"bits": 4,
|
| 101 |
+
"group_size": 128,
|
| 102 |
+
"zero_point": True,
|
| 103 |
+
"version": "GEMM"
|
| 104 |
+
},
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
def _load_config(self) -> Optional[Dict[str, Any]]:
|
| 108 |
"""Load config.json from model path."""
|
| 109 |
config_path = self._resolve_path("config.json")
|
|
|
|
| 131 |
if not self.model_path:
|
| 132 |
return None
|
| 133 |
|
|
|
|
| 134 |
local_path = Path(self.model_path) / filename
|
| 135 |
if local_path.exists():
|
| 136 |
return local_path
|
| 137 |
|
|
|
|
| 138 |
cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
|
| 139 |
if cache_dir.exists():
|
|
|
|
| 140 |
for model_dir in cache_dir.glob("models--*"):
|
| 141 |
model_name = model_dir.name.replace("models--", "").replace("--", "/")
|
| 142 |
if model_name.lower() == self.model_path.lower().replace("/", "--"):
|
| 143 |
snapshot_path = model_dir / "snapshots"
|
| 144 |
if snapshot_path.exists():
|
|
|
|
| 145 |
snapshots = list(snapshot_path.iterdir())
|
| 146 |
if snapshots:
|
| 147 |
file_path = snapshots[-1] / filename
|
|
|
|
| 157 |
) -> QuantizationInfo:
|
| 158 |
"""Detect quantization from config files."""
|
| 159 |
|
|
|
|
| 160 |
if quant_config:
|
| 161 |
if "bits" in quant_config:
|
| 162 |
return QuantizationInfo(
|
|
|
|
| 169 |
)
|
| 170 |
|
| 171 |
if not config:
|
| 172 |
+
return self._get_demo_info()
|
| 173 |
|
|
|
|
| 174 |
qc = config.get("quantization_config", {})
|
| 175 |
|
| 176 |
if qc:
|
| 177 |
quant_method = qc.get("quant_method", "").lower()
|
| 178 |
|
|
|
|
| 179 |
if quant_method == "awq":
|
| 180 |
return QuantizationInfo(
|
| 181 |
method="AWQ",
|
|
|
|
| 184 |
raw_config=qc,
|
| 185 |
)
|
| 186 |
|
|
|
|
| 187 |
if quant_method == "gptq":
|
| 188 |
return QuantizationInfo(
|
| 189 |
method="GPTQ",
|
|
|
|
| 194 |
raw_config=qc,
|
| 195 |
)
|
| 196 |
|
|
|
|
| 197 |
if qc.get("load_in_4bit") or qc.get("load_in_8bit"):
|
| 198 |
bits = 4 if qc.get("load_in_4bit") else 8
|
| 199 |
return QuantizationInfo(
|
|
|
|
| 205 |
raw_config=qc,
|
| 206 |
)
|
| 207 |
|
|
|
|
| 208 |
torch_dtype = config.get("torch_dtype", "float16")
|
| 209 |
if torch_dtype in ("float16", "bfloat16"):
|
| 210 |
return QuantizationInfo(
|
|
|
|
| 213 |
compute_dtype=torch_dtype,
|
| 214 |
)
|
| 215 |
|
| 216 |
+
return self._get_demo_info()
|
| 217 |
|
| 218 |
def get_layer_precisions(self) -> List[LayerPrecision]:
|
| 219 |
+
"""Get per-layer precision information."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
info = self.detect()
|
| 221 |
|
| 222 |
+
if self._demo_mode:
|
| 223 |
+
# Return demo layer data
|
| 224 |
+
layers = []
|
| 225 |
+
for i in range(32):
|
| 226 |
+
layers.append(
|
| 227 |
+
LayerPrecision(
|
| 228 |
+
layer_name=f"model.layers.{i}",
|
| 229 |
+
bits=info.bits,
|
| 230 |
+
group_size=info.group_size,
|
| 231 |
+
dtype="float16",
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
return layers
|
| 235 |
|
| 236 |
index_path = self._resolve_path("model.safetensors.index.json")
|
| 237 |
if not index_path or not index_path.exists():
|
|
|
|
| 246 |
seen_layers = set()
|
| 247 |
|
| 248 |
for weight_name in weight_map.keys():
|
|
|
|
| 249 |
parts = weight_name.split(".")
|
| 250 |
if len(parts) >= 3:
|
| 251 |
layer_name = ".".join(parts[:3])
|
collectors/vllm_collector.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""vLLM metrics collector via Prometheus endpoint."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import requests
|
| 4 |
import logging
|
| 5 |
from dataclasses import dataclass, field
|
|
@@ -53,41 +55,49 @@ class InferenceMetrics:
|
|
| 53 |
class VLLMCollector:
|
| 54 |
"""Collects metrics from vLLM Prometheus endpoint."""
|
| 55 |
|
| 56 |
-
def __init__(self, metrics_url: str = "http://localhost:8000/metrics"):
|
| 57 |
"""
|
| 58 |
Initialize the vLLM collector.
|
| 59 |
|
| 60 |
Args:
|
| 61 |
metrics_url: URL to vLLM's /metrics endpoint
|
|
|
|
| 62 |
"""
|
| 63 |
self.metrics_url = metrics_url
|
| 64 |
self._last_prompt_tokens = 0
|
| 65 |
self._last_generation_tokens = 0
|
| 66 |
self._last_collect_time: Optional[datetime] = None
|
| 67 |
self._connected = False
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def check_connection(self) -> bool:
|
| 70 |
"""Check if vLLM server is accessible."""
|
|
|
|
|
|
|
|
|
|
| 71 |
try:
|
| 72 |
response = requests.get(self.metrics_url, timeout=2)
|
| 73 |
self._connected = response.status_code == 200
|
|
|
|
|
|
|
| 74 |
return self._connected
|
| 75 |
except Exception:
|
| 76 |
self._connected = False
|
|
|
|
| 77 |
return False
|
| 78 |
|
| 79 |
@property
|
| 80 |
def is_connected(self) -> bool:
|
| 81 |
"""Return connection status."""
|
| 82 |
-
return self._connected
|
| 83 |
|
| 84 |
def collect(self) -> InferenceMetrics:
|
| 85 |
-
"""
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
Returns:
|
| 89 |
-
InferenceMetrics dataclass with current values
|
| 90 |
-
"""
|
| 91 |
metrics = InferenceMetrics()
|
| 92 |
|
| 93 |
try:
|
|
@@ -100,12 +110,59 @@ class VLLMCollector:
|
|
| 100 |
|
| 101 |
except requests.exceptions.ConnectionError:
|
| 102 |
self._connected = False
|
| 103 |
-
|
|
|
|
| 104 |
except Exception as e:
|
| 105 |
logger.error(f"Error collecting vLLM metrics: {e}")
|
|
|
|
| 106 |
|
| 107 |
return metrics
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def _parse_metrics(self, raw: Dict[str, List[MetricSample]]) -> InferenceMetrics:
|
| 110 |
"""Parse raw Prometheus metrics into InferenceMetrics."""
|
| 111 |
now = datetime.now()
|
|
@@ -179,7 +236,6 @@ class VLLMCollector:
|
|
| 179 |
|
| 180 |
def _get_model_name(self, raw: Dict[str, List[MetricSample]]) -> Optional[str]:
|
| 181 |
"""Extract model name from metrics labels."""
|
| 182 |
-
# Look for model name in any metric with model_name label
|
| 183 |
for metric_name, samples in raw.items():
|
| 184 |
for sample in samples:
|
| 185 |
if "model_name" in sample.labels:
|
|
@@ -187,23 +243,33 @@ class VLLMCollector:
|
|
| 187 |
return None
|
| 188 |
|
| 189 |
def get_rank_mapping(self) -> Dict[int, int]:
|
| 190 |
-
"""
|
| 191 |
-
Get tensor parallel rank to GPU mapping.
|
| 192 |
-
|
| 193 |
-
Returns:
|
| 194 |
-
Dictionary mapping TP rank to GPU ID
|
| 195 |
-
"""
|
| 196 |
-
# This would typically come from vLLM's internal state
|
| 197 |
-
# For now, return empty mapping - can be extended
|
| 198 |
return {}
|
| 199 |
|
| 200 |
def get_latency_percentiles(self) -> Dict[str, Dict[str, float]]:
|
| 201 |
-
"""
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
Returns:
|
| 205 |
-
Dictionary with P50, P95, P99 for each latency metric
|
| 206 |
-
"""
|
| 207 |
try:
|
| 208 |
response = requests.get(self.metrics_url, timeout=5)
|
| 209 |
raw = parse_prometheus_metrics(response.text)
|
|
|
|
| 1 |
"""vLLM metrics collector via Prometheus endpoint."""
|
| 2 |
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
import requests
|
| 6 |
import logging
|
| 7 |
from dataclasses import dataclass, field
|
|
|
|
| 55 |
class VLLMCollector:
|
| 56 |
"""Collects metrics from vLLM Prometheus endpoint."""
|
| 57 |
|
| 58 |
+
def __init__(self, metrics_url: str = "http://localhost:8000/metrics", demo_mode: bool = True):
|
| 59 |
"""
|
| 60 |
Initialize the vLLM collector.
|
| 61 |
|
| 62 |
Args:
|
| 63 |
metrics_url: URL to vLLM's /metrics endpoint
|
| 64 |
+
demo_mode: Whether to use simulated demo data
|
| 65 |
"""
|
| 66 |
self.metrics_url = metrics_url
|
| 67 |
self._last_prompt_tokens = 0
|
| 68 |
self._last_generation_tokens = 0
|
| 69 |
self._last_collect_time: Optional[datetime] = None
|
| 70 |
self._connected = False
|
| 71 |
+
self._demo_mode = demo_mode
|
| 72 |
+
self._demo_start_time = time.time()
|
| 73 |
+
self._demo_total_tokens = 0
|
| 74 |
|
| 75 |
def check_connection(self) -> bool:
|
| 76 |
"""Check if vLLM server is accessible."""
|
| 77 |
+
if self._demo_mode:
|
| 78 |
+
return True # Demo mode always "connected"
|
| 79 |
+
|
| 80 |
try:
|
| 81 |
response = requests.get(self.metrics_url, timeout=2)
|
| 82 |
self._connected = response.status_code == 200
|
| 83 |
+
if self._connected:
|
| 84 |
+
self._demo_mode = False
|
| 85 |
return self._connected
|
| 86 |
except Exception:
|
| 87 |
self._connected = False
|
| 88 |
+
self._demo_mode = True
|
| 89 |
return False
|
| 90 |
|
| 91 |
@property
|
| 92 |
def is_connected(self) -> bool:
|
| 93 |
"""Return connection status."""
|
| 94 |
+
return self._connected or self._demo_mode
|
| 95 |
|
| 96 |
def collect(self) -> InferenceMetrics:
|
| 97 |
+
"""Collect all inference metrics from vLLM."""
|
| 98 |
+
if self._demo_mode:
|
| 99 |
+
return self._get_demo_metrics()
|
| 100 |
|
|
|
|
|
|
|
|
|
|
| 101 |
metrics = InferenceMetrics()
|
| 102 |
|
| 103 |
try:
|
|
|
|
| 110 |
|
| 111 |
except requests.exceptions.ConnectionError:
|
| 112 |
self._connected = False
|
| 113 |
+
self._demo_mode = True
|
| 114 |
+
return self._get_demo_metrics()
|
| 115 |
except Exception as e:
|
| 116 |
logger.error(f"Error collecting vLLM metrics: {e}")
|
| 117 |
+
return self._get_demo_metrics()
|
| 118 |
|
| 119 |
return metrics
|
| 120 |
|
| 121 |
+
def _get_demo_metrics(self) -> InferenceMetrics:
|
| 122 |
+
"""Generate realistic demo metrics."""
|
| 123 |
+
elapsed = time.time() - self._demo_start_time
|
| 124 |
+
now = datetime.now()
|
| 125 |
+
|
| 126 |
+
# Simulate varying load
|
| 127 |
+
load_factor = 0.5 + 0.3 * abs((elapsed % 30) - 15) / 15 # 0.5-0.8
|
| 128 |
+
|
| 129 |
+
# Simulate token generation
|
| 130 |
+
tokens_this_second = int(45 * load_factor + random.uniform(-5, 5))
|
| 131 |
+
self._demo_total_tokens += tokens_this_second
|
| 132 |
+
|
| 133 |
+
# Batch size varies with load
|
| 134 |
+
batch_size = int(4 + 8 * load_factor + random.randint(-1, 1))
|
| 135 |
+
|
| 136 |
+
# Queue depth
|
| 137 |
+
queue_depth = int(max(0, (load_factor - 0.6) * 20 + random.randint(-2, 2)))
|
| 138 |
+
|
| 139 |
+
# KV cache usage correlates with batch size
|
| 140 |
+
kv_cache = 35 + batch_size * 4 + random.uniform(-3, 3)
|
| 141 |
+
|
| 142 |
+
# Latencies
|
| 143 |
+
base_ttft = 80 + (1 - load_factor) * 40 # Lower load = faster
|
| 144 |
+
base_e2e = 800 + batch_size * 50
|
| 145 |
+
|
| 146 |
+
return InferenceMetrics(
|
| 147 |
+
timestamp=now,
|
| 148 |
+
num_requests_running=batch_size,
|
| 149 |
+
num_requests_waiting=queue_depth,
|
| 150 |
+
num_requests_swapped=0,
|
| 151 |
+
prompt_tokens_total=int(self._demo_total_tokens * 0.3),
|
| 152 |
+
generation_tokens_total=int(self._demo_total_tokens * 0.7),
|
| 153 |
+
tokens_per_second=tokens_this_second + random.uniform(-3, 3),
|
| 154 |
+
ttft_ms=base_ttft + random.uniform(-10, 20),
|
| 155 |
+
tpot_ms=22 + random.uniform(-2, 3),
|
| 156 |
+
e2e_latency_ms=base_e2e + random.uniform(-50, 100),
|
| 157 |
+
kv_cache_usage_percent=min(95, kv_cache),
|
| 158 |
+
gpu_cache_usage_percent=min(95, kv_cache),
|
| 159 |
+
cpu_cache_usage_percent=0,
|
| 160 |
+
model_name="Qwen/Qwen2.5-3B-Instruct",
|
| 161 |
+
max_model_len=4096,
|
| 162 |
+
prefill_ratio=0.3 + random.uniform(-0.05, 0.05),
|
| 163 |
+
batch_size=batch_size,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
def _parse_metrics(self, raw: Dict[str, List[MetricSample]]) -> InferenceMetrics:
|
| 167 |
"""Parse raw Prometheus metrics into InferenceMetrics."""
|
| 168 |
now = datetime.now()
|
|
|
|
| 236 |
|
| 237 |
def _get_model_name(self, raw: Dict[str, List[MetricSample]]) -> Optional[str]:
|
| 238 |
"""Extract model name from metrics labels."""
|
|
|
|
| 239 |
for metric_name, samples in raw.items():
|
| 240 |
for sample in samples:
|
| 241 |
if "model_name" in sample.labels:
|
|
|
|
| 243 |
return None
|
| 244 |
|
| 245 |
def get_rank_mapping(self) -> Dict[int, int]:
|
| 246 |
+
"""Get tensor parallel rank to GPU mapping."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
return {}
|
| 248 |
|
| 249 |
def get_latency_percentiles(self) -> Dict[str, Dict[str, float]]:
|
| 250 |
+
"""Get latency percentiles for detailed analysis."""
|
| 251 |
+
if self._demo_mode:
|
| 252 |
+
base_ttft = 90
|
| 253 |
+
base_tpot = 22
|
| 254 |
+
base_e2e = 900
|
| 255 |
+
return {
|
| 256 |
+
"vllm:time_to_first_token_seconds": {
|
| 257 |
+
"p50": base_ttft,
|
| 258 |
+
"p95": base_ttft * 1.8,
|
| 259 |
+
"p99": base_ttft * 2.5,
|
| 260 |
+
},
|
| 261 |
+
"vllm:time_per_output_token_seconds": {
|
| 262 |
+
"p50": base_tpot,
|
| 263 |
+
"p95": base_tpot * 1.5,
|
| 264 |
+
"p99": base_tpot * 2.0,
|
| 265 |
+
},
|
| 266 |
+
"vllm:e2e_request_latency_seconds": {
|
| 267 |
+
"p50": base_e2e,
|
| 268 |
+
"p95": base_e2e * 1.6,
|
| 269 |
+
"p99": base_e2e * 2.2,
|
| 270 |
+
},
|
| 271 |
+
}
|
| 272 |
|
|
|
|
|
|
|
|
|
|
| 273 |
try:
|
| 274 |
response = requests.get(self.metrics_url, timeout=5)
|
| 275 |
raw = parse_prometheus_metrics(response.text)
|