Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| import sys | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import uvicorn | |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi import Request | |
| try: | |
| from openenv.core.env_server import create_fastapi_app | |
| except Exception: | |
| from fastapi import FastAPI | |
| def create_fastapi_app(env_cls, action_cls, observation_cls): | |
| app = FastAPI(title="DesignGym") | |
| def health(): | |
| return {"status": "healthy"} | |
| return app | |
| try: | |
| from ..models import DesignGymAction, DesignGymObservation | |
| except Exception: | |
| from models import DesignGymAction, DesignGymObservation | |
| try: | |
| from .DesignGym_environment import DesignGymEnvironment, TASKS | |
| except Exception: | |
| from server.DesignGym_environment import DesignGymEnvironment, TASKS | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| try: | |
| import inference as INF | |
| _INFERENCE_OK = True | |
| except Exception: | |
| INF = None | |
| _INFERENCE_OK = False | |
| try: | |
| from openai import OpenAI | |
| _OPENAI_OK = True | |
| except Exception: | |
| OpenAI = None | |
| _OPENAI_OK = False | |
| _LLM_CLIENT_CACHE: Dict[str, Any] = {"client": None, "key": None} | |
| def _get_llm_client(): | |
| backend = os.getenv("DESIGNGYM_BACKEND", "local") | |
| if backend == "local": | |
| try: | |
| from local_model import get_client | |
| return get_client() | |
| except Exception: | |
| return None | |
| if backend == "router": | |
| if not (_INFERENCE_OK and _OPENAI_OK): | |
| return None | |
| token = os.getenv("HF_TOKEN") or getattr(INF, "HF_TOKEN", None) | |
| base = os.getenv("API_BASE_URL", getattr(INF, "API_BASE_URL", "https://router.huggingface.co/v1")) | |
| if not token: | |
| return None | |
| cache_key = f"router::{token[:6]}::{base}" | |
| if _LLM_CLIENT_CACHE["key"] != cache_key: | |
| _LLM_CLIENT_CACHE["client"] = OpenAI(base_url=base, api_key=token) | |
| _LLM_CLIENT_CACHE["key"] = cache_key | |
| return _LLM_CLIENT_CACHE["client"] | |
| return None | |
| app = create_fastapi_app( | |
| DesignGymEnvironment, | |
| DesignGymAction, | |
| DesignGymObservation, | |
| ) | |
| if os.getenv("DESIGNGYM_BACKEND", "local") == "local": | |
| try: | |
| from local_model import warm_up_async | |
| warm_up_async() | |
| print("[server] local model warm-up thread started", flush=True) | |
| except Exception as _warmup_err: | |
| print(f"[server] local model warm-up skipped: {_warmup_err}", flush=True) | |
| DEMO_ENV = DesignGymEnvironment() | |
| _LAST_OBS: Dict[str, Any] = {"obs": None} | |
| def _current_obs(): | |
| """Return the most recent observation, falling back to the env's internal builder.""" | |
| if _LAST_OBS.get("obs") is not None: | |
| return _LAST_OBS["obs"] | |
| builder = getattr(DEMO_ENV, "_observation", None) | |
| if callable(builder): | |
| try: | |
| return builder(message="snapshot") | |
| except Exception: | |
| return None | |
| return None | |
| ROOT_DIR = Path(__file__).resolve().parent.parent | |
| ASSETS_DIR = ROOT_DIR / "assets" | |
| if ASSETS_DIR.exists(): | |
| app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets") | |
| WEB_DIR = ROOT_DIR / "web" | |
| def _html_response(): | |
| """Serve index.html with no-cache headers so Safari doesn't serve stale versions.""" | |
| resp = FileResponse(str(WEB_DIR / "index.html")) | |
| resp.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" | |
| resp.headers["Pragma"] = "no-cache" | |
| return resp | |
| def home(): | |
| return _html_response() | |
| def web_index_no_slash(): | |
| return _html_response() | |
| def web_index(): | |
| return _html_response() | |
| def web_static(path: str): | |
| file_path = WEB_DIR / path | |
| if not file_path.exists() or not file_path.is_file(): | |
| return _html_response() | |
| return FileResponse(str(file_path)) | |
| def _task_description(task_id: str) -> str: | |
| if task_id == "poster_basic_v1": | |
| return "Poster layout optimization with hero image, title hierarchy, CTA placement, and alignment." | |
| if task_id == "editorial_cover_v1": | |
| return "Editorial cover optimization with masthead preservation, headline stack, and reading order." | |
| if task_id == "dense_flyer_v1": | |
| return "Dense flyer optimization with support-group reflow, spacing, occupancy, and caption alignment." | |
| return "Design layout optimization task." | |
| def _task_catalog(): | |
| difficulty_map = { | |
| "poster_basic_v1": "easy", | |
| "editorial_cover_v1": "medium", | |
| "dense_flyer_v1": "hard", | |
| } | |
| catalog = [] | |
| for task_id, spec in TASKS.items(): | |
| catalog.append( | |
| { | |
| "task_id": task_id, | |
| "difficulty": difficulty_map.get(task_id, "medium"), | |
| "graded": True, | |
| "grader": { | |
| "type": "programmatic", | |
| "name": "deterministic_layout_utility", | |
| "deterministic": True, | |
| "source": "server/DesignGym_environment.py", | |
| }, | |
| "description": _task_description(task_id), | |
| "max_steps": int(spec.get("max_steps", 0)), | |
| "instance_id": spec.get("instance_id"), | |
| "reward_range": [0.0, 1.0], | |
| "score_range": [0.0, 1.0], | |
| } | |
| ) | |
| return catalog | |
| def info(): | |
| tasks = _task_catalog() | |
| return JSONResponse( | |
| { | |
| "name": "DesignGym", | |
| "description": "OpenEnv-compatible reinforcement learning environment for design layout optimization.", | |
| "task_count": len(tasks), | |
| "default_task_id": "poster_basic_v1", | |
| "tasks": tasks, | |
| "reward_range": [0.0, 1.0], | |
| "score_range": [0.0, 1.0], | |
| "supports_seeded_reset": True, | |
| "supports_task_id_reset": True, | |
| } | |
| ) | |
| def demo_ping(): | |
| return {"ok": True, "message": "DesignGym 2.0 demo endpoints are live"} | |
| def demo_backend_info(): | |
| try: | |
| from local_model import describe_client, ADAPTERS as _ADAPTERS, BASE_MODEL as _BASE | |
| client = _get_llm_client() | |
| info = describe_client(client) | |
| info["available_adapters"] = _ADAPTERS | |
| info["base_model"] = _BASE | |
| info["env"] = { | |
| "DESIGNGYM_BACKEND": os.getenv("DESIGNGYM_BACKEND", "local"), | |
| "DESIGNGYM_ADAPTER": os.getenv("DESIGNGYM_ADAPTER", "sft"), | |
| "HF_TOKEN_present": bool(os.getenv("HF_TOKEN")), | |
| } | |
| return info | |
| except Exception as e: | |
| return {"backend": "unknown", "error": str(e)} | |
| async def demo_switch_adapter(request: Request): | |
| """Switch the active LoRA adapter. Triggers a fresh model load -- first call after switch will be slow.""" | |
| try: | |
| from local_model import get_client, ADAPTERS as _ADAPTERS | |
| except ImportError: | |
| return JSONResponse(status_code=400, content={"error": "local_model not available"}) | |
| payload = await request.json() | |
| key = payload.get("adapter") | |
| if key not in _ADAPTERS: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": f"Unknown adapter {key!r}", "valid": list(_ADAPTERS)}, | |
| ) | |
| client = get_client(key) | |
| _LLM_CLIENT_CACHE["client"] = None | |
| _LLM_CLIENT_CACHE["key"] = None | |
| client._ensure_loaded() | |
| return {"ok": True, "adapter": key, "info": client.describe()} | |
| def tasks(): | |
| return JSONResponse( | |
| { | |
| "tasks": _task_catalog() | |
| } | |
| ) | |
| async def demo_reset(request: Request): | |
| payload = await request.json() | |
| obs = DEMO_ENV.reset(**payload) | |
| _LAST_OBS["obs"] = obs | |
| return { | |
| "observation": obs.model_dump(), | |
| "state": DEMO_ENV.state.model_dump(), | |
| "reward": 0.0, | |
| "done": False, | |
| } | |
| async def demo_step(request: Request): | |
| payload = await request.json() | |
| action_payload = payload.get("action", payload) | |
| action = DesignGymAction(**action_payload) | |
| obs = DEMO_ENV.step(action) | |
| _LAST_OBS["obs"] = obs | |
| return { | |
| "observation": obs.model_dump(), | |
| "state": DEMO_ENV.state.model_dump(), | |
| "reward": float(DEMO_ENV.state.last_reward), | |
| "done": bool(DEMO_ENV.state.done), | |
| } | |
| def demo_state(): | |
| return { | |
| "state": DEMO_ENV.state.model_dump() | |
| } | |
| def _summary_from_state(state, trajectory: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| rewards = [t.get("reward", 0.0) for t in trajectory] | |
| valid = [t for t in trajectory if t.get("error") is None] | |
| finalized = any(t.get("action_type") == "finalize" for t in trajectory) | |
| return { | |
| "final_score": float(getattr(state, "current_score", 0.0) or 0.0), | |
| "instruction_score": float(getattr(state, "instruction_score", 0.0) or 0.0), | |
| "phase_score": float(getattr(state, "phase_score", 0.0) or 0.0), | |
| "best_score_so_far": float(getattr(state, "best_score_so_far", 0.0) or 0.0), | |
| "steps_taken": len(trajectory), | |
| "total_reward": sum(rewards), | |
| "valid_action_rate": (len(valid) / len(trajectory)) if trajectory else 0.0, | |
| "finalized": finalized, | |
| "done": bool(getattr(state, "done", False)), | |
| "phase": getattr(state, "phase", None), | |
| "task_id": getattr(state, "task_id", None), | |
| } | |
| def _record_step(step: int, action, obs, prev_score: float) -> Dict[str, Any]: | |
| # last_reward lives on env.state, not on the observation | |
| state_reward = getattr(getattr(DEMO_ENV, "state", None), "last_reward", None) | |
| reward = state_reward if state_reward is not None else getattr(obs, "last_reward", 0.0) | |
| return { | |
| "step": step, | |
| "action_type": getattr(action, "action_type", None), | |
| "action": getattr(action, "canonical", lambda: str(action))(), | |
| "reward": float(reward or 0.0), | |
| "score": float(getattr(obs, "current_score", 0.0) or 0.0), | |
| "delta_score": float(getattr(obs, "current_score", 0.0) or 0.0) - prev_score, | |
| "instruction_score": float(getattr(obs, "instruction_score", 0.0) or 0.0), | |
| "worst_metrics": list(getattr(obs, "worst_metrics", []) or []), | |
| "error": getattr(obs, "last_action_error", None), | |
| "done": bool(getattr(obs, "done", False)), | |
| } | |
| def _choose_action(policy: str, step: int, obs, history: List[str], rewards: List[float], recent_actions: List[str]): | |
| """Pick one action using the requested policy. Falls back to heuristic if LLM unavailable.""" | |
| if not _INFERENCE_OK: | |
| return DesignGymAction(action_type="finalize"), "fallback_no_inference" | |
| if policy == "heuristic": | |
| return INF.heuristic_action(step, obs, rewards, recent_actions), "heuristic" | |
| client = _get_llm_client() | |
| if client is None: | |
| action = INF.get_model_action_sync(None, step, obs, history, rewards, recent_actions) | |
| return action, "heuristic_fallback" | |
| action = INF.get_model_action_sync(client, step, obs, history, rewards, recent_actions) | |
| try: | |
| from local_model import LocalLoRAClient | |
| if isinstance(client, LocalLoRAClient): | |
| if client.adapter_id: | |
| label = f"finetuned_{client.adapter_key}" | |
| else: | |
| label = "local_base" | |
| elif hasattr(client, "base_url"): | |
| label = "router_base" | |
| else: | |
| label = "llm" | |
| except ImportError: | |
| label = "router_base" if hasattr(client, "base_url") else "llm" | |
| return action, label | |
| async def demo_policy_step(request: Request): | |
| """Run a single step using the requested policy ({"policy": "heuristic"|"sft"}).""" | |
| payload = await request.json() | |
| policy = (payload.get("policy") or "heuristic").lower() | |
| obs = _current_obs() | |
| if obs is None: | |
| return JSONResponse( | |
| status_code=409, | |
| content={"error": "no_active_episode", "hint": "Call /demo/reset first."}, | |
| ) | |
| state = DEMO_ENV.state | |
| step = int(getattr(state, "step_count", 0) or 0) + 1 | |
| recent_actions: List[str] = list(getattr(state, "action_history", []) or []) | |
| rewards: List[float] = [] # not tracked on state; OK since heuristic mostly checks last action | |
| history = recent_actions[-4:] | |
| prev_score = float(getattr(obs, "current_score", 0.0) or 0.0) | |
| action, used = _choose_action(policy, step, obs, history, rewards, recent_actions) | |
| obs_after = DEMO_ENV.step(action) | |
| _LAST_OBS["obs"] = obs_after | |
| record = _record_step(step, action, obs_after, prev_score) | |
| record["policy"] = used | |
| return { | |
| "observation": obs_after.model_dump(), | |
| "state": DEMO_ENV.state.model_dump(), | |
| "step_record": record, | |
| "reward": record["reward"], | |
| "done": record["done"], | |
| } | |
| async def demo_run_episode(request: Request): | |
| """Run a full episode server-side and return the trajectory + summary in one call. | |
| Body: {"policy": "heuristic"|"sft", "task_id": str, "seed": int, "max_steps": int} | |
| """ | |
| payload = await request.json() | |
| policy = (payload.get("policy") or "heuristic").lower() | |
| task_id = payload.get("task_id") or "poster_basic_v1" | |
| seed = int(payload.get("seed") or 0) | |
| max_steps_override = payload.get("max_steps") | |
| t0 = time.time() | |
| obs = DEMO_ENV.reset(task_id=task_id, seed=seed) | |
| _LAST_OBS["obs"] = obs | |
| declared_max = int(getattr(obs, "max_steps", 8) or 8) | |
| if max_steps_override: | |
| declared_max = min(declared_max, int(max_steps_override)) | |
| trajectory: List[Dict[str, Any]] = [] | |
| history: List[str] = [] | |
| rewards: List[float] = [] | |
| recent_actions: List[str] = [] | |
| try: | |
| for step in range(1, declared_max + 1): | |
| if bool(getattr(obs, "done", False)): | |
| break | |
| prev_score = float(getattr(obs, "current_score", 0.0) or 0.0) | |
| action, used = _choose_action(policy, step, obs, history, rewards, recent_actions) | |
| obs = DEMO_ENV.step(action) | |
| _LAST_OBS["obs"] = obs | |
| record = _record_step(step, action, obs, prev_score) | |
| record["policy"] = used | |
| trajectory.append(record) | |
| history.append(record["action"]) | |
| rewards.append(record["reward"]) | |
| recent_actions.append(record["action"]) | |
| if record["done"]: | |
| break | |
| summary = _summary_from_state(DEMO_ENV.state, trajectory) | |
| summary["policy_requested"] = policy | |
| summary["llm_available"] = _get_llm_client() is not None | |
| summary["wall_time_sec"] = round(time.time() - t0, 3) | |
| return { | |
| "summary": summary, | |
| "trajectory": trajectory, | |
| "final_observation": obs.model_dump(), | |
| "final_state": DEMO_ENV.state.model_dump(), | |
| } | |
| except Exception as exc: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": str(exc), | |
| "trace": traceback.format_exc().splitlines()[-12:], | |
| "trajectory": trajectory, | |
| }, | |
| ) | |
| def demo_policies(): | |
| """Tell the frontend which policies are usable right now.""" | |
| client = _get_llm_client() | |
| llm_ok = client is not None | |
| try: | |
| from local_model import describe_client, LocalLoRAClient | |
| info = describe_client(client) | |
| backend = info.get("backend", "none") | |
| adapter_key = info.get("adapter_key", "") | |
| except ImportError: | |
| backend = "router" if llm_ok else "none" | |
| adapter_key = "" | |
| if backend == "local-lora": | |
| llm_label = f"Fine-tuned {(adapter_key or '').upper()} · Qwen2.5-0.5B + LoRA (local CPU)" | |
| elif backend == "local-base": | |
| llm_label = "Base Qwen2.5-0.5B (local CPU, no adapter)" | |
| elif backend == "router": | |
| llm_label = "Base Qwen2.5-0.5B (HF Router) — NOT fine-tuned" | |
| else: | |
| llm_label = "Heuristic only — no LLM available" | |
| return { | |
| "policies": [ | |
| { | |
| "id": "heuristic", | |
| "label": "Heuristic Planner", | |
| "available": _INFERENCE_OK, | |
| "description": "Rule-based planner from inference.py — fast, no API key needed.", | |
| }, | |
| { | |
| "id": "llm", | |
| "label": llm_label, | |
| "available": _INFERENCE_OK and llm_ok, | |
| "description": f"Backend: {backend}. Uses the LLM client to pick from candidate actions.", | |
| "llm_active": llm_ok, | |
| "backend": backend, | |
| }, | |
| ], | |
| "llm_active": llm_ok, | |
| "backend": backend, | |
| "model_name": getattr(INF, "MODEL_NAME", None) if _INFERENCE_OK else None, | |
| } | |
| def main() -> None: | |
| host = os.getenv("HOST", "0.0.0.0") | |
| port = int(os.getenv("PORT", "8000")) | |
| uvicorn.run("server.app:app", host=host, port=port, reload=False) | |
| if __name__ == "__main__": | |
| main() |