Spaces:
Sleeping
Sleeping
File size: 3,880 Bytes
b00d5d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import sys
import os
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
# Add parent directory to path so we can import the RL modules
ROOT_DIR = Path(__file__).parent.parent
sys.path.append(str(ROOT_DIR))
import config as cfg
from environment import TrafficEnvironment
# We will try to load DQN since it performed best
try:
from agent import DQNAgent
DQN_AVAILABLE = True
except ImportError:
DQN_AVAILABLE = False
app = FastAPI(title="Traffic RL Simulation API")
# Allow CORS for frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static files from the frontend/dist directory
# This must be after the API routes for correct routing, but we can define it here
# if we mount it at the end. Actually, better to define it after all routes.
# Global instances
env = TrafficEnvironment(cfg)
agent = None
# Try to load DQN agent
if DQN_AVAILABLE:
agent = DQNAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.DQN_CONFIG)
model_path = ROOT_DIR / "models" / "dqn_best.pth"
if model_path.exists():
agent.load(str(model_path))
print("Backend: Loaded DQN Model successfully.")
else:
print("Backend: DQN Model file not found, using untrained agent.")
else:
# Fallback to QLearning
from agent import QLearningAgent
agent = QLearningAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.Q_LEARNING_CONFIG)
model_path = ROOT_DIR / "models" / "q_learning_best.pth"
if model_path.exists():
agent.load(str(model_path))
print("Backend: Loaded Q-Learning Model successfully.")
else:
print("Backend: Q-Learning Model file not found.")
class StateResponse(BaseModel):
queues: list[float]
phase: int
reward: float
vehicles_passed: int
step: int
total_reward: float
is_done: bool
current_state, _ = env.reset()
total_reward = 0.0
@app.post("/api/reset", response_model=StateResponse)
def reset_env():
global current_state, total_reward
current_state, _ = env.reset()
total_reward = 0.0
return {
"queues": env.queue_lengths.tolist(),
"phase": env.current_phase,
"reward": 0.0,
"vehicles_passed": env.vehicles_passed,
"step": env.current_step,
"total_reward": total_reward,
"is_done": False
}
@app.post("/api/step", response_model=StateResponse)
def step_env():
global current_state, total_reward
# Get action from the loaded agent (evaluation mode)
action = agent.select_action(current_state, training=False)
# Step the environment
next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
current_state = next_state
total_reward += reward
response = {
"queues": env.queue_lengths.tolist(),
"phase": env.current_phase,
"reward": reward,
"vehicles_passed": env.vehicles_passed,
"step": env.current_step,
"total_reward": total_reward,
"is_done": done
}
if done:
# Reset for next call if done
current_state, _ = env.reset()
total_reward = 0.0
return response
# Mount the static files at the root
# Note: Ensure this is the last route defined
frontend_path = ROOT_DIR / "frontend" / "dist"
if frontend_path.exists():
app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="static")
if __name__ == "__main__":
import uvicorn
# Use port 7860 for Hugging Face Spaces compatibility
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
# Restart to load fixed working 9D model
|