""" BaseAgent — abstract interface that all RL agents must implement. Every concrete agent (Q-Learning, DQN, …) must override: • select_action(state, training) → int • train_step(state, action, reward, next_state, done) → float | None • save(filepath) • load(filepath) Optional hooks: • update_target_network() – used by DQN • reset() – called between episodes if needed """ from abc import ABC, abstractmethod try: import torch _TORCH_AVAILABLE = True except ImportError: _TORCH_AVAILABLE = False class BaseAgent(ABC): """Abstract base class for RL agents.""" def __init__(self, state_size: int, action_size: int, config: dict): self.state_size = state_size self.action_size = action_size self.config = config if _TORCH_AVAILABLE: import torch self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) else: self.device = None @abstractmethod def select_action(self, state, training: bool = True) -> int: """Return an action integer for the given state.""" @abstractmethod def train_step(self, state, action, reward, next_state, done): """Perform one update step; return loss (or None if not applicable).""" @abstractmethod def save(self, filepath: str): """Persist the agent to *filepath*.""" @abstractmethod def load(self, filepath: str): """Restore the agent from *filepath*.""" # Optional hooks def update_target_network(self): """Sync target network (DQN only).""" def reset(self): """Reset any per-episode internal state."""