XcodeAddy commited on
Commit
939dba8
·
1 Parent(s): c47715e

Expose cluster trust API endpoints

Browse files
Files changed (1) hide show
  1. app.py +273 -2
app.py CHANGED
@@ -17,7 +17,7 @@ from fastapi.staticfiles import StaticFiles
17
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, StreamingResponse
18
  from pydantic import BaseModel
19
 
20
- from cluster_trust_env import ClusterTrustEnv
21
  from difficulty_controller import GLOBAL_DIFFICULTY_CONTROLLER
22
  from environment import SentinelEnv
23
  from mission_context import build_orchestrator_prompt, mission_for_task, problem_statement
@@ -130,6 +130,19 @@ def _get_env(session_id: str) -> SentinelEnv | ClusterTrustEnv:
130
  return env
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def _resolve_env_mode(task_type: str | None, mode: str | None = None) -> tuple[str, str]:
134
  requested_task = task_type or "task3"
135
  requested_mode = (mode or "").lower()
@@ -218,6 +231,27 @@ class StepRequest(BaseModel):
218
  reasoning: str | None = None
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # ---------------------------------------------------------------------------
222
  # Endpoints
223
  # ---------------------------------------------------------------------------
@@ -255,6 +289,11 @@ def root():
255
  "/grader", "/reward-report", "/difficulty", "/stream", "/trust-dashboard",
256
  "/cluster-dashboard",
257
  "/reset", "/step", "/state",
 
 
 
 
 
258
  ],
259
  }
260
  )
@@ -308,6 +347,11 @@ def api_root():
308
  "/grader", "/reward-report", "/difficulty", "/stream", "/trust-dashboard",
309
  "/cluster-dashboard",
310
  "/reset", "/step", "/state",
 
 
 
 
 
311
  ],
312
  }
313
 
@@ -369,8 +413,13 @@ def metadata():
369
  },
370
  "adaptive_curriculum": GLOBAL_DIFFICULTY_CONTROLLER.state(),
371
  "cluster_mode": {
372
- "how_to_enable": "POST /reset with {\"mode\":\"cluster\",\"task_type\":\"task3\"} or {\"task_type\":\"cluster_task3\"}.",
 
 
 
 
373
  "live_dashboard": "/cluster-dashboard?session_id=<session_id>",
 
374
  },
375
  }
376
 
@@ -578,6 +627,228 @@ def mcp(body: dict[str, Any]):
578
  raise HTTPException(status_code=400, detail=f"Unknown method: {method}")
579
 
580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  def _trust_dashboard_html(session_id: str) -> str:
582
  escaped_session = html.escape(session_id, quote=True)
583
  return f"""<!doctype html>
 
17
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, StreamingResponse
18
  from pydantic import BaseModel
19
 
20
+ from cluster_trust_env import CLUSTER_TASK_CONFIG, ClusterTrustEnv
21
  from difficulty_controller import GLOBAL_DIFFICULTY_CONTROLLER
22
  from environment import SentinelEnv
23
  from mission_context import build_orchestrator_prompt, mission_for_task, problem_statement
 
130
  return env
131
 
132
 
133
+ def _get_cluster_env(session_id: str) -> ClusterTrustEnv:
134
+ env = _get_env(session_id)
135
+ if not isinstance(env, ClusterTrustEnv):
136
+ raise HTTPException(
137
+ status_code=400,
138
+ detail=(
139
+ "Session is in abstract SentinelEnv mode. Start a cluster session via "
140
+ "POST /cluster/reset (or POST /reset with mode='cluster')."
141
+ ),
142
+ )
143
+ return env
144
+
145
+
146
  def _resolve_env_mode(task_type: str | None, mode: str | None = None) -> tuple[str, str]:
147
  requested_task = task_type or "task3"
148
  requested_mode = (mode or "").lower()
 
231
  reasoning: str | None = None
232
 
233
 
234
+ # Cluster-only request shapes. Kept separate from ResetRequest/StepRequest so
235
+ # the OpenAPI schema makes the GPU-cluster contract explicit.
236
+
237
+ CLUSTER_ACTION_TYPES = ("allocate", "preempt", "request_info", "verify", "tick")
238
+
239
+
240
+ class ClusterResetRequest(BaseModel):
241
+ task_type: str | None = None # "task1" | "task2" | "task3" (also accepts "cluster_task*")
242
+ seed: int | None = None
243
+ adaptive: bool = False
244
+
245
+
246
+ class ClusterStepRequest(BaseModel):
247
+ action_type: str # allocate | preempt | request_info | verify | tick
248
+ job_id: str | None = None
249
+ gpu_id: str | None = None
250
+ worker_id: str | None = None
251
+ force_flag: bool | None = None
252
+ reasoning: str | None = None
253
+
254
+
255
  # ---------------------------------------------------------------------------
256
  # Endpoints
257
  # ---------------------------------------------------------------------------
 
289
  "/grader", "/reward-report", "/difficulty", "/stream", "/trust-dashboard",
290
  "/cluster-dashboard",
291
  "/reset", "/step", "/state",
292
+ "/cluster", "/cluster/metadata", "/cluster/tasks",
293
+ "/cluster/reset", "/cluster/step", "/cluster/state",
294
+ "/cluster/gpus", "/cluster/jobs", "/cluster/workers",
295
+ "/cluster/audit", "/cluster/audit/investigate",
296
+ "/cluster/ai-failure-coverage", "/cluster/reward-report", "/cluster/stream",
297
  ],
298
  }
299
  )
 
347
  "/grader", "/reward-report", "/difficulty", "/stream", "/trust-dashboard",
348
  "/cluster-dashboard",
349
  "/reset", "/step", "/state",
350
+ "/cluster", "/cluster/metadata", "/cluster/tasks",
351
+ "/cluster/reset", "/cluster/step", "/cluster/state",
352
+ "/cluster/gpus", "/cluster/jobs", "/cluster/workers",
353
+ "/cluster/audit", "/cluster/audit/investigate",
354
+ "/cluster/ai-failure-coverage", "/cluster/reward-report", "/cluster/stream",
355
  ],
356
  }
357
 
 
413
  },
414
  "adaptive_curriculum": GLOBAL_DIFFICULTY_CONTROLLER.state(),
415
  "cluster_mode": {
416
+ "how_to_enable": (
417
+ "POST /cluster/reset with {\"task_type\":\"task3\"} (preferred), "
418
+ "or POST /reset with {\"mode\":\"cluster\",\"task_type\":\"task3\"} "
419
+ "or {\"task_type\":\"cluster_task3\"}."
420
+ ),
421
  "live_dashboard": "/cluster-dashboard?session_id=<session_id>",
422
+ "api_root": "/cluster",
423
  },
424
  }
425
 
 
627
  raise HTTPException(status_code=400, detail=f"Unknown method: {method}")
628
 
629
 
630
+ # ---------------------------------------------------------------------------
631
+ # Cluster API (GPU cluster trust mission, namespaced under /cluster/*)
632
+ # ---------------------------------------------------------------------------
633
+
634
+
635
+ def _cluster_task_type(raw: str | None) -> str:
636
+ task_type = (raw or "task3").removeprefix("cluster_")
637
+ if task_type not in CLUSTER_TASK_CONFIG:
638
+ raise HTTPException(
639
+ status_code=400,
640
+ detail=(
641
+ f"Unknown cluster task_type '{raw}'. "
642
+ f"Expected one of: {', '.join(sorted(CLUSTER_TASK_CONFIG))}."
643
+ ),
644
+ )
645
+ return task_type
646
+
647
+
648
+ @app.get("/cluster")
649
+ def cluster_root():
650
+ return {
651
+ "name": "sentinel-cluster",
652
+ "summary": (
653
+ "GPU cluster trust calibration API. The orchestrator schedules jobs across "
654
+ "GPUs, audits worker reports, and routes around adversarial false completions "
655
+ "while keeping cluster health and AI reliability high."
656
+ ),
657
+ "session_lifecycle": [
658
+ "POST /cluster/reset -> {info.session_id}",
659
+ "POST /cluster/step?session_id=...",
660
+ "GET /cluster/state?session_id=... (or /cluster/stream for SSE)",
661
+ ],
662
+ "routes": [
663
+ "POST /cluster/reset",
664
+ "POST /cluster/step",
665
+ "GET /cluster/state",
666
+ "GET /cluster/gpus",
667
+ "GET /cluster/jobs",
668
+ "GET /cluster/workers",
669
+ "GET /cluster/audit",
670
+ "GET /cluster/audit/investigate",
671
+ "GET /cluster/ai-failure-coverage",
672
+ "GET /cluster/reward-report",
673
+ "GET /cluster/stream",
674
+ "GET /cluster/metadata",
675
+ "GET /cluster/tasks",
676
+ "GET /cluster-dashboard",
677
+ ],
678
+ }
679
+
680
+
681
+ @app.get("/cluster/metadata")
682
+ def cluster_metadata():
683
+ return {
684
+ "tasks": {
685
+ "task1": {**CLUSTER_TASK_CONFIG["task1"], "name": "Cluster Basics"},
686
+ "task2": {**CLUSTER_TASK_CONFIG["task2"], "name": "Unreliable Workers"},
687
+ "task3": {**CLUSTER_TASK_CONFIG["task3"], "name": "Full Adversarial Cluster"},
688
+ },
689
+ "action_types": {
690
+ "allocate": {"description": "Place a queued job on a GPU and assign a worker.",
691
+ "fields": ["job_id?", "gpu_id?", "worker_id?"]},
692
+ "preempt": {"description": "Free a running job from its GPU.",
693
+ "fields": ["job_id?"]},
694
+ "request_info": {"description": "Ask the assigned worker for a fresh progress report.",
695
+ "fields": ["job_id?", "worker_id?"]},
696
+ "verify": {"description": "Audit a worker's report. Catches false completions and lying.",
697
+ "fields": ["job_id?", "worker_id?", "force_flag?"]},
698
+ "tick": {"description": "Advance the cluster clock without acting.",
699
+ "fields": []},
700
+ },
701
+ "workers": list(["S0", "S1", "S2", "S3", "S4"]),
702
+ "scoring": "global_reward = weighted(orchestrator, resource_manager, auditor, worker) × cluster_health × ai_reliability_modifier",
703
+ "terminal": "task1: jobs+util | task2: jobs+calibration+deadlines | task3: jobs+detection+plan_coherence+efficiency",
704
+ "controller": GLOBAL_DIFFICULTY_CONTROLLER.state(),
705
+ }
706
+
707
+
708
+ @app.get("/cluster/tasks")
709
+ def cluster_tasks():
710
+ descriptions = {
711
+ "task1": "10-job warmup. No adversary, no GPU failures. Learn the allocate/preempt/tick loop.",
712
+ "task2": "20-job stream with unreliable/slow/degrading workers and rare GPU failures.",
713
+ "task3": "30-job adversarial cluster: false memory reports, false completions, poisoned reward claims.",
714
+ }
715
+ out: dict[str, Any] = {}
716
+ for tid, cfg in CLUSTER_TASK_CONFIG.items():
717
+ out[tid] = {
718
+ "difficulty": {"task1": "easy", "task2": "medium", "task3": "hard"}[tid],
719
+ "description": descriptions[tid],
720
+ "adversary_active": cfg["adversary"],
721
+ "jobs": cfg["jobs"],
722
+ "gpus": cfg["gpus"],
723
+ "max_steps": cfg["max_steps"],
724
+ "failure_probability": cfg["failure_probability"],
725
+ }
726
+ return out
727
+
728
+
729
+ @app.post("/cluster/reset")
730
+ def cluster_reset(req: ClusterResetRequest = ClusterResetRequest()):
731
+ task_type = _cluster_task_type(req.task_type)
732
+ env = ClusterTrustEnv()
733
+ result = env.reset(task_type=task_type, seed=req.seed, adaptive=req.adaptive)
734
+ session_id = result["info"]["session_id"]
735
+ _sessions.set(session_id, env)
736
+ return _add_demo_context(result, env)
737
+
738
+
739
+ @app.post("/cluster/step")
740
+ def cluster_step(req: ClusterStepRequest, session_id: str = Query(...)):
741
+ if req.action_type not in CLUSTER_ACTION_TYPES:
742
+ raise HTTPException(
743
+ status_code=400,
744
+ detail=f"Unknown cluster action_type '{req.action_type}'. Expected one of: {', '.join(CLUSTER_ACTION_TYPES)}.",
745
+ )
746
+ env = _get_cluster_env(session_id)
747
+ try:
748
+ result = env.step(req.model_dump(exclude_none=True))
749
+ except (RuntimeError, ValueError) as exc:
750
+ raise HTTPException(status_code=400, detail=str(exc))
751
+
752
+ if result["done"]:
753
+ _sessions.pop(session_id)
754
+ else:
755
+ _add_demo_context(result, env)
756
+ return result
757
+
758
+
759
+ @app.get("/cluster/state")
760
+ def cluster_state(session_id: str = Query(...)):
761
+ env = _get_cluster_env(session_id)
762
+ return env.state()
763
+
764
+
765
+ @app.get("/cluster/gpus")
766
+ def cluster_gpus(session_id: str = Query(...), include_hidden: bool = Query(False)):
767
+ env = _get_cluster_env(session_id)
768
+ return {
769
+ "summary": env._pool.summary(),
770
+ "gpus": env._pool.snapshot(include_hidden=include_hidden),
771
+ }
772
+
773
+
774
+ @app.get("/cluster/jobs")
775
+ def cluster_jobs(
776
+ session_id: str = Query(...),
777
+ include_hidden: bool = Query(False),
778
+ deadline_window: int = Query(10, ge=1, le=240),
779
+ ):
780
+ env = _get_cluster_env(session_id)
781
+ return {
782
+ "summary": env._jobs.summary(),
783
+ "jobs": env._jobs.snapshot(include_hidden=include_hidden),
784
+ "deadline_pressure": [
785
+ job.job_id for job in env._jobs.deadline_pressure(env.step_count, window=deadline_window)
786
+ ],
787
+ }
788
+
789
+
790
+ @app.get("/cluster/workers")
791
+ def cluster_workers(session_id: str = Query(...)):
792
+ env = _get_cluster_env(session_id)
793
+ return {
794
+ "available": env._workers.available_ids(),
795
+ "trust_snapshot": env._trust.snapshot(),
796
+ "behavioral_fingerprints": env._trust.behavioral_fingerprints(),
797
+ "public_ground_truth_reliability": env._workers.public_ground_truth_reliability(),
798
+ }
799
+
800
+
801
+ @app.get("/cluster/audit")
802
+ def cluster_audit(session_id: str = Query(...)):
803
+ env = _get_cluster_env(session_id)
804
+ return env._audit.snapshot()
805
+
806
+
807
+ @app.get("/cluster/audit/investigate")
808
+ def cluster_audit_investigate(
809
+ session_id: str = Query(...),
810
+ agent_id: str = Query(..., description="Worker public id (S0..S4) or 'cluster'/'adversary'/'auditor'."),
811
+ window: int = Query(10, ge=1, le=240),
812
+ ):
813
+ env = _get_cluster_env(session_id)
814
+ return env._audit.investigate(agent_id, window=window)
815
+
816
+
817
+ @app.get("/cluster/ai-failure-coverage")
818
+ def cluster_ai_failure_coverage(session_id: str = Query(...)):
819
+ env = _get_cluster_env(session_id)
820
+ return env.ai_failure_coverage()
821
+
822
+
823
+ @app.get("/cluster/reward-report")
824
+ def cluster_reward_report(session_id: str = Query(...)):
825
+ env = _get_cluster_env(session_id)
826
+ return env.reward_report()
827
+
828
+
829
+ @app.get("/cluster/stream")
830
+ async def cluster_stream(session_id: str = Query(...)):
831
+ async def event_gen():
832
+ while True:
833
+ env = _sessions.get(session_id)
834
+ if env is None or not isinstance(env, ClusterTrustEnv):
835
+ yield (
836
+ "event: close\n"
837
+ "data: {\"reason\":\"session_not_found_or_not_cluster\"}\n\n"
838
+ )
839
+ break
840
+ yield f"data: {json.dumps(env.stream_snapshot())}\n\n"
841
+ if env.done:
842
+ break
843
+ await asyncio.sleep(0.5)
844
+
845
+ return StreamingResponse(
846
+ event_gen(),
847
+ media_type="text/event-stream",
848
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
849
+ )
850
+
851
+
852
  def _trust_dashboard_html(session_id: str) -> str:
853
  escaped_session = html.escape(session_id, quote=True)
854
  return f"""<!doctype html>