Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI OpenEnv Server for Adaptive Alert Triage Environment β v0.3.0 | |
| Root-cause fixes: | |
| FIX 1 β "No active episode" on /agent/recommend | |
| The startup now calls env.reset() immediately AND starts an asyncio | |
| background task (_episode_loop) that keeps the environment always live. | |
| Every STEP_INTERVAL seconds it checks alerts, picks an action (PPO or | |
| rule-based fallback), calls env.step(), and resets when done. | |
| FIX 2 β Queued alerts (real_alerts_queue) never appeared in env.alerts | |
| env.py only drains real_alerts_queue inside _generate_new_alerts() which | |
| runs during env.step(). The episode loop calls step() continuously, so | |
| real alerts are consumed automatically within ~1s of being queued. | |
| FIX 3 β alert.dict() / obs.dict() removed in Pydantic v2 | |
| Fixed to model_dump() everywhere. | |
| FIX 4 β task_score missing from info dict | |
| Computed server-side from action_correct running average and injected | |
| into info["task_score"] so train_external.py receives it correctly. | |
| FIX 5 β real_alerts_queue dropped on /env/reset | |
| Queue is saved and re-attached to the new env object. | |
| FIX 6 β state.system_load AttributeError | |
| Fixed to state.observation.system_load (EpisodeState structure). | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| import sys | |
| import traceback | |
| from collections import deque | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from .env import AdaptiveAlertTriageEnv | |
| from .models import Action, Observation, Reward | |
| # ββ Try to load trained PPO agent (lazy import, server starts without it) βββββ | |
| _PPO_AVAILABLE = False | |
| try: | |
| _project_root = os.path.dirname( | |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| ) | |
| if _project_root not in sys.path: | |
| sys.path.insert(0, _project_root) | |
| from rl_agent import PPOTrainer, encode_state, _ACTION_NAMES # type: ignore | |
| _PPO_AVAILABLE = True | |
| except ImportError: | |
| _project_root = "" | |
| # ββ Request / response models βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class IngestAlert(BaseModel): | |
| id: str | |
| visible_severity: float | |
| confidence: float | |
| type: str | |
| class StepRequest(BaseModel): | |
| alert_id: str | |
| action_type: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| env_ready: bool | |
| queue_size: int | |
| # ββ Alert-type normaliser βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _TYPE_REMAP: Dict[str, str] = { | |
| "cpu": "CPU", "cpu_spike": "CPU", | |
| "memory": "MEMORY", "memory_leak": "MEMORY", | |
| "disk": "DISK", "disk_full": "DISK", | |
| "network": "NETWORK", "net": "NETWORK", "network_latency": "NETWORK", | |
| "application": "APPLICATION", "app": "APPLICATION", | |
| "security": "SECURITY", "sec": "SECURITY", | |
| } | |
| _VALID = {"CPU", "MEMORY", "DISK", "NETWORK", "APPLICATION", "SECURITY"} | |
| def _norm(raw: str) -> str: | |
| return _TYPE_REMAP.get(raw.lower(), raw.upper()) if raw else "APPLICATION" | |
| # ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="Adaptive Alert Triage RL Server", version="0.3.0") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], | |
| allow_credentials=False, allow_methods=["*"], allow_headers=["*"]) | |
| #Changes | |
| async def log_requests(request, call_next): | |
| print(f"REQUEST: {request.method} {request.url}") | |
| return await call_next(request) | |
| # ββ Global state ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| env: Optional[AdaptiveAlertTriageEnv] = None | |
| episode_scores: List[float] = [] | |
| _ppo_agents: Dict[str, Any] = {} # task_id β PPOTrainer | |
| _loop_task: Optional[asyncio.Task] = None | |
| _last_action: Optional[str] = None | |
| _step_correct: int = 0 | |
| _step_total: int = 0 | |
| STEP_INTERVAL = 1.0 # seconds between autonomous episode-loop steps | |
| # ββ Score helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _reset_score() -> None: | |
| global _step_correct, _step_total | |
| _step_correct = _step_total = 0 | |
| def _tick(info: Dict) -> None: | |
| global _step_correct, _step_total | |
| _step_total += 1 | |
| if info.get("action_correct", False): | |
| _step_correct += 1 | |
| def _score() -> float: | |
| return _step_correct / _step_total if _step_total else 0.0 | |
| # ββ PPO helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_ppo(task_id: str) -> Optional[Any]: | |
| if not _PPO_AVAILABLE: | |
| return None | |
| path = os.path.join(_project_root, "weights", f"ppo_{task_id}.json") | |
| if not os.path.exists(path): | |
| print(f" [PPO] weights not found: {path}") | |
| return None | |
| try: | |
| agent = PPOTrainer(task_id=task_id) | |
| agent.load(path) | |
| print(f" [PPO] loaded {path}") | |
| return agent | |
| except Exception as e: | |
| print(f" [PPO] load error: {e}") | |
| return None | |
| def _ppo_act() -> Optional[Action]: | |
| if not env or not env.alerts: | |
| return None | |
| agent = _ppo_agents.get(env.task_id) | |
| if agent is None: | |
| return None | |
| try: | |
| obs = Observation( | |
| alerts = list(env.alerts), | |
| system_load = getattr(env, "_last_system_load", 0.5), | |
| queue_length = len(env.alerts), | |
| time_remaining = env.max_steps - env.current_step, | |
| resource_budget=( | |
| env.max_investigations_per_step - env.investigations_used | |
| if env.max_investigations_per_step is not None else None | |
| ), | |
| episode_step = env.current_step, | |
| ) | |
| return agent.act(obs) | |
| except Exception: | |
| return None | |
| def _rule_act() -> Optional[Action]: | |
| if not env or not env.alerts: | |
| return None | |
| top = max(env.alerts, key=lambda a: a.visible_severity) | |
| sev = top.visible_severity | |
| conf = top.confidence | |
| rem = (env.max_investigations_per_step - env.investigations_used | |
| if env.max_investigations_per_step is not None else None) | |
| if sev >= 0.75 and conf >= 0.60: | |
| atype = "ESCALATE" if (rem is not None and rem <= 0) else "INVESTIGATE" | |
| elif conf < 0.30 or sev < 0.30: | |
| atype = "IGNORE" | |
| elif sev >= 0.55: | |
| atype = "ESCALATE" | |
| else: | |
| atype = "DELAY" | |
| return Action(alert_id=top.id, action_type=atype) | |
| # ββ Always-live episode loop ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _episode_loop() -> None: | |
| """ | |
| Background asyncio task. | |
| Every STEP_INTERVAL seconds: | |
| 1. If no active alerts β reset (start new episode). | |
| 2. Choose action: PPO weights > rule-based fallback. | |
| 3. Call env.step() β drains real_alerts_queue automatically. | |
| 4. Track score; on done β log + reset. | |
| This is what makes /agent/recommend always return a valid answer. | |
| """ | |
| global env, _last_action | |
| while True: | |
| try: | |
| if env is None: | |
| await asyncio.sleep(STEP_INTERVAL) | |
| continue | |
| # Start new episode if terminal or empty | |
| if not env.alerts or env._is_terminal(): | |
| if _step_total > 0: | |
| episode_scores.append(_score()) | |
| _reset_score() | |
| env.reset() | |
| if not env.alerts: | |
| await asyncio.sleep(STEP_INTERVAL) | |
| continue | |
| action = _ppo_act() or _rule_act() | |
| if action is None: | |
| await asyncio.sleep(STEP_INTERVAL) | |
| continue | |
| _last_action = action.action_type | |
| _, reward, done, info = env.step(action) | |
| _tick(info) | |
| if done: | |
| episode_scores.append(_score()) | |
| if len(episode_scores) > 1000: | |
| episode_scores[:] = episode_scores[-1000:] | |
| _reset_score() | |
| env.reset() | |
| except Exception as exc: | |
| print(f"[episode_loop] {exc}") | |
| await asyncio.sleep(STEP_INTERVAL) | |
| # ββ Startup / shutdown ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def startup(): | |
| global env, _loop_task | |
| env = AdaptiveAlertTriageEnv(task_id="hard") | |
| env.real_alerts_queue = deque(maxlen=50) | |
| env.reset() # β FIX 1: immediately populate env.alerts | |
| for tid in ("easy", "medium", "hard"): | |
| agent = _load_ppo(tid) | |
| if agent: | |
| _ppo_agents[tid] = agent | |
| _loop_task = asyncio.create_task(_episode_loop()) | |
| print("β Alert Triage RL Server v0.3.0") | |
| print(f" Active alerts : {len(env.alerts)}") | |
| print(f" PPO loaded : {list(_ppo_agents.keys()) or 'none (run train_rl.py first)'}") | |
| print(f" Episode loop : every {STEP_INTERVAL}s") | |
| async def shutdown(): | |
| if _loop_task: | |
| _loop_task.cancel() | |
| # ββ Health ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health(): | |
| return HealthResponse( | |
| status = "ok", | |
| env_ready = env is not None and bool(env.alerts), | |
| queue_size= len(env.real_alerts_queue) if env and hasattr(env, "real_alerts_queue") else 0, | |
| ) | |
| async def metrics(): | |
| if not env: | |
| return {"error": "not initialized"} | |
| mean = sum(episode_scores[-100:]) / len(episode_scores[-100:]) if episode_scores else 0.0 | |
| delta = (mean - 0.61) * 100 | |
| return { | |
| "mean_score": round(mean, 3), | |
| "vs_baseline": f"+{delta:.0f}%" if delta >= 0 else f"{delta:.0f}%", | |
| "active_alerts": len(env.alerts), | |
| "episodes_completed": len(episode_scores), | |
| "current_step_score": round(_score(), 3), | |
| "current_step": env.current_step, | |
| "last_action": _last_action, | |
| "queue_size": len(env.real_alerts_queue) if hasattr(env, "real_alerts_queue") else 0, | |
| "ppo_loaded": list(_ppo_agents.keys()), | |
| } | |
| # ββ Alert ingestion βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def ingest_one(alert: IngestAlert): | |
| if not env: | |
| return {"error": "not initialized"} | |
| if not hasattr(env, "real_alerts_queue"): | |
| env.real_alerts_queue = deque(maxlen=50) | |
| raw = alert.model_dump() | |
| raw["type"] = _norm(raw.get("type", "APPLICATION")) | |
| env.real_alerts_queue.appendleft(raw) | |
| return { | |
| "status": "queued", "queued": len(env.real_alerts_queue), | |
| "alert_id": alert.id, "resolved_type": raw["type"], | |
| "note": "Episode loop will process this within ~1s", | |
| } | |
| async def ingest_batch(alerts: List[IngestAlert]): | |
| if not env: | |
| return {"error": "not initialized"} | |
| if not hasattr(env, "real_alerts_queue"): | |
| env.real_alerts_queue = deque(maxlen=50) | |
| ingested = [] | |
| for alert in alerts: | |
| raw = alert.model_dump() | |
| raw["type"] = _norm(raw.get("type", "APPLICATION")) | |
| env.real_alerts_queue.appendleft(raw) | |
| ingested.append({"alert_id": alert.id, "resolved_type": raw["type"]}) | |
| return {"status": "queued", "queued": len(env.real_alerts_queue), "ingested": ingested} | |
| # ββ Environment control βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def reset_env(task_id: str = "hard"): | |
| global env | |
| if task_id not in ("easy", "medium", "hard"): | |
| return {"error": f"Invalid task_id '{task_id}'"} | |
| try: | |
| saved = env.real_alerts_queue if (env and hasattr(env, "real_alerts_queue")) else None | |
| env = AdaptiveAlertTriageEnv(task_id=task_id) | |
| env.real_alerts_queue = saved if saved is not None else deque(maxlen=50) | |
| agent = _load_ppo(task_id) | |
| if agent: | |
| _ppo_agents[task_id] = agent | |
| obs = env.reset() | |
| _reset_score() | |
| return {"status": "reset", "task_id": task_id, "obs": obs.model_dump()} | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| async def step_env(request: StepRequest): | |
| global episode_scores | |
| if not env: | |
| return {"error": "not initialized"} | |
| if request.action_type not in {"INVESTIGATE", "IGNORE", "ESCALATE", "DELAY"}: | |
| return {"error": f"Invalid action '{request.action_type}'"} | |
| try: | |
| action = Action(alert_id=request.alert_id, action_type=request.action_type) | |
| obs, reward, done, info = env.step(action) | |
| _tick(info) | |
| s = _score() | |
| info["task_score"] = s | |
| if done: | |
| episode_scores.append(s) | |
| _reset_score() | |
| return {"obs": obs.model_dump(), "reward": reward.value, | |
| "done": done, "info": info, "score": s} | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| async def get_state(): | |
| if not env: | |
| return {"error": "not initialized"} | |
| try: | |
| state = env.state() | |
| return { | |
| "visible_state": { | |
| "alerts": [a.model_dump() for a in env.alerts], | |
| "current_step": env.current_step, | |
| "max_steps": env.max_steps, | |
| "failures_count": env.failures_count, | |
| "system_load": state.observation.system_load, # FIX 6 | |
| "queue_length": len(env.alerts), | |
| "task_id": env.task_id, | |
| "real_queue_size": len(env.real_alerts_queue) if hasattr(env, "real_alerts_queue") else 0, | |
| }, | |
| "hidden_state": state.hidden_state, | |
| "cumulative_reward": state.cumulative_reward, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # ββ Agent recommendation ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def recommend(): | |
| """ | |
| Returns the trained PPO agent's recommended action for the current alert. | |
| Always has alerts because the episode loop keeps the environment live. | |
| """ | |
| if not env or not env.alerts: | |
| return { | |
| "error": "No alerts yet β episode loop is starting, retry in 2s", | |
| "active_alerts": len(env.alerts) if env else 0, | |
| } | |
| task_id = env.task_id | |
| top = max(env.alerts, key=lambda a: a.visible_severity) | |
| ppo = _ppo_agents.get(task_id) | |
| if ppo is not None: | |
| try: | |
| import numpy as np | |
| obs = Observation( | |
| alerts = list(env.alerts), | |
| system_load = getattr(env, "_last_system_load", 0.5), | |
| queue_length = len(env.alerts), | |
| time_remaining = env.max_steps - env.current_step, | |
| resource_budget=( | |
| env.max_investigations_per_step - env.investigations_used | |
| if env.max_investigations_per_step is not None else None | |
| ), | |
| episode_step = env.current_step, | |
| ) | |
| s = encode_state(obs) | |
| probs, val = ppo.net.forward(s) | |
| idx = int(np.argmax(probs)) | |
| act = _ACTION_NAMES[idx] | |
| conf = round(float(probs[idx]) * 100, 1) | |
| return { | |
| "alert_id": top.id, | |
| "action_type": act, | |
| "reasoning": f"PPO ({conf:.1f}% confidence)", | |
| "source": "trained_ppo", | |
| "model_confidence": conf, | |
| "probabilities": {_ACTION_NAMES[i]: round(float(probs[i]), 4) for i in range(4)}, | |
| "value_estimate": round(float(val), 3), | |
| "alert_severity": top.visible_severity, | |
| "alert_confidence": top.confidence, | |
| "alert_age": top.age, | |
| "alert_type": top.alert_type, | |
| "active_alerts": len(env.alerts), | |
| "episode_step": env.current_step, | |
| "task_id": task_id, | |
| } | |
| except Exception as exc: | |
| print(f"PPO recommend error: {exc}") | |
| # Rule-based fallback | |
| sev, conf = top.visible_severity, top.confidence | |
| rem = (env.max_investigations_per_step - env.investigations_used | |
| if env.max_investigations_per_step is not None else None) | |
| if sev >= 0.75 and conf >= 0.60: | |
| act = "ESCALATE" if (rem is not None and rem <= 0) else "INVESTIGATE" | |
| elif conf < 0.30 or sev < 0.30: | |
| act = "IGNORE" | |
| elif sev >= 0.55: | |
| act = "ESCALATE" | |
| else: | |
| act = "DELAY" | |
| return { | |
| "alert_id": top.id, "action_type": act, | |
| "source": "rule_based", | |
| "alert_severity": sev, "alert_confidence": conf, | |
| "alert_type": top.alert_type, "active_alerts": len(env.alerts), | |
| "task_id": task_id, | |
| "hint": "Run `python train_rl.py --episodes 300` to load PPO weights", | |
| } | |
| # ββ WebSocket βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def ws_train(websocket: WebSocket): | |
| global env, episode_scores | |
| await websocket.accept() | |
| lc = lt = 0 | |
| try: | |
| while True: | |
| data = await websocket.receive_json() | |
| if data.get("type") == "reset": | |
| tid = data.get("task_id", "hard") | |
| saved = env.real_alerts_queue if (env and hasattr(env, "real_alerts_queue")) else None | |
| env = AdaptiveAlertTriageEnv(task_id=tid) | |
| env.real_alerts_queue = saved or deque(maxlen=50) | |
| obs = env.reset() | |
| lc = lt = 0 | |
| await websocket.send_json({"obs": obs.model_dump(), "task_id": tid}) | |
| elif data.get("type") == "step": | |
| if not env: | |
| await websocket.send_json({"error": "Reset first"}); continue | |
| ad = data.get("action", {}) | |
| act = Action(alert_id=ad.get("alert_id",""), action_type=ad.get("action_type","IGNORE")) | |
| obs, reward, done, info = env.step(act) | |
| lt += 1 | |
| if info.get("action_correct", False): lc += 1 | |
| s = lc / lt if lt else 0.0 | |
| if done: episode_scores.append(s) | |
| info["task_score"] = s | |
| await websocket.send_json({ | |
| "obs": obs.model_dump(), "reward": reward.value, | |
| "done": done, "info": info, "task_score": s, | |
| "action_correct": info.get("action_correct", False), | |
| "failures_this_step": info.get("failures_this_step", 0), | |
| }) | |
| elif data.get("type") == "close": | |
| break | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| try: await websocket.send_json({"error": str(e)}) | |
| except Exception: pass | |
| # ββ Utility βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root(): | |
| return { | |
| "name": "Adaptive Alert Triage RL Server", "version": "0.3.0", | |
| "quick_start": [ | |
| "1. python train_rl.py --episodes 300", | |
| "2. uvicorn src.adaptive_alert_triage.server:app --port 8000", | |
| "3. curl -X POST localhost:8000/ingest/alerts -H 'Content-Type: application/json' -d '{\"id\":\"p1\",\"visible_severity\":0.9,\"confidence\":0.85,\"type\":\"CPU\"}'", | |
| "4. curl localhost:8000/agent/recommend", | |
| ], | |
| } | |
| import threading | |
| import subprocess | |
| _training_proc = None | |
| _training_logs = [] | |
| def _run_training(episodes: int): | |
| global _training_proc, _training_logs, _ppo_agents | |
| _training_logs = [f"Starting training with --episodes {episodes}..."] | |
| try: | |
| _training_proc = subprocess.Popen( | |
| [sys.executable, "train_rl.py", "--episodes", str(episodes)], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| cwd=_project_root if _project_root else os.getcwd() | |
| ) | |
| for line in iter(_training_proc.stdout.readline, ''): | |
| if line: | |
| _training_logs.append(line.rstrip('\n')) | |
| if len(_training_logs) > 1000: | |
| _training_logs.pop(0) | |
| _training_proc.wait() | |
| _training_logs.append(f"Training finished with exit code {- _training_proc.returncode if _training_proc.returncode < 0 else _training_proc.returncode}") | |
| # Auto-reload PPO weights if training succeeded | |
| if _training_proc.returncode == 0: | |
| for tid in ("easy", "medium", "hard"): | |
| agent = _load_ppo(tid) | |
| if agent: | |
| _ppo_agents[tid] = agent | |
| _training_logs.append("Successfully reloaded PPO weights for all tasks.") | |
| # Auto-save weights back to Hugging Face | |
| try: | |
| from huggingface_hub import HfApi | |
| hf_token = os.environ.get("HF_TOKEN") | |
| repo_id = os.environ.get("SPACE_ID", "tusharp2006/scaler-deployment") | |
| if hf_token: | |
| _training_logs.append(f"Pushing updated weights back to HF Hub ({repo_id})...") | |
| api = HfApi(token=hf_token) | |
| weights_dir = os.path.join(_project_root if _project_root else os.getcwd(), "weights") | |
| if os.path.exists(weights_dir): | |
| api.upload_folder( | |
| repo_id=repo_id, | |
| folder_path=weights_dir, | |
| path_in_repo="weights", | |
| repo_type="space", | |
| commit_message="Auto-sync weights after RL training" | |
| ) | |
| _training_logs.append("Weights successfully pushed and persisted to Hugging Face!") | |
| else: | |
| _training_logs.append("No HF_TOKEN found in environment. Skipping weight cloud-persistence.") | |
| except ImportError: | |
| _training_logs.append("huggingface_hub not installed. Skipping cloud backup.") | |
| except Exception as e: | |
| _training_logs.append(f"Failed to push weights to Hub: {e}") | |
| except Exception as e: | |
| _training_logs.append(f"Error starting training: {e}") | |
| async def start_training(episodes: int = 300): | |
| global _training_proc | |
| if _training_proc is not None and _training_proc.poll() is None: | |
| return {"status": "already running"} | |
| threading.Thread(target=_run_training, args=(episodes,), daemon=True).start() | |
| return {"status": "started"} | |
| async def get_training_status(): | |
| global _training_proc, _training_logs | |
| is_running = _training_proc is not None and _training_proc.poll() is None | |
| return {"is_running": is_running, "logs": _training_logs} | |
| async def web_ui(): | |
| """ | |
| Serves the interactive web dashboard for real-time monitoring. | |
| OpenEnv-compliant: Matches HF Spaces `/web` endpoint convention. | |
| """ | |
| import os | |
| dashboard_path = os.path.join( | |
| os.path.dirname(os.path.dirname(os.path.dirname(__file__))), | |
| "dashboard.html" | |
| ) | |
| return FileResponse(dashboard_path, media_type="text/html") | |
| async def list_tasks(): | |
| return {"tasks": [ | |
| {"id": "easy", "success_threshold": 0.70, "max_steps": 30}, | |
| {"id": "medium", "success_threshold": 0.55, "max_steps": 40}, | |
| {"id": "hard", "success_threshold": 0.50, "max_steps": 50}, | |
| ]} | |