Spaces:
Sleeping
Sleeping
File size: 1,735 Bytes
b00d5d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | """
BaseAgent β abstract interface that all RL agents must implement.
Every concrete agent (Q-Learning, DQN, β¦) must override:
β’ select_action(state, training) β int
β’ train_step(state, action, reward, next_state, done) β float | None
β’ save(filepath)
β’ load(filepath)
Optional hooks:
β’ update_target_network() β used by DQN
β’ reset() β called between episodes if needed
"""
from abc import ABC, abstractmethod
try:
import torch
_TORCH_AVAILABLE = True
except ImportError:
_TORCH_AVAILABLE = False
class BaseAgent(ABC):
"""Abstract base class for RL agents."""
def __init__(self, state_size: int, action_size: int, config: dict):
self.state_size = state_size
self.action_size = action_size
self.config = config
if _TORCH_AVAILABLE:
import torch
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
else:
self.device = None
@abstractmethod
def select_action(self, state, training: bool = True) -> int:
"""Return an action integer for the given state."""
@abstractmethod
def train_step(self, state, action, reward, next_state, done):
"""Perform one update step; return loss (or None if not applicable)."""
@abstractmethod
def save(self, filepath: str):
"""Persist the agent to *filepath*."""
@abstractmethod
def load(self, filepath: str):
"""Restore the agent from *filepath*."""
# Optional hooks
def update_target_network(self):
"""Sync target network (DQN only)."""
def reset(self):
"""Reset any per-episode internal state."""
|