Spaces:
Sleeping
Sleeping
File size: 7,292 Bytes
9f7b0e1 cb64216 9f7b0e1 cb64216 9f7b0e1 3890926 9f7b0e1 94158b3 9f7b0e1 94158b3 9f7b0e1 94158b3 9f7b0e1 94158b3 9f7b0e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | import asyncio
import json
import math
import random
import uuid
import os
import time
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from dotenv import load_dotenv
load_dotenv()
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
from .models import SimulationState, AgentModel, TickResponse, FireScenario, WaterSource
from .simulation import SimulationEngine, TICK_INTERVAL_SECONDS
from . import groq_client
from . import hf_spaces
app = FastAPI(title="Unhinged 2.0", version="0.2.0")
_DEFAULT_ALLOWED_ORIGINS = [
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:3002",
"http://127.0.0.1:3000",
"http://127.0.0.1:3001",
"http://127.0.0.1:3002",
]
_configured_origins = os.environ.get("ALLOWED_ORIGINS", "").strip()
if _configured_origins:
ALLOWED_ORIGINS = [origin.strip() for origin in _configured_origins.split(",") if origin.strip()]
else:
ALLOWED_ORIGINS = _DEFAULT_ALLOWED_ORIGINS
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
active_simulations: dict[str, SimulationState] = {}
START_TIME = time.time()
def _safe_randint(low: int, high: int) -> int:
"""Return a valid random int even if bounds are inverted."""
if low > high:
low, high = high, low
return random.randint(low, high)
class StartSimulationRequest(BaseModel):
model_names: list[str] = Field(..., min_length=2, max_length=6)
scenario: str = "fire"
map_width: int = 1200
map_height: int = 800
class StartSimulationResponse(BaseModel):
simulation_id: str
state: SimulationState
class PlaceFireRequest(BaseModel):
simulation_id: str
x: int
y: int
class TickRequest(BaseModel):
simulation_id: str
@app.get("/")
async def root():
return {
"service": "rush-agents-backend",
"status": "ok",
"groq_available": groq_client.is_ready(),
}
@app.get("/wake")
async def wake():
return {
"warm": True,
"groq_available": groq_client.is_ready(),
"uptime_seconds": int(time.time() - START_TIME),
}
@app.get("/available-models")
async def get_available_models():
"""Get list of available models (Groq + HF Spaces) for the UI."""
return await hf_spaces.get_available_models()
@app.post("/start-simulation", response_model=StartSimulationResponse)
async def start_simulation(req: StartSimulationRequest):
if req.scenario != "fire":
raise HTTPException(status_code=400, detail="Only 'fire' scenario supported.")
agents = _spawn_agents(req.model_names, req.map_width, req.map_height)
state = SimulationState(
simulation_id=str(uuid.uuid4()),
scenario=req.scenario,
map_width=req.map_width,
map_height=req.map_height,
agents=agents,
fire=None,
water_sources=[],
round=0,
status="waiting_for_scenario",
)
active_simulations[state.simulation_id] = state
return StartSimulationResponse(simulation_id=state.simulation_id, state=state)
@app.post("/place-fire", response_model=SimulationState)
def place_fire(req: PlaceFireRequest):
sim = _get_or_404(req.simulation_id)
if sim.status != "waiting_for_scenario":
raise HTTPException(status_code=409, detail="Fire already placed or simulation finished.")
# Create fire at a clamped location inside map bounds.
fire_x = max(0, min(req.x, sim.map_width))
fire_y = max(0, min(req.y, sim.map_height))
sim.fire = FireScenario(x=fire_x, y=fire_y)
# Generate 3-5 water sources scattered around the map
num_sources = random.randint(3, 5)
x_margin = 80
y_margin = 80
x_min = x_margin
x_max = max(x_margin, sim.map_width - x_margin)
y_min = y_margin
y_max = max(y_margin, sim.map_height - y_margin)
for i in range(num_sources):
# Prefer spawning wells to one side of the fire, but always keep ranges valid.
left_low = x_min
left_high = min(fire_x - 180, x_max)
right_low = max(fire_x + 180, x_min)
right_high = x_max
pick_left = random.random() < 0.5
if pick_left and left_low <= left_high:
water_x = _safe_randint(left_low, left_high)
elif right_low <= right_high:
water_x = _safe_randint(right_low, right_high)
elif left_low <= left_high:
water_x = _safe_randint(left_low, left_high)
else:
water_x = _safe_randint(x_min, x_max)
water_y = _safe_randint(y_min, y_max)
sim.water_sources.append(WaterSource(id=f"water_{i}", x=water_x, y=water_y))
sim.status = "running"
return sim
@app.websocket("/ws/{simulation_id}")
async def simulation_ws(websocket: WebSocket, simulation_id: str):
await websocket.accept()
sim = active_simulations.get(simulation_id)
if not sim:
await websocket.close(code=1008)
return
try:
while True:
if sim.status == "waiting_for_scenario":
await asyncio.sleep(1)
continue
if sim.status == "finished":
await websocket.send_json({"type": "finished", "state": sim.model_dump()})
await websocket.close(code=1000)
return
engine = SimulationEngine(sim)
result = await engine.tick()
active_simulations[simulation_id] = result.state
# DEBUG: log outgoing TickResponse summary for troubleshooting
try:
agent_states = [(a.model_name, a.alive) for a in result.state.agents]
except Exception:
agent_states = str(result.state)
print(f"WS_SEND sim={simulation_id} round={result.round} agents={agent_states} events={len(result.events)}")
await websocket.send_json(result.model_dump())
if result.state.status == "finished":
await websocket.close(code=1000)
return
await asyncio.sleep(TICK_INTERVAL_SECONDS)
except WebSocketDisconnect:
pass
def _spawn_agents(model_names: list[str], width: int, height: int) -> list[AgentModel]:
min_gap = 100
positions = []
agents = []
for name in model_names:
for _ in range(100):
x = random.randint(100, width - 100)
y = random.randint(100, height - 100)
if all(math.dist((x, y), p) >= min_gap for p in positions):
positions.append((x, y))
break
else:
positions.append((x, y))
agents.append(AgentModel(
model_name=name,
display_name=name.split("/")[-1].split("-")[0].capitalize(),
x=positions[-1][0],
y=positions[-1][1],
alive=True
))
return agents
def _get_or_404(simulation_id: str) -> SimulationState:
sim = active_simulations.get(simulation_id)
if not sim:
raise HTTPException(status_code=404, detail="Simulation not found")
return sim
|