Spaces:
Sleeping
Sleeping
| """ | |
| LogiCrisis FastAPI β OpenEnv spec compliant. | |
| Endpoints: POST /reset, POST /step, GET /state, GET /tasks, GET /validate, GET /render | |
| """ | |
| from __future__ import annotations | |
| import os, sys | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from environment import LogiCrisisEnv, AgentAction, ActionType | |
| from environment.schemas import ( | |
| ActionSchema, ObservationSchema, RewardSchema, | |
| StepResponseSchema, ResetResponseSchema, TaskSchema, GraderResultSchema, | |
| ) | |
| from environment.tasks import TASKS, ALL_TASK_IDS, get_task | |
| from environment.live_data import LiveDataConnector | |
| app = FastAPI( | |
| title="LogiCrisis OpenEnv", | |
| description=( | |
| "Multi-Agent Logistics Recovery β Meta PyTorch OpenEnv Hackathon\n\n" | |
| "Real-world supply chain crisis simulation with 5 agent roles, " | |
| "6 reward signals, and 3 graded tasks (easy β medium β hard)." | |
| ), | |
| version="1.0.0", | |
| ) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], | |
| allow_methods=["*"], allow_headers=["*"]) | |
| # ββ Session state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _env: Optional[LogiCrisisEnv] = None | |
| _current_task_id: str = "single_route_recovery" | |
| # ββ Request models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| task_id: str = Field( | |
| default="single_route_recovery", | |
| description=f"One of: {', '.join(ALL_TASK_IDS)}" | |
| ) | |
| seed: Optional[int] = Field(default=42, description="Random seed for reproducibility") | |
| class StepRequest(BaseModel): | |
| actions: list[ActionSchema] = Field(..., description="One action per active agent") | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_action(payload: ActionSchema) -> AgentAction: | |
| try: | |
| atype = ActionType(payload.action_type) | |
| except ValueError: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=( | |
| f"Unknown action_type '{payload.action_type}'. " | |
| f"Valid: {[e.value for e in ActionType]}" | |
| ), | |
| ) | |
| return AgentAction( | |
| agent_id=payload.agent_id, | |
| action_type=atype, | |
| cargo_id=payload.cargo_id, | |
| route_id=payload.route_id, | |
| target_region=payload.target_region, | |
| bid_price=payload.bid_price, | |
| bid_capacity=payload.bid_capacity, | |
| target_agent=payload.target_agent, | |
| bid_id=payload.bid_id, | |
| coalition_id=payload.coalition_id, | |
| coalition_members=payload.coalition_members, | |
| coalition_role=payload.coalition_role, | |
| reward_split=payload.reward_split, | |
| reasoning=payload.reasoning, | |
| ) | |
| def _obs_dict(obs_map: dict) -> dict[str, ObservationSchema]: | |
| result = {} | |
| for aid, obs in obs_map.items(): | |
| result[aid] = ObservationSchema( | |
| agent_id=obs.agent_id, | |
| role=obs.role.value, | |
| turn=obs.turn, | |
| max_turns=obs.max_turns, | |
| own_region=obs.own_region, | |
| own_capacity_tons=obs.own_capacity_tons, | |
| own_budget=obs.own_budget, | |
| own_cargo_queue=obs.own_cargo_queue, | |
| pending_deadlines=[list(d) for d in obs.pending_deadlines], | |
| disrupted_routes=obs.disrupted_routes, | |
| disrupted_nodes=obs.disrupted_nodes, | |
| neighbor_bids=obs.neighbor_bids, | |
| coalition_proposals=obs.coalition_proposals, | |
| action_history=obs.action_history, | |
| active_coalition_id=obs.active_coalition_id, | |
| active_contracts=obs.active_contracts, | |
| prompt_text=obs.to_prompt_text(), | |
| ) | |
| return result | |
| def _reward_breakdown(rb_map: dict) -> dict[str, RewardSchema]: | |
| result = {} | |
| for aid, rb in rb_map.items(): | |
| result[aid] = RewardSchema( | |
| R1_delivery=rb.get("R1_delivery", 0.0), | |
| R2_coalition=rb.get("R2_coalition", 0.0), | |
| R3_negotiation=rb.get("R3_negotiation", 0.0), | |
| R4_cold_chain=rb.get("R4_cold_chain", 0.0), | |
| R5_efficiency=rb.get("R5_efficiency", 0.0), | |
| R6_anti_cheat=rb.get("R6_anti_cheat", 0.0), | |
| shared_bonus=rb.get("shared_bonus", 0.0), | |
| total=rb.get("total", 0.0), | |
| ) | |
| return result | |
| # ββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "env": "LogiCrisis", | |
| "version": "1.0.0", | |
| "tasks": ALL_TASK_IDS, | |
| "openenv_spec": "step/reset/state compliant", | |
| } | |
| def reset(req: ResetRequest): | |
| global _env, _current_task_id | |
| if req.task_id not in TASKS: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Unknown task '{req.task_id}'. Valid: {ALL_TASK_IDS}" | |
| ) | |
| task = get_task(req.task_id) | |
| _current_task_id = req.task_id | |
| _env = task.make_env(seed=req.seed or 42) | |
| observations = _env.reset() | |
| return { | |
| "task_id": req.task_id, | |
| "observations": {aid: obs.model_dump() for aid, obs in _obs_dict(observations).items()}, | |
| "world_state": _env.state(), | |
| "message": ( | |
| f"Task '{req.task_id}' started | " | |
| f"Agents: {list(observations.keys())} | " | |
| f"Disruptions: {len(_env.world.disruptions)} | " | |
| f"Cargo: {len(_env.world.cargo_queue)}" | |
| ), | |
| } | |
| def step(req: StepRequest): | |
| if _env is None: | |
| raise HTTPException(status_code=400, detail="Call POST /reset first.") | |
| if not req.actions: | |
| raise HTTPException(status_code=422, detail="'actions' list must not be empty.") | |
| actions: dict[str, AgentAction] = {} | |
| for payload in req.actions: | |
| action = _parse_action(payload) | |
| actions[action.agent_id] = action | |
| result = _env.step(actions) | |
| return { | |
| "observations": { | |
| aid: obs.model_dump() | |
| for aid, obs in _obs_dict(result.observations).items() | |
| }, | |
| "rewards": result.rewards, | |
| "reward_breakdown": { | |
| aid: rb.model_dump() | |
| for aid, rb in _reward_breakdown(result.reward_breakdown).items() | |
| }, | |
| "terminated": result.terminated, | |
| "truncated": result.truncated, | |
| "info": result.info, | |
| } | |
| def get_state(): | |
| if _env is None: | |
| raise HTTPException(status_code=400, detail="Call POST /reset first.") | |
| return _env.state() | |
| def render(): | |
| return get_state() | |
| def list_tasks(): | |
| tasks = [] | |
| for task_id, cls in TASKS.items(): | |
| t = cls() | |
| tasks.append(TaskSchema( | |
| id=t.id, | |
| name=t.name, | |
| difficulty=t.difficulty, | |
| description=t.description, | |
| max_turns=t.max_turns, | |
| reward_range=t.reward_range, | |
| agents=t.agents, | |
| cargo_count=t.cargo_count, | |
| disruptions=t.disruptions, | |
| ).model_dump()) | |
| return {"tasks": tasks} | |
| def grade(): | |
| if _env is None: | |
| raise HTTPException(status_code=400, detail="Call POST /reset + run steps first.") | |
| task = get_task(_current_task_id) | |
| result = task.grade(_env) | |
| return GraderResultSchema(**result).model_dump() | |
| def validate(): | |
| """ | |
| Checks that all required OpenEnv endpoints respond and types are correct. | |
| Returns pass/fail per check for the automated validator. | |
| """ | |
| checks = {} | |
| # 1. Tasks endpoint | |
| try: | |
| t = list_tasks() | |
| checks["tasks_endpoint"] = len(t["tasks"]) >= 3 | |
| except Exception as e: | |
| checks["tasks_endpoint"] = False | |
| # 2. Reset works for each task | |
| for tid in ALL_TASK_IDS: | |
| try: | |
| from fastapi.testclient import TestClient | |
| # Inline check without HTTP | |
| task = get_task(tid) | |
| env = task.make_env(seed=42) | |
| obs = env.reset() | |
| checks[f"reset_{tid}"] = len(obs) > 0 | |
| except Exception: | |
| checks[f"reset_{tid}"] = False | |
| # 3. Graders return 0.0β1.0 | |
| for tid in ALL_TASK_IDS: | |
| try: | |
| task = get_task(tid) | |
| env = task.make_env(seed=42) | |
| env.reset() | |
| result = task.grade(env) | |
| score = result["score"] | |
| checks[f"grader_{tid}"] = 0.0 <= score <= 1.0 | |
| except Exception: | |
| checks[f"grader_{tid}"] = False | |
| # 4. Reward range | |
| checks["reward_range_valid"] = True # enforced by RewardSchema | |
| # 5. Typed models | |
| checks["pydantic_schemas"] = True # enforced by FastAPI | |
| all_pass = all(checks.values()) | |
| return { | |
| "valid": all_pass, | |
| "checks": checks, | |
| "spec_version": "openenv@1.0.0", | |
| } | |
| def live_data(): | |
| """ | |
| Polls OpenWeatherMap, ExchangeRate-API, and GDELT for real-world disruption signals. | |
| Falls back to synthetic data automatically if API keys are missing or calls fail. | |
| Set OPENWEATHERMAP_API_KEY env var to enable live weather data. | |
| """ | |
| connector = LiveDataConnector() | |
| return connector.get_all_disruptions() | |
| def training_log(): | |
| log_path = "/tmp/training.log" | |
| try: | |
| with open(log_path, "r") as f: | |
| lines = f.readlines() | |
| tail = lines[-80:] if len(lines) > 80 else lines | |
| return {"status": "found", "lines": len(lines), "tail": "".join(tail)} | |
| except FileNotFoundError: | |
| return {"status": "not_started", "tail": "Training log not created yet β training may still be starting up."} | |
| def action_types(): | |
| return {"action_types": [e.value for e in ActionType]} | |
| def agent_roles(): | |
| from environment.models import AgentRole | |
| return {"agent_roles": [e.value for e in AgentRole]} | |
| # ββ Mount Gradio demo at /gradio ββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import gradio as gr | |
| from demo.app import demo as gradio_demo | |
| app = gr.mount_gradio_app(app, gradio_demo, path="/gradio") | |
| except Exception: | |
| pass # Gradio optional β API still works without it | |