Traffic-Control / agent /q_learning_agent.py
Dhaerya's picture
Add files
b00d5d5
"""
Tabular Q-Learning Agent.
Implements Q(s,a) ← Q(s,a) + α [r + γ·max_a' Q(s',a') − Q(s,a)]
Because Q-learning requires a finite state space, the continuous
observation is discretised into equal-width bins per dimension.
Key results from PROJECT_EXPLANATION.md:
• Mean reward: −916.97 (best among all methods)
• 5-feature state + 10 bins per dimension performs well
• Epsilon-greedy exploration with decay 0.995/episode
"""
import numpy as np
from .base_agent import BaseAgent
class QLearningAgent(BaseAgent):
"""
Tabular Q-Learning with adaptive state discretisation.
The Q-table is stored as a sparse dictionary
{(discrete_state_tuple, action): q_value} for memory efficiency.
"""
def __init__(self, state_size: int, action_size: int, config: dict):
super().__init__(state_size, action_size, config)
# Hyperparameters
self.learning_rate = config.get("learning_rate", 0.1)
self.gamma = config.get("gamma", 0.99)
self.epsilon = config.get("epsilon_start", 1.0)
self.epsilon_end = config.get("epsilon_end", 0.01)
self.epsilon_decay = config.get("epsilon_decay", 0.995)
self.num_bins = config.get("num_bins", 10)
# Adaptive bounds for normalisation
self.state_mins = np.zeros(state_size, dtype=np.float32)
self.state_maxs = np.ones(state_size, dtype=np.float32)
# Sparse Q-table
self.q_table: dict = {}
# Stats
self.steps = 0
self.episodes = 0
print(f"[Q-Learning] Initialised state={state_size} "
f"actions={action_size} bins={self.num_bins} "
f"lr={self.learning_rate} gamma={self.gamma}")
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _discretise(self, state: np.ndarray) -> tuple:
"""Convert continuous state → discrete tuple (hashable dict key)."""
if not isinstance(state, np.ndarray):
state = np.array(state, dtype=np.float32)
if state.dtype != np.float32:
state = state.astype(np.float32)
# Update running bounds
self.state_mins = np.minimum(self.state_mins, state)
self.state_maxs = np.maximum(self.state_maxs, state)
ranges = np.maximum(self.state_maxs - self.state_mins, 1e-8)
normalised = np.clip((state - self.state_mins) / ranges, 0.0, 1.0)
indices = (normalised * (self.num_bins - 1)).astype(np.int32)
return tuple(indices)
def _get_q(self, discrete_state: tuple, action: int) -> float:
return self.q_table.get((discrete_state, action), 0.0)
def _set_q(self, discrete_state: tuple, action: int, value: float):
self.q_table[(discrete_state, action)] = float(value)
# ------------------------------------------------------------------
# BaseAgent interface
# ------------------------------------------------------------------
def select_action(self, state, training: bool = True) -> int:
"""Epsilon-greedy action selection."""
ds = self._discretise(state)
if training and np.random.random() < self.epsilon:
return int(np.random.randint(0, self.action_size))
q_values = [self._get_q(ds, a) for a in range(self.action_size)]
max_q = max(q_values)
best = [a for a, q in enumerate(q_values) if q == max_q]
return int(np.random.choice(best))
def train_step(self, state, action, reward, next_state, done):
"""
One Bellman update.
Returns:
td_error (float): Temporal-difference error for this update.
"""
ds = self._discretise(state)
dns = self._discretise(next_state)
action = int(action)
reward = float(reward)
done = bool(done)
current_q = self._get_q(ds, action)
if done:
target_q = reward
else:
next_qs = [self._get_q(dns, a) for a in range(self.action_size)]
target_q = reward + self.gamma * max(next_qs)
td_error = target_q - current_q
self._set_q(ds, action, current_q + self.learning_rate * td_error)
if done:
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
self.episodes += 1
self.steps += 1
return float(td_error)
def save(self, filepath: str):
"""Serialise Q-table to a .npy file."""
payload = {
"q_table": dict(self.q_table),
"state_mins": self.state_mins.tolist(),
"state_maxs": self.state_maxs.tolist(),
"epsilon": self.epsilon,
"steps": self.steps,
"episodes": self.episodes,
"num_bins": self.num_bins,
}
np.save(filepath, payload, allow_pickle=True)
print(f"[Q-Learning] Saved Q-table ({len(self.q_table)} entries) -> {filepath}")
def load(self, filepath: str):
"""Deserialise Q-table from a .npy file."""
payload = np.load(filepath, allow_pickle=True).item()
self.q_table = payload["q_table"]
self.state_mins = np.array(payload["state_mins"], dtype=np.float32)
self.state_maxs = np.array(payload["state_maxs"], dtype=np.float32)
self.epsilon = payload["epsilon"]
self.steps = payload["steps"]
self.episodes = payload["episodes"]
self.num_bins = payload["num_bins"]
print(f"[Q-Learning] Loaded Q-table ({len(self.q_table)} entries) <- {filepath}")
# ------------------------------------------------------------------
# Diagnostics
# ------------------------------------------------------------------
def stats(self) -> dict:
if not self.q_table:
return {"entries": 0, "unique_states": 0}
states = {s for s, _ in self.q_table}
vals = list(self.q_table.values())
return {
"entries": len(self.q_table),
"unique_states": len(states),
"mean_q": float(np.mean(vals)),
"max_q": float(np.max(vals)),
"min_q": float(np.min(vals)),
"epsilon": round(self.epsilon, 4),
"episodes": self.episodes,
}
def __repr__(self):
s = self.stats()
return (
f"QLearningAgent(state={self.state_size}, actions={self.action_size}, "
f"bins={self.num_bins}, entries={s['entries']}, ε={s['epsilon']})"
)