Traffic-Control / agent /base_agent.py
Dhaerya's picture
Add files
b00d5d5
"""
BaseAgent β€” abstract interface that all RL agents must implement.
Every concrete agent (Q-Learning, DQN, …) must override:
β€’ select_action(state, training) β†’ int
β€’ train_step(state, action, reward, next_state, done) β†’ float | None
β€’ save(filepath)
β€’ load(filepath)
Optional hooks:
β€’ update_target_network() – used by DQN
β€’ reset() – called between episodes if needed
"""
from abc import ABC, abstractmethod
try:
import torch
_TORCH_AVAILABLE = True
except ImportError:
_TORCH_AVAILABLE = False
class BaseAgent(ABC):
"""Abstract base class for RL agents."""
def __init__(self, state_size: int, action_size: int, config: dict):
self.state_size = state_size
self.action_size = action_size
self.config = config
if _TORCH_AVAILABLE:
import torch
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
else:
self.device = None
@abstractmethod
def select_action(self, state, training: bool = True) -> int:
"""Return an action integer for the given state."""
@abstractmethod
def train_step(self, state, action, reward, next_state, done):
"""Perform one update step; return loss (or None if not applicable)."""
@abstractmethod
def save(self, filepath: str):
"""Persist the agent to *filepath*."""
@abstractmethod
def load(self, filepath: str):
"""Restore the agent from *filepath*."""
# Optional hooks
def update_target_network(self):
"""Sync target network (DQN only)."""
def reset(self):
"""Reset any per-episode internal state."""