Dhaerya's picture
Add files
b00d5d5
import sys
import os
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
# Add parent directory to path so we can import the RL modules
ROOT_DIR = Path(__file__).parent.parent
sys.path.append(str(ROOT_DIR))
import config as cfg
from environment import TrafficEnvironment
# We will try to load DQN since it performed best
try:
from agent import DQNAgent
DQN_AVAILABLE = True
except ImportError:
DQN_AVAILABLE = False
app = FastAPI(title="Traffic RL Simulation API")
# Allow CORS for frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static files from the frontend/dist directory
# This must be after the API routes for correct routing, but we can define it here
# if we mount it at the end. Actually, better to define it after all routes.
# Global instances
env = TrafficEnvironment(cfg)
agent = None
# Try to load DQN agent
if DQN_AVAILABLE:
agent = DQNAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.DQN_CONFIG)
model_path = ROOT_DIR / "models" / "dqn_best.pth"
if model_path.exists():
agent.load(str(model_path))
print("Backend: Loaded DQN Model successfully.")
else:
print("Backend: DQN Model file not found, using untrained agent.")
else:
# Fallback to QLearning
from agent import QLearningAgent
agent = QLearningAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.Q_LEARNING_CONFIG)
model_path = ROOT_DIR / "models" / "q_learning_best.pth"
if model_path.exists():
agent.load(str(model_path))
print("Backend: Loaded Q-Learning Model successfully.")
else:
print("Backend: Q-Learning Model file not found.")
class StateResponse(BaseModel):
queues: list[float]
phase: int
reward: float
vehicles_passed: int
step: int
total_reward: float
is_done: bool
current_state, _ = env.reset()
total_reward = 0.0
@app.post("/api/reset", response_model=StateResponse)
def reset_env():
global current_state, total_reward
current_state, _ = env.reset()
total_reward = 0.0
return {
"queues": env.queue_lengths.tolist(),
"phase": env.current_phase,
"reward": 0.0,
"vehicles_passed": env.vehicles_passed,
"step": env.current_step,
"total_reward": total_reward,
"is_done": False
}
@app.post("/api/step", response_model=StateResponse)
def step_env():
global current_state, total_reward
# Get action from the loaded agent (evaluation mode)
action = agent.select_action(current_state, training=False)
# Step the environment
next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
current_state = next_state
total_reward += reward
response = {
"queues": env.queue_lengths.tolist(),
"phase": env.current_phase,
"reward": reward,
"vehicles_passed": env.vehicles_passed,
"step": env.current_step,
"total_reward": total_reward,
"is_done": done
}
if done:
# Reset for next call if done
current_state, _ = env.reset()
total_reward = 0.0
return response
# Mount the static files at the root
# Note: Ensure this is the last route defined
frontend_path = ROOT_DIR / "frontend" / "dist"
if frontend_path.exists():
app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="static")
if __name__ == "__main__":
import uvicorn
# Use port 7860 for Hugging Face Spaces compatibility
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
# Restart to load fixed working 9D model