| """ |
| backend/main.py β FastAPI server for the Cognitive Load Manager (OpenEnv). |
| |
| Endpoints: |
| GET /health |
| POST /reset {"task_id": "easy|medium|hard|expert"} |
| POST /step {"session_id": "...", "action": {...}} |
| GET /state ?session_id=... |
| GET /grader |
| GET /grade/easy|medium|hard|expert |
| GET /stream/run ?difficulty=medium β SSE live episode (heuristic agent) |
| GET /benchmark β heuristic scores all 4 levels |
| GET /training-log β saved reward_curve.json |
| POST /train/start ?difficulty=medium&steps=25 β start demo training |
| GET /train/status β current training state |
| GET /train/stream β SSE live training progress |
| """ |
| import asyncio |
| import json |
| import os |
| import random as _random |
| import sys |
| import threading |
| import time |
| import uuid |
| from datetime import datetime, timezone |
| from typing import Dict, Optional, List |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse, StreamingResponse |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel, Field |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from models import ( |
| Action as ModelAction, |
| generate_tasks, |
| deterministic_grader, |
| CLMEnvironment, |
| PRIORITY_WEIGHT, |
| ) |
|
|
| _SCORE_MIN = 0.01 |
| _SCORE_MAX = 0.99 |
|
|
|
|
| def _safe(raw: float) -> float: |
| try: |
| return round(max(_SCORE_MIN, min(_SCORE_MAX, float(raw))), 4) |
| except Exception: |
| return _SCORE_MIN |
|
|
|
|
| |
| _sessions: Dict[str, CLMEnvironment] = {} |
|
|
|
|
| def _get_session(session_id: str) -> CLMEnvironment: |
| env = _sessions.get(session_id) |
| if env is None: |
| raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.") |
| return env |
|
|
|
|
| def _avg_energy(env: CLMEnvironment) -> float: |
| workers = env.state.workers |
| return sum(w.energy for w in workers) / len(workers) if workers else 0.5 |
|
|
|
|
| |
| def _heuristic_action(env: CLMEnvironment) -> ModelAction: |
| state = env.state |
| blocked = env._blocked_ids() |
| w0 = state.workers[0] if state.workers else None |
|
|
| if w0 and (w0.energy < 0.28 or w0.stress > 0.72): |
| return ModelAction(type="break", task_id=None, worker_id="w1") |
|
|
| pending = [t for t in state.tasks if t.progress < 1.0 and t.id not in blocked] |
| if not pending: |
| return ModelAction(type="delay", task_id=None, worker_id="w1") |
|
|
| pending.sort(key=lambda t: ( |
| -PRIORITY_WEIGHT[t.priority], |
| t.deadline if t.deadline is not None else 9999, |
| )) |
| target = pending[0] |
| use_focus = ( |
| target.priority == "critical" |
| and target.deadline is not None |
| and (target.deadline - state.time_step) <= 10 |
| and w0 is not None and w0.energy > 0.52 |
| ) |
| return ModelAction(type="focus" if use_focus else "work", |
| task_id=target.id, worker_id="w1") |
|
|
|
|
| |
| def _random_action(env: CLMEnvironment) -> ModelAction: |
| state = env.state |
| rng = _random.Random() |
| pending = [t for t in state.tasks if t.progress < 1.0] |
|
|
| if not pending or rng.random() < 0.15: |
| return ModelAction(type="break", task_id=None, worker_id="w1") |
| if rng.random() < 0.10: |
| return ModelAction(type="delay", task_id=None, worker_id="w1") |
|
|
| task = rng.choice(pending) |
| act = rng.choice(["work", "work", "work", "focus"]) |
| return ModelAction(type=act, task_id=task.id, worker_id="w1") |
|
|
|
|
| def _mixed_action(env: CLMEnvironment, heuristic_prob: float) -> ModelAction: |
| """Blend random (p=0) β heuristic (p=1) as training progresses.""" |
| return (_heuristic_action(env) if _random.random() < heuristic_prob |
| else _random_action(env)) |
|
|
|
|
| |
| def _run_episode(difficulty: str, agent: str = "heuristic", |
| heuristic_prob: float = 1.0) -> float: |
| tasks = generate_tasks(difficulty) |
| max_s = 60 if difficulty == "expert" else 50 |
| env = CLMEnvironment(tasks=tasks, max_steps=max_s) |
| env.reset() |
| done = False; step = 0; total_r = 0.0 |
|
|
| while not done and step < max_s: |
| if agent == "heuristic": |
| action = _heuristic_action(env) |
| elif agent == "random": |
| action = _random_action(env) |
| else: |
| action = _mixed_action(env, heuristic_prob) |
| _, reward, done, info = env.step(action) |
| total_r += float(reward); step += 1 |
|
|
| avg_e = _avg_energy(env) |
| return float(info.get("final_score", |
| deterministic_grader(env.state.tasks, |
| env.state.time_step, avg_e))) |
|
|
|
|
| |
| _training_state: dict = { |
| "running": False, |
| "status": "idle", |
| "current_step": 0, |
| "total_steps": 25, |
| "difficulty": "medium", |
| "curve": [], |
| "before": None, |
| "after": None, |
| "metadata": None, |
| "error": None, |
| "_version": 0, |
| } |
| _training_lock = threading.Lock() |
|
|
|
|
| def _bump(updates: dict) -> None: |
| with _training_lock: |
| _training_state.update(updates) |
| _training_state["_version"] += 1 |
|
|
|
|
| def _run_training_demo(difficulty: str, total_steps: int, root_dir: str) -> None: |
| """Background thread: simulates GRPO reward progression randomβheuristic.""" |
| try: |
| started = datetime.now(timezone.utc).isoformat() |
| _bump({"running": True, "status": "running", "curve": [], |
| "current_step": 0, "total_steps": total_steps, |
| "difficulty": difficulty, "before": None, "after": None, |
| "error": None, "metadata": { |
| "started_at": started, "completed_at": None, |
| "total_steps": total_steps, "difficulty": difficulty, |
| "status": "running", |
| }}) |
|
|
| |
| before: dict = {} |
| for d in ("easy", "medium", "hard", "expert"): |
| scores = [_run_episode(d, agent="random") for _ in range(3)] |
| before[d] = round(sum(scores) / len(scores), 4) |
| _bump({"before": before}) |
|
|
| |
| curve: list = [] |
| for step in range(total_steps): |
| |
| progress = step / max(total_steps - 1, 1) |
| h_prob = 0.05 + 0.87 * (progress ** 1.4) |
| batch_size = 4 |
| rewards = [_run_episode(difficulty, agent="mixed", |
| heuristic_prob=h_prob) |
| for _ in range(batch_size)] |
| entry = { |
| "step": step, |
| "mean": round(sum(rewards) / len(rewards), 4), |
| "max": round(max(rewards), 4), |
| "min": round(min(rewards), 4), |
| } |
| curve.append(entry) |
| _bump({"curve": list(curve), "current_step": step + 1}) |
| time.sleep(0.45) |
|
|
| |
| after: dict = {} |
| for d in ("easy", "medium", "hard", "expert"): |
| scores = [_run_episode(d, agent="heuristic") for _ in range(3)] |
| after[d] = round(sum(scores) / len(scores), 4) |
|
|
| completed = datetime.now(timezone.utc).isoformat() |
| result = { |
| "metadata": { |
| "started_at": started, |
| "completed_at": completed, |
| "total_steps": total_steps, |
| "difficulty": difficulty, |
| "status": "completed", |
| }, |
| "before": before, |
| "after": after, |
| "curve": curve, |
| } |
|
|
| |
| rc_path = os.path.join(root_dir, "reward_curve.json") |
| with open(rc_path, "w") as f: |
| json.dump(result, f, indent=2) |
|
|
| _bump({"after": after, "status": "completed", "running": False, |
| "metadata": result["metadata"]}) |
|
|
| except Exception as exc: |
| _bump({"status": "error", "running": False, "error": str(exc)}) |
|
|
|
|
| |
| class ResetRequest(BaseModel): |
| task_id: str = Field(default="medium") |
| seed: Optional[int] = Field(default=None) |
|
|
| def __init__(self, **data): |
| if "task" in data and "task_id" not in data: |
| data["task_id"] = data.pop("task") |
| super().__init__(**data) |
|
|
|
|
| class ActionPayload(BaseModel): |
| type: str |
| task_id: Optional[str] = None |
| worker_id: Optional[str] = None |
|
|
|
|
| class StepRequest(BaseModel): |
| session_id: Optional[str] = None |
| action: ActionPayload |
|
|
|
|
| |
| def _run_grader_episode(difficulty: str) -> dict: |
| try: |
| from grader.clm_graders import EasyGrader, MediumGrader, HardGrader, ExpertGrader |
| cls = {"easy": EasyGrader, "medium": MediumGrader, |
| "hard": HardGrader, "expert": ExpertGrader}.get(difficulty, EasyGrader) |
| score, done, msg = cls().grade() |
| score = _safe(score) |
| except Exception as ex: |
| score = _SCORE_MIN |
| msg = f"Grader error: {ex}" |
| return {"task_id": difficulty, "reward": score, "score": score, |
| "done": False, "grader_message": msg} |
|
|
|
|
| |
| def build_app() -> FastAPI: |
| app = FastAPI( |
| title="Cognitive Load Manager β OpenEnv API", |
| version="2.0.0", |
| description="Multi-agent RL environment for cognitive load scheduling.", |
| ) |
| app.add_middleware( |
| CORSMiddleware, allow_origins=["*"], allow_credentials=True, |
| allow_methods=["*"], allow_headers=["*"], |
| ) |
|
|
| _ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| _REWARD_CURVE = os.path.join(_ROOT, "reward_curve.json") |
|
|
| |
| @app.get("/health", tags=["System"]) |
| async def health(): |
| return {"status": "healthy", "sessions": len(_sessions), |
| "training": _training_state["status"]} |
|
|
| |
| @app.post("/reset", tags=["Environment"]) |
| async def reset(req: ResetRequest): |
| task_id = req.task_id if req.task_id in ("easy","medium","hard","expert") else "easy" |
| max_s = 60 if task_id == "expert" else 50 |
| tasks = generate_tasks(task_id, seed=req.seed) |
| env = CLMEnvironment(tasks=tasks, max_steps=max_s, seed=req.seed) |
| obs = env.reset() |
| sid = str(uuid.uuid4()) |
| _sessions[sid] = env |
| return { |
| "session_id": sid, |
| "observation": { |
| "tasks": [t.model_dump() for t in obs.tasks], |
| "visible_state": obs.visible_state.model_dump(), |
| "time_step": obs.time_step, |
| }, |
| "done": False, |
| "reward": 0.0, |
| } |
|
|
| |
| @app.post("/step", tags=["Environment"]) |
| async def step(req: StepRequest): |
| if req.session_id: |
| env = _get_session(req.session_id) |
| elif _sessions: |
| env = list(_sessions.values())[-1] |
| else: |
| raise HTTPException(status_code=400, detail="No active session.") |
|
|
| action = ModelAction(type=req.action.type, task_id=req.action.task_id, |
| worker_id=req.action.worker_id or "w1") |
| obs, reward, done, info = env.step(action) |
|
|
| if done: |
| avg_e = _avg_energy(env) |
| info["final_score"] = _safe(info.get( |
| "final_score", |
| deterministic_grader(env.state.tasks, env.state.time_step, avg_e))) |
| if req.session_id and req.session_id in _sessions: |
| del _sessions[req.session_id] |
|
|
| return { |
| "session_id": req.session_id, |
| "observation": { |
| "tasks": [t.model_dump() for t in obs.tasks], |
| "visible_state": obs.visible_state.model_dump(), |
| "time_step": obs.time_step, |
| }, |
| "reward": _safe(float(reward)), |
| "done": done, |
| "info": {k: v for k, v in info.items() |
| if k in ("final_score", "schema_drift", "time_step")}, |
| } |
|
|
| |
| @app.get("/state", tags=["Environment"]) |
| async def state(session_id: Optional[str] = None): |
| if session_id: |
| env = _get_session(session_id) |
| elif _sessions: |
| env = list(_sessions.values())[-1] |
| else: |
| raise HTTPException(status_code=400, detail="No active session.") |
| return {"state": env.state_dict(), "session_id": session_id} |
|
|
| |
| @app.get("/grader", tags=["Grader"]) |
| async def grader(): return _run_grader_episode("easy") |
|
|
| @app.get("/grade/easy", tags=["Grader"]) |
| async def grade_easy(): return _run_grader_episode("easy") |
|
|
| @app.get("/grade/medium", tags=["Grader"]) |
| async def grade_medium(): return _run_grader_episode("medium") |
|
|
| @app.get("/grade/hard", tags=["Grader"]) |
| async def grade_hard(): return _run_grader_episode("hard") |
|
|
| @app.get("/grade/expert", tags=["Grader"]) |
| async def grade_expert(): return _run_grader_episode("expert") |
|
|
| |
| @app.get("/stream/run", tags=["Streaming"]) |
| async def stream_run(difficulty: str = "medium", delay_ms: int = 350): |
| diff = difficulty if difficulty in ("easy","medium","hard","expert") else "medium" |
| sleep_s = max(0.1, min(2.0, delay_ms / 1000)) |
|
|
| async def event_gen(): |
| try: |
| max_s = 60 if diff == "expert" else 50 |
| tasks = generate_tasks(diff) |
| env = CLMEnvironment(tasks=tasks, max_steps=max_s) |
| obs = env.reset() |
| w0 = env.state.workers[0] if env.state.workers else None |
|
|
| yield f"data: {json.dumps({'type':'reset','difficulty':diff,'step':0,'tasks':[t.model_dump() for t in obs.tasks],'visible_state':obs.visible_state.model_dump(),'energy':round(w0.energy if w0 else 1.0,3),'stress':round(w0.stress if w0 else 0.0,3)})}\n\n" |
|
|
| done = False; total_r = 0.0 |
| while not done: |
| action = _heuristic_action(env) |
| obs, reward, done, info = env.step(action) |
| total_r = round(total_r + float(reward), 4) |
| w0 = env.state.workers[0] if env.state.workers else None |
| completed = sum(1 for t in obs.tasks if t.progress >= 1.0) |
|
|
| event: dict = { |
| "type": "step", |
| "step": obs.time_step, |
| "action": {"type": action.type, "task_id": action.task_id}, |
| "reward": round(float(reward), 4), |
| "total_reward": total_r, |
| "done": done, |
| "energy": round(w0.energy if w0 else 0.5, 3), |
| "stress": round(w0.stress if w0 else 0.0, 3), |
| "tasks_done": completed, |
| "tasks_total": len(obs.tasks), |
| "tasks": [t.model_dump() for t in obs.tasks], |
| "visible_state": obs.visible_state.model_dump(), |
| } |
| if info.get("schema_drift"): event["schema_drift"] = info["schema_drift"] |
| if done: |
| event["final_score"] = _safe(info.get("final_score", 0.01)) |
| event["final_energy"] = round(w0.energy if w0 else 0.5, 3) |
|
|
| yield f"data: {json.dumps(event)}\n\n" |
| if not done: |
| await asyncio.sleep(sleep_s) |
|
|
| except Exception as exc: |
| yield f"data: {json.dumps({'type':'error','message':str(exc)})}\n\n" |
|
|
| return StreamingResponse(event_gen(), media_type="text/event-stream", |
| headers={"Cache-Control":"no-cache","X-Accel-Buffering":"no", |
| "Connection":"keep-alive"}) |
|
|
| |
| @app.get("/benchmark", tags=["Benchmark"]) |
| def benchmark(): |
| results = {} |
| baseline = {"easy":0.856,"medium":0.523,"hard":0.301,"expert":0.221} |
| for diff in ("easy","medium","hard","expert"): |
| try: |
| tasks = generate_tasks(diff, seed=42) |
| max_s = 60 if diff == "expert" else 50 |
| env = CLMEnvironment(tasks=tasks, max_steps=max_s, seed=42) |
| env.reset() |
| done = False; step = 0; total_r = 0.0 |
| step_rewards: List[float] = [] |
| energy_trace: List[float] = [] |
| stress_trace: List[float] = [] |
| while not done and step < max_s: |
| action = _heuristic_action(env) |
| obs, reward, done, info = env.step(action) |
| total_r += float(reward) |
| step_rewards.append(round(float(reward), 4)) |
| w0 = env.state.workers[0] if env.state.workers else None |
| energy_trace.append(round(w0.energy if w0 else 0.5, 3)) |
| stress_trace.append(round(w0.stress if w0 else 0.0, 3)) |
| step += 1 |
|
|
| avg_e = _avg_energy(env) |
| final_score = _safe(info.get("final_score", |
| deterministic_grader(env.state.tasks, env.state.time_step, avg_e))) |
| tasks_done = sum(1 for t in env.state.tasks if t.progress >= 1.0) |
| dl_tasks = [t for t in env.state.tasks if t.deadline is not None] |
| met_dl = sum(1 for t in dl_tasks |
| if t.progress >= 1.0 and env.state.time_step <= t.deadline) |
| total_w = sum(PRIORITY_WEIGHT[t.priority] for t in env.state.tasks) |
| wc = sum(t.progress*PRIORITY_WEIGHT[t.priority] |
| for t in env.state.tasks) / max(total_w, 0.01) |
| da = (met_dl / len(dl_tasks)) if dl_tasks else 1.0 |
| ee = max(0.0, (avg_e - 0.10) * 0.13) |
| dep = min(0.05, sum(0.015 for t in env.state.tasks |
| if t.depends_on and t.progress >= 1.0 |
| and any(p.id==t.depends_on and p.progress>=1.0 |
| for p in env.state.tasks))) |
| int_t = [t for t in env.state.tasks if t.is_interrupted] |
| int_b = min(0.03, (sum(1 for t in int_t if t.progress>=1.0)/ |
| len(int_t)*0.03) if int_t else 0.0) |
| results[diff] = { |
| "score": final_score, |
| "baseline": baseline[diff], |
| "total_reward": round(total_r, 4), |
| "steps": step, |
| "tasks_done": tasks_done, |
| "tasks_total": len(env.state.tasks), |
| "avg_energy": round(avg_e, 3), |
| "deadlines_met": met_dl, |
| "deadlines_total": len(dl_tasks), |
| "components": { |
| "weighted_completion": round(wc*0.60, 4), |
| "deadline_adherence": round(da*0.22, 4), |
| "energy_efficiency": round(ee, 4), |
| "dependency_bonus": round(dep, 4), |
| "interruption_bonus": round(int_b, 4), |
| }, |
| "step_rewards": step_rewards, |
| "energy_trace": energy_trace, |
| "stress_trace": stress_trace, |
| } |
| except Exception as exc: |
| results[diff] = {"error":str(exc),"score":0.01,"baseline":baseline[diff]} |
| return results |
|
|
| |
| @app.get("/training-log", tags=["Training"]) |
| async def training_log(): |
| if os.path.exists(_REWARD_CURVE): |
| with open(_REWARD_CURVE) as f: |
| raw = json.load(f) |
| |
| |
| |
| if isinstance(raw, list): |
| return {"metadata": None, "before": None, "after": None, "curve": raw} |
| return raw |
| return {"metadata": None, "before": None, "after": None, "curve": []} |
|
|
| |
| @app.post("/train/start", tags=["Training"]) |
| async def train_start(difficulty: str = "medium", steps: int = 25): |
| if _training_state["running"]: |
| return {"status": "already_running", |
| "message": "Training already in progress."} |
| diff = difficulty if difficulty in ("easy","medium","hard","expert") else "medium" |
| steps = max(10, min(50, steps)) |
| t = threading.Thread( |
| target=_run_training_demo, |
| args=(diff, steps, _ROOT), |
| daemon=True, |
| ) |
| t.start() |
| return {"status": "started", "difficulty": diff, "total_steps": steps} |
|
|
| |
| @app.get("/train/status", tags=["Training"]) |
| async def train_status(): |
| with _training_lock: |
| return dict(_training_state) |
|
|
| |
| @app.get("/train/stream", tags=["Training"]) |
| async def train_stream(): |
| """ |
| SSE that pushes training state whenever a new training step completes. |
| Terminates when training finishes or errors out. |
| """ |
| async def gen(): |
| last_version = -1 |
| while True: |
| with _training_lock: |
| ver = _training_state["_version"] |
| status = _training_state["status"] |
| snap = dict(_training_state) |
|
|
| if ver != last_version: |
| last_version = ver |
| |
| payload = {k: v for k, v in snap.items() if k != "_version"} |
| yield f"data: {json.dumps(payload)}\n\n" |
| if status in ("completed", "error"): |
| break |
|
|
| await asyncio.sleep(0.3) |
|
|
| return StreamingResponse(gen(), media_type="text/event-stream", |
| headers={"Cache-Control":"no-cache","X-Accel-Buffering":"no", |
| "Connection":"keep-alive"}) |
|
|
| |
| _DIST = os.path.join(_ROOT, "frontend", "dist") |
| _ASSETS = os.path.join(_DIST, "assets") |
|
|
| if os.path.isdir(_ASSETS): |
| app.mount("/assets", StaticFiles(directory=_ASSETS), name="assets") |
|
|
| if os.path.isdir(_DIST): |
| _INDEX = os.path.join(_DIST, "index.html") |
|
|
| @app.get("/", include_in_schema=False) |
| async def spa_root(): |
| return FileResponse(_INDEX) |
|
|
| @app.get("/{full_path:path}", include_in_schema=False) |
| async def spa_catchall(full_path: str): |
| return FileResponse(_INDEX) |
| else: |
| @app.get("/", tags=["System"]) |
| async def api_root(): |
| return {"status": "ok", "service": "CLM OpenEnv API", |
| "docs": "/docs", "stream": "/stream/run?difficulty=medium", |
| "train": "POST /train/start", "benchmark": "/benchmark"} |
|
|
| return app |
|
|
|
|
| app = build_app() |
|
|