File size: 503 Bytes
b00d5d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
Agent package.

Provides:
  BaseAgent        — abstract interface
  QLearningAgent   — tabular Q-learning (NumPy only)
  DQNAgent         — Deep Q-Network (PyTorch, GPU-accelerated)
"""

from .base_agent import BaseAgent
from .q_learning_agent import QLearningAgent

# DQN requires PyTorch
try:
    from .dqn_agent import DQNAgent
    DQN_AVAILABLE = True
except ImportError:
    DQN_AVAILABLE = False
    DQNAgent = None  # type: ignore

__all__ = ["BaseAgent", "QLearningAgent", "DQNAgent"]