Spaces:
Sleeping
Sleeping
File size: 7,325 Bytes
5c40041 22261c1 5c40041 22261c1 5c40041 6f0ece8 5c40041 6f0ece8 5c40041 6f0ece8 5c40041 6f0ece8 5c40041 a9ae55d 5c40041 a9ae55d 5c40041 a9ae55d 5c40041 9186517 86d7215 22261c1 86d7215 a9ae55d 86d7215 a9ae55d 86d7215 6f0ece8 a9ae55d 6f0ece8 9186517 6f0ece8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | """
server/app.py
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
FastAPI backend for Survival Island.
Endpoints
---------
GET /api/health liveness probe (used by start.sh and HF healthcheck)
GET /api/config safe public config (model name, pipeline status)
POST /api/infer LLM survival-action inference
"""
from __future__ import annotations
import logging
import os
import time
import traceback
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from server.environment import GameState, InferResponse, build_prompt
import models
# ββ ADDED: Import the environment βββββββββββββββββββββββββββββββββββββββββββββ
from inference import SurvivalIslandEnvironment
# ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)-8s] %(name)s: %(message)s",
)
logger = logging.getLogger("survival.api")
# ββ Lifespan (startup / shutdown) βββββββββββββββββββββββββββββββββββββββββββββ
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Pre-warm the model pipeline in a daemon thread on startup."""
import threading
def _warm():
logger.info("Pre-warming model pipelineβ¦")
models.get_pipeline()
logger.info("Pipeline warm-up complete.")
threading.Thread(target=_warm, daemon=True).start()
yield
logger.info("Survival Island API shutting down.")
# ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(
title="Survival Island API",
description=(
"LLM-powered survival agent backend.\n"
"Built for the Meta PyTorch Hackathon."
),
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in production if needed
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
# ββ Simple in-process rate limiter ββββββββββββββββββββββββββββββββββββββββββββ
_last_call: dict[str, float] = {}
_MIN_INTERVAL = 4.0 # seconds between /api/infer calls per IP
def _rate_ok(ip: str) -> bool:
now = time.monotonic()
if now - _last_call.get(ip, 0.0) < _MIN_INTERVAL:
return False
_last_call[ip] = now
return True
# ββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/api/health")
async def health():
"""Liveness probe β always returns 200 while the process is running."""
return {
"status": "ok",
"model": os.getenv("HF_MODEL", "unset"),
"localPipeline": models._use_local,
}
@app.get("/", include_in_schema=False)
async def root():
"""Root info route for browser access."""
return {
"message": "Survival Island API is running.",
"endpoints": ["/api/health", "/api/config", "/api/infer"],
}
@app.get("/api/config")
async def config():
"""Public runtime configuration for the frontend."""
return {
"model": os.getenv("HF_MODEL", "mistralai/Mistral-7B-Instruct-v0.2"),
"hasToken": bool(os.getenv("HF_TOKEN")),
"localPipeline": models._use_local,
}
@app.post("/api/infer", response_model=InferResponse)
async def infer(state: GameState, request: Request):
"""
Accept a GameState JSON body, build the LLM prompt,
run inference, and return the chosen action + thought.
Inference priority: local pipeline β HF Inference API β rule fallback.
"""
ip = request.client.host if request.client else "unknown"
if not _rate_ok(ip):
raise HTTPException(
status_code=429,
detail="Too many requests β wait a few seconds.",
)
prompt = build_prompt(state)
logger.info(
f"[infer] gen={state.generation} ip={ip} "
f"challenge={state.activeChallenge.type if state.activeChallenge else 'none'}"
)
# Try inference paths
source = "fallback"
try:
if models._use_local:
result = models.infer_local(prompt)
source = "local"
elif os.getenv("HF_TOKEN"):
result = models.infer_api(prompt)
source = "api"
else:
result = models.run_inference(prompt)
except Exception as exc:
logger.warning(f"[infer] Primary inference failed ({exc}), using fallback.")
result = models.run_inference(prompt)
return InferResponse(
action=result["action"],
thought=result["thought"],
source=source,
)
# ββ Global error handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.exception_handler(Exception)
async def _global_error(request: Request, exc: Exception):
logger.error(f"Unhandled error on {request.url}: {exc}", exc_info=True)
# TEMPORARY: Expose the actual Python error so the validator prints it!
return JSONResponse(
status_code=500,
content={"detail": f"{type(exc).__name__}: {str(exc)}"},
)
# βββββββββββββββββββββββββββββββββββββββββββββ
# OpenEnv REQUIRED ENDPOINTS
# βββββββββββββββββββββββββββββββββββββββββββββ
# ββ ADDED: Instantiate the environment here so the routes can use it ββββββββββ
env = SurvivalIslandEnvironment()
@app.post("/reset")
async def root_reset(request: Request):
state = env.reset()
return {
"observation": state,
"reward": 0.0,
"done": False,
"info": {}
}
@app.post("/step")
async def root_step(request: Request):
try:
data = await request.json()
except:
data = {}
action = data.get("action", "FORAGE")
state, reward, done, info = env.step(action)
return {
"observation": state,
"reward": float(reward),
"done": bool(done),
"info": info
}
@app.get("/state")
async def root_state():
return env.current_state
def main():
import uvicorn
uvicorn.run("server.app:app", host="0.0.0.0", port=8000)
if __name__ == "__main__":
main() |