Spaces:
Sleeping
Sleeping
| """ | |
| BaseAgent β abstract interface that all RL agents must implement. | |
| Every concrete agent (Q-Learning, DQN, β¦) must override: | |
| β’ select_action(state, training) β int | |
| β’ train_step(state, action, reward, next_state, done) β float | None | |
| β’ save(filepath) | |
| β’ load(filepath) | |
| Optional hooks: | |
| β’ update_target_network() β used by DQN | |
| β’ reset() β called between episodes if needed | |
| """ | |
| from abc import ABC, abstractmethod | |
| try: | |
| import torch | |
| _TORCH_AVAILABLE = True | |
| except ImportError: | |
| _TORCH_AVAILABLE = False | |
| class BaseAgent(ABC): | |
| """Abstract base class for RL agents.""" | |
| def __init__(self, state_size: int, action_size: int, config: dict): | |
| self.state_size = state_size | |
| self.action_size = action_size | |
| self.config = config | |
| if _TORCH_AVAILABLE: | |
| import torch | |
| self.device = torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| else: | |
| self.device = None | |
| def select_action(self, state, training: bool = True) -> int: | |
| """Return an action integer for the given state.""" | |
| def train_step(self, state, action, reward, next_state, done): | |
| """Perform one update step; return loss (or None if not applicable).""" | |
| def save(self, filepath: str): | |
| """Persist the agent to *filepath*.""" | |
| def load(self, filepath: str): | |
| """Restore the agent from *filepath*.""" | |
| # Optional hooks | |
| def update_target_network(self): | |
| """Sync target network (DQN only).""" | |
| def reset(self): | |
| """Reset any per-episode internal state.""" | |