jkottu's picture
Initial commit: LLM Inference Dashboard
aefabf0
"""Alert engine and webhook dispatch for monitoring thresholds."""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
import aiohttp
from storage.database import MetricsDB
from storage.models import AlertRecord
logger = logging.getLogger(__name__)
class AlertSeverity(Enum):
INFO = "info"
WARNING = "warning"
CRITICAL = "critical"
@dataclass
class AlertRule:
"""Configuration for an alert rule."""
name: str
metric: str
condition: str # >, <, >=, <=, ==
threshold: float
severity: AlertSeverity
message: str
# For dynamic thresholds
threshold_type: str = "static" # static, baseline_multiplier, baseline_percent
multiplier: float = 1.0
percent: float = 100.0
cooldown_seconds: int = 60
@dataclass
class Alert:
"""A triggered alert instance."""
rule_name: str
metric: str
value: float
threshold: float
severity: AlertSeverity
message: str
timestamp: datetime = field(default_factory=datetime.now)
resolved: bool = False
def to_dict(self) -> Dict[str, Any]:
return {
"rule_name": self.rule_name,
"metric": self.metric,
"value": self.value,
"threshold": self.threshold,
"severity": self.severity.value,
"message": self.message,
"timestamp": self.timestamp.isoformat(),
"resolved": self.resolved,
}
# Default alert rules
DEFAULT_RULES = {
"kv_cache_high": AlertRule(
name="kv_cache_high",
metric="kv_cache_percent",
condition=">",
threshold=90.0,
severity=AlertSeverity.WARNING,
message="KV cache utilization above 90%",
),
"gpu_memory_critical": AlertRule(
name="gpu_memory_critical",
metric="gpu_memory_percent",
condition=">",
threshold=95.0,
severity=AlertSeverity.CRITICAL,
message="GPU memory critically high (>95%)",
),
"ttft_spike": AlertRule(
name="ttft_spike",
metric="ttft_ms",
condition=">",
threshold=0, # Dynamic
threshold_type="baseline_multiplier",
multiplier=2.0,
severity=AlertSeverity.WARNING,
message="Time to first token spiked to 2x baseline",
),
"throughput_drop": AlertRule(
name="throughput_drop",
metric="tokens_per_second",
condition="<",
threshold=0, # Dynamic
threshold_type="baseline_percent",
percent=50.0,
severity=AlertSeverity.WARNING,
message="Throughput dropped below 50% of baseline",
),
"queue_buildup": AlertRule(
name="queue_buildup",
metric="queue_depth",
condition=">",
threshold=50.0,
severity=AlertSeverity.WARNING,
message="Request queue depth exceeds 50",
),
}
class AlertEngine:
"""Evaluates metrics against alert rules."""
def __init__(self, db: Optional[MetricsDB] = None):
"""
Initialize alert engine.
Args:
db: Optional database for persisting alerts
"""
self.db = db
self.rules: Dict[str, AlertRule] = dict(DEFAULT_RULES)
self.active_alerts: Dict[str, Alert] = {}
self.baselines: Dict[str, float] = {}
self._last_trigger_times: Dict[str, datetime] = {}
self._callbacks: List[Callable[[Alert], None]] = []
def add_rule(self, rule: AlertRule) -> None:
"""Add or update an alert rule."""
self.rules[rule.name] = rule
def remove_rule(self, name: str) -> None:
"""Remove an alert rule."""
self.rules.pop(name, None)
def set_baseline(self, metric: str, value: float) -> None:
"""Set baseline value for a metric."""
self.baselines[metric] = value
def update_baselines(self, metrics: Dict[str, float]) -> None:
"""Update baseline values from current metrics."""
for metric, value in metrics.items():
if metric not in self.baselines and value > 0:
self.baselines[metric] = value
def on_alert(self, callback: Callable[[Alert], None]) -> None:
"""Register callback for new alerts."""
self._callbacks.append(callback)
def evaluate(self, metrics: Dict[str, float]) -> List[Alert]:
"""
Evaluate metrics against all rules.
Args:
metrics: Current metric values
Returns:
List of newly triggered alerts
"""
new_alerts = []
for rule_name, rule in self.rules.items():
if rule.metric not in metrics:
continue
value = metrics[rule.metric]
threshold = self._get_threshold(rule)
if threshold is None:
continue
triggered = self._check_condition(value, rule.condition, threshold)
if triggered:
# Check cooldown
if rule_name in self._last_trigger_times:
elapsed = (
datetime.now() - self._last_trigger_times[rule_name]
).total_seconds()
if elapsed < rule.cooldown_seconds:
continue
# Create alert
alert = Alert(
rule_name=rule_name,
metric=rule.metric,
value=value,
threshold=threshold,
severity=rule.severity,
message=rule.message,
)
self.active_alerts[rule_name] = alert
self._last_trigger_times[rule_name] = datetime.now()
new_alerts.append(alert)
# Persist to database
if self.db:
record = AlertRecord(
rule_name=rule_name,
severity=rule.severity.value,
metric_name=rule.metric,
value=value,
threshold=threshold,
message=rule.message,
)
self.db.insert_alert(record)
# Notify callbacks
for callback in self._callbacks:
try:
callback(alert)
except Exception as e:
logger.error(f"Alert callback error: {e}")
elif rule_name in self.active_alerts:
# Resolve alert
self.active_alerts[rule_name].resolved = True
del self.active_alerts[rule_name]
return new_alerts
def _get_threshold(self, rule: AlertRule) -> Optional[float]:
"""Calculate threshold for a rule."""
if rule.threshold_type == "static":
return rule.threshold
baseline = self.baselines.get(rule.metric)
if baseline is None:
return None
if rule.threshold_type == "baseline_multiplier":
return baseline * rule.multiplier
if rule.threshold_type == "baseline_percent":
return baseline * (rule.percent / 100.0)
return rule.threshold
def _check_condition(
self, value: float, condition: str, threshold: float
) -> bool:
"""Check if condition is met."""
if condition == ">":
return value > threshold
if condition == ">=":
return value >= threshold
if condition == "<":
return value < threshold
if condition == "<=":
return value <= threshold
if condition == "==":
return abs(value - threshold) < 0.001
return False
def get_active_alerts(self) -> List[Alert]:
"""Get all active (unresolved) alerts."""
return list(self.active_alerts.values())
def clear_alert(self, rule_name: str) -> None:
"""Manually clear an alert."""
if rule_name in self.active_alerts:
del self.active_alerts[rule_name]
class AlertDispatcher:
"""Dispatches alerts to external services."""
def __init__(
self,
slack_webhook: Optional[str] = None,
pagerduty_key: Optional[str] = None,
generic_webhooks: Optional[List[str]] = None,
):
"""
Initialize alert dispatcher.
Args:
slack_webhook: Slack incoming webhook URL
pagerduty_key: PagerDuty routing key
generic_webhooks: List of generic webhook URLs
"""
self.slack_webhook = slack_webhook
self.pagerduty_key = pagerduty_key
self.generic_webhooks = generic_webhooks or []
async def dispatch(self, alert: Alert) -> None:
"""
Dispatch alert to all configured services.
Args:
alert: Alert to dispatch
"""
tasks = []
if self.slack_webhook:
tasks.append(self._send_slack(alert))
if self.pagerduty_key and alert.severity == AlertSeverity.CRITICAL:
tasks.append(self._send_pagerduty(alert))
for webhook in self.generic_webhooks:
tasks.append(self._send_generic(webhook, alert))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _send_slack(self, alert: Alert) -> None:
"""Send alert to Slack."""
color = "danger" if alert.severity == AlertSeverity.CRITICAL else "warning"
emoji = "🚨" if alert.severity == AlertSeverity.CRITICAL else "⚠️"
payload = {
"text": f"{emoji} *{alert.severity.value.upper()}*: {alert.message}",
"attachments": [
{
"color": color,
"fields": [
{
"title": "Metric",
"value": alert.metric,
"short": True,
},
{
"title": "Value",
"value": f"{alert.value:.2f}",
"short": True,
},
{
"title": "Threshold",
"value": f"{alert.threshold:.2f}",
"short": True,
},
{
"title": "Time",
"value": alert.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
"short": True,
},
],
}
],
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
self.slack_webhook,
json=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as response:
if response.status != 200:
logger.error(f"Slack webhook failed: {response.status}")
except Exception as e:
logger.error(f"Error sending Slack alert: {e}")
async def _send_pagerduty(self, alert: Alert) -> None:
"""Send alert to PagerDuty."""
payload = {
"routing_key": self.pagerduty_key,
"event_action": "trigger",
"dedup_key": f"llm-dashboard-{alert.rule_name}",
"payload": {
"summary": alert.message,
"severity": "critical",
"source": "llm-inference-dashboard",
"custom_details": {
"metric": alert.metric,
"value": alert.value,
"threshold": alert.threshold,
},
},
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
"https://events.pagerduty.com/v2/enqueue",
json=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as response:
if response.status != 202:
logger.error(f"PagerDuty failed: {response.status}")
except Exception as e:
logger.error(f"Error sending PagerDuty alert: {e}")
async def _send_generic(self, webhook_url: str, alert: Alert) -> None:
"""Send alert to a generic webhook."""
payload = alert.to_dict()
try:
async with aiohttp.ClientSession() as session:
async with session.post(
webhook_url,
json=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as response:
if response.status >= 400:
logger.error(f"Webhook {webhook_url} failed: {response.status}")
except Exception as e:
logger.error(f"Error sending to webhook {webhook_url}: {e}")
async def send_test_alert(self) -> bool:
"""Send a test alert to verify configuration."""
test_alert = Alert(
rule_name="test_alert",
metric="test_metric",
value=100.0,
threshold=50.0,
severity=AlertSeverity.INFO,
message="This is a test alert from LLM Inference Dashboard",
)
try:
await self.dispatch(test_alert)
return True
except Exception as e:
logger.error(f"Test alert failed: {e}")
return False