Spaces:
Sleeping
Sleeping
| """Alert engine and webhook dispatch for monitoring thresholds.""" | |
| import asyncio | |
| import logging | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Any, Callable | |
| from enum import Enum | |
| import aiohttp | |
| from storage.database import MetricsDB | |
| from storage.models import AlertRecord | |
| logger = logging.getLogger(__name__) | |
| class AlertSeverity(Enum): | |
| INFO = "info" | |
| WARNING = "warning" | |
| CRITICAL = "critical" | |
| class AlertRule: | |
| """Configuration for an alert rule.""" | |
| name: str | |
| metric: str | |
| condition: str # >, <, >=, <=, == | |
| threshold: float | |
| severity: AlertSeverity | |
| message: str | |
| # For dynamic thresholds | |
| threshold_type: str = "static" # static, baseline_multiplier, baseline_percent | |
| multiplier: float = 1.0 | |
| percent: float = 100.0 | |
| cooldown_seconds: int = 60 | |
| class Alert: | |
| """A triggered alert instance.""" | |
| rule_name: str | |
| metric: str | |
| value: float | |
| threshold: float | |
| severity: AlertSeverity | |
| message: str | |
| timestamp: datetime = field(default_factory=datetime.now) | |
| resolved: bool = False | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "rule_name": self.rule_name, | |
| "metric": self.metric, | |
| "value": self.value, | |
| "threshold": self.threshold, | |
| "severity": self.severity.value, | |
| "message": self.message, | |
| "timestamp": self.timestamp.isoformat(), | |
| "resolved": self.resolved, | |
| } | |
| # Default alert rules | |
| DEFAULT_RULES = { | |
| "kv_cache_high": AlertRule( | |
| name="kv_cache_high", | |
| metric="kv_cache_percent", | |
| condition=">", | |
| threshold=90.0, | |
| severity=AlertSeverity.WARNING, | |
| message="KV cache utilization above 90%", | |
| ), | |
| "gpu_memory_critical": AlertRule( | |
| name="gpu_memory_critical", | |
| metric="gpu_memory_percent", | |
| condition=">", | |
| threshold=95.0, | |
| severity=AlertSeverity.CRITICAL, | |
| message="GPU memory critically high (>95%)", | |
| ), | |
| "ttft_spike": AlertRule( | |
| name="ttft_spike", | |
| metric="ttft_ms", | |
| condition=">", | |
| threshold=0, # Dynamic | |
| threshold_type="baseline_multiplier", | |
| multiplier=2.0, | |
| severity=AlertSeverity.WARNING, | |
| message="Time to first token spiked to 2x baseline", | |
| ), | |
| "throughput_drop": AlertRule( | |
| name="throughput_drop", | |
| metric="tokens_per_second", | |
| condition="<", | |
| threshold=0, # Dynamic | |
| threshold_type="baseline_percent", | |
| percent=50.0, | |
| severity=AlertSeverity.WARNING, | |
| message="Throughput dropped below 50% of baseline", | |
| ), | |
| "queue_buildup": AlertRule( | |
| name="queue_buildup", | |
| metric="queue_depth", | |
| condition=">", | |
| threshold=50.0, | |
| severity=AlertSeverity.WARNING, | |
| message="Request queue depth exceeds 50", | |
| ), | |
| } | |
| class AlertEngine: | |
| """Evaluates metrics against alert rules.""" | |
| def __init__(self, db: Optional[MetricsDB] = None): | |
| """ | |
| Initialize alert engine. | |
| Args: | |
| db: Optional database for persisting alerts | |
| """ | |
| self.db = db | |
| self.rules: Dict[str, AlertRule] = dict(DEFAULT_RULES) | |
| self.active_alerts: Dict[str, Alert] = {} | |
| self.baselines: Dict[str, float] = {} | |
| self._last_trigger_times: Dict[str, datetime] = {} | |
| self._callbacks: List[Callable[[Alert], None]] = [] | |
| def add_rule(self, rule: AlertRule) -> None: | |
| """Add or update an alert rule.""" | |
| self.rules[rule.name] = rule | |
| def remove_rule(self, name: str) -> None: | |
| """Remove an alert rule.""" | |
| self.rules.pop(name, None) | |
| def set_baseline(self, metric: str, value: float) -> None: | |
| """Set baseline value for a metric.""" | |
| self.baselines[metric] = value | |
| def update_baselines(self, metrics: Dict[str, float]) -> None: | |
| """Update baseline values from current metrics.""" | |
| for metric, value in metrics.items(): | |
| if metric not in self.baselines and value > 0: | |
| self.baselines[metric] = value | |
| def on_alert(self, callback: Callable[[Alert], None]) -> None: | |
| """Register callback for new alerts.""" | |
| self._callbacks.append(callback) | |
| def evaluate(self, metrics: Dict[str, float]) -> List[Alert]: | |
| """ | |
| Evaluate metrics against all rules. | |
| Args: | |
| metrics: Current metric values | |
| Returns: | |
| List of newly triggered alerts | |
| """ | |
| new_alerts = [] | |
| for rule_name, rule in self.rules.items(): | |
| if rule.metric not in metrics: | |
| continue | |
| value = metrics[rule.metric] | |
| threshold = self._get_threshold(rule) | |
| if threshold is None: | |
| continue | |
| triggered = self._check_condition(value, rule.condition, threshold) | |
| if triggered: | |
| # Check cooldown | |
| if rule_name in self._last_trigger_times: | |
| elapsed = ( | |
| datetime.now() - self._last_trigger_times[rule_name] | |
| ).total_seconds() | |
| if elapsed < rule.cooldown_seconds: | |
| continue | |
| # Create alert | |
| alert = Alert( | |
| rule_name=rule_name, | |
| metric=rule.metric, | |
| value=value, | |
| threshold=threshold, | |
| severity=rule.severity, | |
| message=rule.message, | |
| ) | |
| self.active_alerts[rule_name] = alert | |
| self._last_trigger_times[rule_name] = datetime.now() | |
| new_alerts.append(alert) | |
| # Persist to database | |
| if self.db: | |
| record = AlertRecord( | |
| rule_name=rule_name, | |
| severity=rule.severity.value, | |
| metric_name=rule.metric, | |
| value=value, | |
| threshold=threshold, | |
| message=rule.message, | |
| ) | |
| self.db.insert_alert(record) | |
| # Notify callbacks | |
| for callback in self._callbacks: | |
| try: | |
| callback(alert) | |
| except Exception as e: | |
| logger.error(f"Alert callback error: {e}") | |
| elif rule_name in self.active_alerts: | |
| # Resolve alert | |
| self.active_alerts[rule_name].resolved = True | |
| del self.active_alerts[rule_name] | |
| return new_alerts | |
| def _get_threshold(self, rule: AlertRule) -> Optional[float]: | |
| """Calculate threshold for a rule.""" | |
| if rule.threshold_type == "static": | |
| return rule.threshold | |
| baseline = self.baselines.get(rule.metric) | |
| if baseline is None: | |
| return None | |
| if rule.threshold_type == "baseline_multiplier": | |
| return baseline * rule.multiplier | |
| if rule.threshold_type == "baseline_percent": | |
| return baseline * (rule.percent / 100.0) | |
| return rule.threshold | |
| def _check_condition( | |
| self, value: float, condition: str, threshold: float | |
| ) -> bool: | |
| """Check if condition is met.""" | |
| if condition == ">": | |
| return value > threshold | |
| if condition == ">=": | |
| return value >= threshold | |
| if condition == "<": | |
| return value < threshold | |
| if condition == "<=": | |
| return value <= threshold | |
| if condition == "==": | |
| return abs(value - threshold) < 0.001 | |
| return False | |
| def get_active_alerts(self) -> List[Alert]: | |
| """Get all active (unresolved) alerts.""" | |
| return list(self.active_alerts.values()) | |
| def clear_alert(self, rule_name: str) -> None: | |
| """Manually clear an alert.""" | |
| if rule_name in self.active_alerts: | |
| del self.active_alerts[rule_name] | |
| class AlertDispatcher: | |
| """Dispatches alerts to external services.""" | |
| def __init__( | |
| self, | |
| slack_webhook: Optional[str] = None, | |
| pagerduty_key: Optional[str] = None, | |
| generic_webhooks: Optional[List[str]] = None, | |
| ): | |
| """ | |
| Initialize alert dispatcher. | |
| Args: | |
| slack_webhook: Slack incoming webhook URL | |
| pagerduty_key: PagerDuty routing key | |
| generic_webhooks: List of generic webhook URLs | |
| """ | |
| self.slack_webhook = slack_webhook | |
| self.pagerduty_key = pagerduty_key | |
| self.generic_webhooks = generic_webhooks or [] | |
| async def dispatch(self, alert: Alert) -> None: | |
| """ | |
| Dispatch alert to all configured services. | |
| Args: | |
| alert: Alert to dispatch | |
| """ | |
| tasks = [] | |
| if self.slack_webhook: | |
| tasks.append(self._send_slack(alert)) | |
| if self.pagerduty_key and alert.severity == AlertSeverity.CRITICAL: | |
| tasks.append(self._send_pagerduty(alert)) | |
| for webhook in self.generic_webhooks: | |
| tasks.append(self._send_generic(webhook, alert)) | |
| if tasks: | |
| await asyncio.gather(*tasks, return_exceptions=True) | |
| async def _send_slack(self, alert: Alert) -> None: | |
| """Send alert to Slack.""" | |
| color = "danger" if alert.severity == AlertSeverity.CRITICAL else "warning" | |
| emoji = "🚨" if alert.severity == AlertSeverity.CRITICAL else "⚠️" | |
| payload = { | |
| "text": f"{emoji} *{alert.severity.value.upper()}*: {alert.message}", | |
| "attachments": [ | |
| { | |
| "color": color, | |
| "fields": [ | |
| { | |
| "title": "Metric", | |
| "value": alert.metric, | |
| "short": True, | |
| }, | |
| { | |
| "title": "Value", | |
| "value": f"{alert.value:.2f}", | |
| "short": True, | |
| }, | |
| { | |
| "title": "Threshold", | |
| "value": f"{alert.threshold:.2f}", | |
| "short": True, | |
| }, | |
| { | |
| "title": "Time", | |
| "value": alert.timestamp.strftime("%Y-%m-%d %H:%M:%S"), | |
| "short": True, | |
| }, | |
| ], | |
| } | |
| ], | |
| } | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| self.slack_webhook, | |
| json=payload, | |
| timeout=aiohttp.ClientTimeout(total=10), | |
| ) as response: | |
| if response.status != 200: | |
| logger.error(f"Slack webhook failed: {response.status}") | |
| except Exception as e: | |
| logger.error(f"Error sending Slack alert: {e}") | |
| async def _send_pagerduty(self, alert: Alert) -> None: | |
| """Send alert to PagerDuty.""" | |
| payload = { | |
| "routing_key": self.pagerduty_key, | |
| "event_action": "trigger", | |
| "dedup_key": f"llm-dashboard-{alert.rule_name}", | |
| "payload": { | |
| "summary": alert.message, | |
| "severity": "critical", | |
| "source": "llm-inference-dashboard", | |
| "custom_details": { | |
| "metric": alert.metric, | |
| "value": alert.value, | |
| "threshold": alert.threshold, | |
| }, | |
| }, | |
| } | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| "https://events.pagerduty.com/v2/enqueue", | |
| json=payload, | |
| timeout=aiohttp.ClientTimeout(total=10), | |
| ) as response: | |
| if response.status != 202: | |
| logger.error(f"PagerDuty failed: {response.status}") | |
| except Exception as e: | |
| logger.error(f"Error sending PagerDuty alert: {e}") | |
| async def _send_generic(self, webhook_url: str, alert: Alert) -> None: | |
| """Send alert to a generic webhook.""" | |
| payload = alert.to_dict() | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| webhook_url, | |
| json=payload, | |
| timeout=aiohttp.ClientTimeout(total=10), | |
| ) as response: | |
| if response.status >= 400: | |
| logger.error(f"Webhook {webhook_url} failed: {response.status}") | |
| except Exception as e: | |
| logger.error(f"Error sending to webhook {webhook_url}: {e}") | |
| async def send_test_alert(self) -> bool: | |
| """Send a test alert to verify configuration.""" | |
| test_alert = Alert( | |
| rule_name="test_alert", | |
| metric="test_metric", | |
| value=100.0, | |
| threshold=50.0, | |
| severity=AlertSeverity.INFO, | |
| message="This is a test alert from LLM Inference Dashboard", | |
| ) | |
| try: | |
| await self.dispatch(test_alert) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Test alert failed: {e}") | |
| return False | |