from __future__ import annotations import os import sys import time import traceback from pathlib import Path from typing import Any, Dict, List, Optional import uvicorn from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi import Request try: from openenv.core.env_server import create_fastapi_app except Exception: from fastapi import FastAPI def create_fastapi_app(env_cls, action_cls, observation_cls): app = FastAPI(title="DesignGym") @app.get("/health") def health(): return {"status": "healthy"} return app try: from ..models import DesignGymAction, DesignGymObservation except Exception: from models import DesignGymAction, DesignGymObservation try: from .DesignGym_environment import DesignGymEnvironment, TASKS except Exception: from server.DesignGym_environment import DesignGymEnvironment, TASKS _REPO_ROOT = Path(__file__).resolve().parent.parent if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) try: import inference as INF _INFERENCE_OK = True except Exception: INF = None _INFERENCE_OK = False try: from openai import OpenAI _OPENAI_OK = True except Exception: OpenAI = None _OPENAI_OK = False _LLM_CLIENT_CACHE: Dict[str, Any] = {"client": None, "key": None} def _get_llm_client(): backend = os.getenv("DESIGNGYM_BACKEND", "local") if backend == "local": try: from local_model import get_client return get_client() except Exception: return None if backend == "router": if not (_INFERENCE_OK and _OPENAI_OK): return None token = os.getenv("HF_TOKEN") or getattr(INF, "HF_TOKEN", None) base = os.getenv("API_BASE_URL", getattr(INF, "API_BASE_URL", "https://router.huggingface.co/v1")) if not token: return None cache_key = f"router::{token[:6]}::{base}" if _LLM_CLIENT_CACHE["key"] != cache_key: _LLM_CLIENT_CACHE["client"] = OpenAI(base_url=base, api_key=token) _LLM_CLIENT_CACHE["key"] = cache_key return _LLM_CLIENT_CACHE["client"] return None app = create_fastapi_app( DesignGymEnvironment, DesignGymAction, DesignGymObservation, ) if os.getenv("DESIGNGYM_BACKEND", "local") == "local": try: from local_model import warm_up_async warm_up_async() print("[server] local model warm-up thread started", flush=True) except Exception as _warmup_err: print(f"[server] local model warm-up skipped: {_warmup_err}", flush=True) DEMO_ENV = DesignGymEnvironment() _LAST_OBS: Dict[str, Any] = {"obs": None} def _current_obs(): """Return the most recent observation, falling back to the env's internal builder.""" if _LAST_OBS.get("obs") is not None: return _LAST_OBS["obs"] builder = getattr(DEMO_ENV, "_observation", None) if callable(builder): try: return builder(message="snapshot") except Exception: return None return None ROOT_DIR = Path(__file__).resolve().parent.parent ASSETS_DIR = ROOT_DIR / "assets" if ASSETS_DIR.exists(): app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets") WEB_DIR = ROOT_DIR / "web" def _html_response(): """Serve index.html with no-cache headers so Safari doesn't serve stale versions.""" resp = FileResponse(str(WEB_DIR / "index.html")) resp.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" resp.headers["Pragma"] = "no-cache" return resp @app.get("/", include_in_schema=False) def home(): return _html_response() @app.get("/web", include_in_schema=False) def web_index_no_slash(): return _html_response() @app.get("/web/", include_in_schema=False) def web_index(): return _html_response() @app.get("/web/{path:path}", include_in_schema=False) def web_static(path: str): file_path = WEB_DIR / path if not file_path.exists() or not file_path.is_file(): return _html_response() return FileResponse(str(file_path)) def _task_description(task_id: str) -> str: if task_id == "poster_basic_v1": return "Poster layout optimization with hero image, title hierarchy, CTA placement, and alignment." if task_id == "editorial_cover_v1": return "Editorial cover optimization with masthead preservation, headline stack, and reading order." if task_id == "dense_flyer_v1": return "Dense flyer optimization with support-group reflow, spacing, occupancy, and caption alignment." return "Design layout optimization task." def _task_catalog(): difficulty_map = { "poster_basic_v1": "easy", "editorial_cover_v1": "medium", "dense_flyer_v1": "hard", } catalog = [] for task_id, spec in TASKS.items(): catalog.append( { "task_id": task_id, "difficulty": difficulty_map.get(task_id, "medium"), "graded": True, "grader": { "type": "programmatic", "name": "deterministic_layout_utility", "deterministic": True, "source": "server/DesignGym_environment.py", }, "description": _task_description(task_id), "max_steps": int(spec.get("max_steps", 0)), "instance_id": spec.get("instance_id"), "reward_range": [0.0, 1.0], "score_range": [0.0, 1.0], } ) return catalog @app.get("/info") def info(): tasks = _task_catalog() return JSONResponse( { "name": "DesignGym", "description": "OpenEnv-compatible reinforcement learning environment for design layout optimization.", "task_count": len(tasks), "default_task_id": "poster_basic_v1", "tasks": tasks, "reward_range": [0.0, 1.0], "score_range": [0.0, 1.0], "supports_seeded_reset": True, "supports_task_id_reset": True, } ) @app.get("/demo/ping") def demo_ping(): return {"ok": True, "message": "DesignGym 2.0 demo endpoints are live"} @app.get("/demo/backend_info") def demo_backend_info(): try: from local_model import describe_client, ADAPTERS as _ADAPTERS, BASE_MODEL as _BASE client = _get_llm_client() info = describe_client(client) info["available_adapters"] = _ADAPTERS info["base_model"] = _BASE info["env"] = { "DESIGNGYM_BACKEND": os.getenv("DESIGNGYM_BACKEND", "local"), "DESIGNGYM_ADAPTER": os.getenv("DESIGNGYM_ADAPTER", "sft"), "HF_TOKEN_present": bool(os.getenv("HF_TOKEN")), } return info except Exception as e: return {"backend": "unknown", "error": str(e)} @app.post("/demo/switch_adapter") async def demo_switch_adapter(request: Request): """Switch the active LoRA adapter. Triggers a fresh model load -- first call after switch will be slow.""" try: from local_model import get_client, ADAPTERS as _ADAPTERS except ImportError: return JSONResponse(status_code=400, content={"error": "local_model not available"}) payload = await request.json() key = payload.get("adapter") if key not in _ADAPTERS: return JSONResponse( status_code=400, content={"error": f"Unknown adapter {key!r}", "valid": list(_ADAPTERS)}, ) client = get_client(key) _LLM_CLIENT_CACHE["client"] = None _LLM_CLIENT_CACHE["key"] = None client._ensure_loaded() return {"ok": True, "adapter": key, "info": client.describe()} @app.get("/tasks") def tasks(): return JSONResponse( { "tasks": _task_catalog() } ) @app.post("/demo/reset") async def demo_reset(request: Request): payload = await request.json() obs = DEMO_ENV.reset(**payload) _LAST_OBS["obs"] = obs return { "observation": obs.model_dump(), "state": DEMO_ENV.state.model_dump(), "reward": 0.0, "done": False, } @app.post("/demo/step") async def demo_step(request: Request): payload = await request.json() action_payload = payload.get("action", payload) action = DesignGymAction(**action_payload) obs = DEMO_ENV.step(action) _LAST_OBS["obs"] = obs return { "observation": obs.model_dump(), "state": DEMO_ENV.state.model_dump(), "reward": float(DEMO_ENV.state.last_reward), "done": bool(DEMO_ENV.state.done), } @app.get("/demo/state") def demo_state(): return { "state": DEMO_ENV.state.model_dump() } def _summary_from_state(state, trajectory: List[Dict[str, Any]]) -> Dict[str, Any]: rewards = [t.get("reward", 0.0) for t in trajectory] valid = [t for t in trajectory if t.get("error") is None] finalized = any(t.get("action_type") == "finalize" for t in trajectory) return { "final_score": float(getattr(state, "current_score", 0.0) or 0.0), "instruction_score": float(getattr(state, "instruction_score", 0.0) or 0.0), "phase_score": float(getattr(state, "phase_score", 0.0) or 0.0), "best_score_so_far": float(getattr(state, "best_score_so_far", 0.0) or 0.0), "steps_taken": len(trajectory), "total_reward": sum(rewards), "valid_action_rate": (len(valid) / len(trajectory)) if trajectory else 0.0, "finalized": finalized, "done": bool(getattr(state, "done", False)), "phase": getattr(state, "phase", None), "task_id": getattr(state, "task_id", None), } def _record_step(step: int, action, obs, prev_score: float) -> Dict[str, Any]: # last_reward lives on env.state, not on the observation state_reward = getattr(getattr(DEMO_ENV, "state", None), "last_reward", None) reward = state_reward if state_reward is not None else getattr(obs, "last_reward", 0.0) return { "step": step, "action_type": getattr(action, "action_type", None), "action": getattr(action, "canonical", lambda: str(action))(), "reward": float(reward or 0.0), "score": float(getattr(obs, "current_score", 0.0) or 0.0), "delta_score": float(getattr(obs, "current_score", 0.0) or 0.0) - prev_score, "instruction_score": float(getattr(obs, "instruction_score", 0.0) or 0.0), "worst_metrics": list(getattr(obs, "worst_metrics", []) or []), "error": getattr(obs, "last_action_error", None), "done": bool(getattr(obs, "done", False)), } def _choose_action(policy: str, step: int, obs, history: List[str], rewards: List[float], recent_actions: List[str]): """Pick one action using the requested policy. Falls back to heuristic if LLM unavailable.""" if not _INFERENCE_OK: return DesignGymAction(action_type="finalize"), "fallback_no_inference" if policy == "heuristic": return INF.heuristic_action(step, obs, rewards, recent_actions), "heuristic" client = _get_llm_client() if client is None: action = INF.get_model_action_sync(None, step, obs, history, rewards, recent_actions) return action, "heuristic_fallback" action = INF.get_model_action_sync(client, step, obs, history, rewards, recent_actions) try: from local_model import LocalLoRAClient if isinstance(client, LocalLoRAClient): if client.adapter_id: label = f"finetuned_{client.adapter_key}" else: label = "local_base" elif hasattr(client, "base_url"): label = "router_base" else: label = "llm" except ImportError: label = "router_base" if hasattr(client, "base_url") else "llm" return action, label @app.post("/demo/policy_step") async def demo_policy_step(request: Request): """Run a single step using the requested policy ({"policy": "heuristic"|"sft"}).""" payload = await request.json() policy = (payload.get("policy") or "heuristic").lower() obs = _current_obs() if obs is None: return JSONResponse( status_code=409, content={"error": "no_active_episode", "hint": "Call /demo/reset first."}, ) state = DEMO_ENV.state step = int(getattr(state, "step_count", 0) or 0) + 1 recent_actions: List[str] = list(getattr(state, "action_history", []) or []) rewards: List[float] = [] # not tracked on state; OK since heuristic mostly checks last action history = recent_actions[-4:] prev_score = float(getattr(obs, "current_score", 0.0) or 0.0) action, used = _choose_action(policy, step, obs, history, rewards, recent_actions) obs_after = DEMO_ENV.step(action) _LAST_OBS["obs"] = obs_after record = _record_step(step, action, obs_after, prev_score) record["policy"] = used return { "observation": obs_after.model_dump(), "state": DEMO_ENV.state.model_dump(), "step_record": record, "reward": record["reward"], "done": record["done"], } @app.post("/demo/run_episode") async def demo_run_episode(request: Request): """Run a full episode server-side and return the trajectory + summary in one call. Body: {"policy": "heuristic"|"sft", "task_id": str, "seed": int, "max_steps": int} """ payload = await request.json() policy = (payload.get("policy") or "heuristic").lower() task_id = payload.get("task_id") or "poster_basic_v1" seed = int(payload.get("seed") or 0) max_steps_override = payload.get("max_steps") t0 = time.time() obs = DEMO_ENV.reset(task_id=task_id, seed=seed) _LAST_OBS["obs"] = obs declared_max = int(getattr(obs, "max_steps", 8) or 8) if max_steps_override: declared_max = min(declared_max, int(max_steps_override)) trajectory: List[Dict[str, Any]] = [] history: List[str] = [] rewards: List[float] = [] recent_actions: List[str] = [] try: for step in range(1, declared_max + 1): if bool(getattr(obs, "done", False)): break prev_score = float(getattr(obs, "current_score", 0.0) or 0.0) action, used = _choose_action(policy, step, obs, history, rewards, recent_actions) obs = DEMO_ENV.step(action) _LAST_OBS["obs"] = obs record = _record_step(step, action, obs, prev_score) record["policy"] = used trajectory.append(record) history.append(record["action"]) rewards.append(record["reward"]) recent_actions.append(record["action"]) if record["done"]: break summary = _summary_from_state(DEMO_ENV.state, trajectory) summary["policy_requested"] = policy summary["llm_available"] = _get_llm_client() is not None summary["wall_time_sec"] = round(time.time() - t0, 3) return { "summary": summary, "trajectory": trajectory, "final_observation": obs.model_dump(), "final_state": DEMO_ENV.state.model_dump(), } except Exception as exc: return JSONResponse( status_code=500, content={ "error": str(exc), "trace": traceback.format_exc().splitlines()[-12:], "trajectory": trajectory, }, ) @app.get("/demo/policies") def demo_policies(): """Tell the frontend which policies are usable right now.""" client = _get_llm_client() llm_ok = client is not None try: from local_model import describe_client, LocalLoRAClient info = describe_client(client) backend = info.get("backend", "none") adapter_key = info.get("adapter_key", "") except ImportError: backend = "router" if llm_ok else "none" adapter_key = "" if backend == "local-lora": llm_label = f"Fine-tuned {(adapter_key or '').upper()} ยท Qwen2.5-0.5B + LoRA (local CPU)" elif backend == "local-base": llm_label = "Base Qwen2.5-0.5B (local CPU, no adapter)" elif backend == "router": llm_label = "Base Qwen2.5-0.5B (HF Router) โ€” NOT fine-tuned" else: llm_label = "Heuristic only โ€” no LLM available" return { "policies": [ { "id": "heuristic", "label": "Heuristic Planner", "available": _INFERENCE_OK, "description": "Rule-based planner from inference.py โ€” fast, no API key needed.", }, { "id": "llm", "label": llm_label, "available": _INFERENCE_OK and llm_ok, "description": f"Backend: {backend}. Uses the LLM client to pick from candidate actions.", "llm_active": llm_ok, "backend": backend, }, ], "llm_active": llm_ok, "backend": backend, "model_name": getattr(INF, "MODEL_NAME", None) if _INFERENCE_OK else None, } def main() -> None: host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "8000")) uvicorn.run("server.app:app", host=host, port=port, reload=False) if __name__ == "__main__": main()