Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| from pathlib import Path | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| # Add parent directory to path so we can import the RL modules | |
| ROOT_DIR = Path(__file__).parent.parent | |
| sys.path.append(str(ROOT_DIR)) | |
| import config as cfg | |
| from environment import TrafficEnvironment | |
| # We will try to load DQN since it performed best | |
| try: | |
| from agent import DQNAgent | |
| DQN_AVAILABLE = True | |
| except ImportError: | |
| DQN_AVAILABLE = False | |
| app = FastAPI(title="Traffic RL Simulation API") | |
| # Allow CORS for frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Serve static files from the frontend/dist directory | |
| # This must be after the API routes for correct routing, but we can define it here | |
| # if we mount it at the end. Actually, better to define it after all routes. | |
| # Global instances | |
| env = TrafficEnvironment(cfg) | |
| agent = None | |
| # Try to load DQN agent | |
| if DQN_AVAILABLE: | |
| agent = DQNAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.DQN_CONFIG) | |
| model_path = ROOT_DIR / "models" / "dqn_best.pth" | |
| if model_path.exists(): | |
| agent.load(str(model_path)) | |
| print("Backend: Loaded DQN Model successfully.") | |
| else: | |
| print("Backend: DQN Model file not found, using untrained agent.") | |
| else: | |
| # Fallback to QLearning | |
| from agent import QLearningAgent | |
| agent = QLearningAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.Q_LEARNING_CONFIG) | |
| model_path = ROOT_DIR / "models" / "q_learning_best.pth" | |
| if model_path.exists(): | |
| agent.load(str(model_path)) | |
| print("Backend: Loaded Q-Learning Model successfully.") | |
| else: | |
| print("Backend: Q-Learning Model file not found.") | |
| class StateResponse(BaseModel): | |
| queues: list[float] | |
| phase: int | |
| reward: float | |
| vehicles_passed: int | |
| step: int | |
| total_reward: float | |
| is_done: bool | |
| current_state, _ = env.reset() | |
| total_reward = 0.0 | |
| def reset_env(): | |
| global current_state, total_reward | |
| current_state, _ = env.reset() | |
| total_reward = 0.0 | |
| return { | |
| "queues": env.queue_lengths.tolist(), | |
| "phase": env.current_phase, | |
| "reward": 0.0, | |
| "vehicles_passed": env.vehicles_passed, | |
| "step": env.current_step, | |
| "total_reward": total_reward, | |
| "is_done": False | |
| } | |
| def step_env(): | |
| global current_state, total_reward | |
| # Get action from the loaded agent (evaluation mode) | |
| action = agent.select_action(current_state, training=False) | |
| # Step the environment | |
| next_state, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| current_state = next_state | |
| total_reward += reward | |
| response = { | |
| "queues": env.queue_lengths.tolist(), | |
| "phase": env.current_phase, | |
| "reward": reward, | |
| "vehicles_passed": env.vehicles_passed, | |
| "step": env.current_step, | |
| "total_reward": total_reward, | |
| "is_done": done | |
| } | |
| if done: | |
| # Reset for next call if done | |
| current_state, _ = env.reset() | |
| total_reward = 0.0 | |
| return response | |
| # Mount the static files at the root | |
| # Note: Ensure this is the last route defined | |
| frontend_path = ROOT_DIR / "frontend" / "dist" | |
| if frontend_path.exists(): | |
| app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Use port 7860 for Hugging Face Spaces compatibility | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) | |
| # Restart to load fixed working 9D model | |