File size: 3,880 Bytes
b00d5d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import sys
import os
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

# Add parent directory to path so we can import the RL modules
ROOT_DIR = Path(__file__).parent.parent
sys.path.append(str(ROOT_DIR))

import config as cfg
from environment import TrafficEnvironment

# We will try to load DQN since it performed best
try:
    from agent import DQNAgent
    DQN_AVAILABLE = True
except ImportError:
    DQN_AVAILABLE = False

app = FastAPI(title="Traffic RL Simulation API")

# Allow CORS for frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve static files from the frontend/dist directory
# This must be after the API routes for correct routing, but we can define it here 
# if we mount it at the end. Actually, better to define it after all routes.

# Global instances
env = TrafficEnvironment(cfg)
agent = None

# Try to load DQN agent
if DQN_AVAILABLE:
    agent = DQNAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.DQN_CONFIG)
    model_path = ROOT_DIR / "models" / "dqn_best.pth"
    if model_path.exists():
        agent.load(str(model_path))
        print("Backend: Loaded DQN Model successfully.")
    else:
        print("Backend: DQN Model file not found, using untrained agent.")
else:
    # Fallback to QLearning
    from agent import QLearningAgent
    agent = QLearningAgent(cfg.STATE_SIZE, cfg.ACTION_SIZE, cfg.Q_LEARNING_CONFIG)
    model_path = ROOT_DIR / "models" / "q_learning_best.pth"
    if model_path.exists():
        agent.load(str(model_path))
        print("Backend: Loaded Q-Learning Model successfully.")
    else:
        print("Backend: Q-Learning Model file not found.")

class StateResponse(BaseModel):
    queues: list[float]
    phase: int
    reward: float
    vehicles_passed: int
    step: int
    total_reward: float
    is_done: bool

current_state, _ = env.reset()
total_reward = 0.0

@app.post("/api/reset", response_model=StateResponse)
def reset_env():
    global current_state, total_reward
    current_state, _ = env.reset()
    total_reward = 0.0
    return {
        "queues": env.queue_lengths.tolist(),
        "phase": env.current_phase,
        "reward": 0.0,
        "vehicles_passed": env.vehicles_passed,
        "step": env.current_step,
        "total_reward": total_reward,
        "is_done": False
    }

@app.post("/api/step", response_model=StateResponse)
def step_env():
    global current_state, total_reward
    
    # Get action from the loaded agent (evaluation mode)
    action = agent.select_action(current_state, training=False)
    
    # Step the environment
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    
    current_state = next_state
    total_reward += reward
    
    response = {
        "queues": env.queue_lengths.tolist(),
        "phase": env.current_phase,
        "reward": reward,
        "vehicles_passed": env.vehicles_passed,
        "step": env.current_step,
        "total_reward": total_reward,
        "is_done": done
    }
    
    if done:
        # Reset for next call if done
        current_state, _ = env.reset()
        total_reward = 0.0
        
    return response

# Mount the static files at the root
# Note: Ensure this is the last route defined
frontend_path = ROOT_DIR / "frontend" / "dist"
if frontend_path.exists():
    app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="static")

if __name__ == "__main__":
    import uvicorn
    # Use port 7860 for Hugging Face Spaces compatibility
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
    # Restart to load fixed working 9D model