"""Abstract base class for RL algorithms.""" from abc import ABC, abstractmethod from typing import Dict, Any import torch class RLAlgorithm(ABC): """ Abstract base class for reinforcement learning algorithms. Defines the interface that all RL algorithms must implement for training voice models. """ def __init__(self, learning_rate: float, **kwargs): """ Initialize the RL algorithm. Args: learning_rate: Learning rate for optimization **kwargs: Additional algorithm-specific parameters """ self.learning_rate = learning_rate self.hyperparameters = kwargs @abstractmethod def compute_loss( self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, **kwargs ) -> torch.Tensor: """ Compute the loss for the current batch. Args: states: Current states actions: Actions taken rewards: Rewards received next_states: Next states **kwargs: Additional algorithm-specific inputs Returns: Loss tensor """ pass @abstractmethod def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]: """ Update the policy based on computed loss. Args: loss: Computed loss tensor Returns: Dictionary containing update metrics (e.g., gradient norms) """ pass def get_hyperparameters(self) -> Dict[str, Any]: """ Get the hyperparameters for this algorithm. Returns: Dictionary of hyperparameter names and values """ return { 'learning_rate': self.learning_rate, **self.hyperparameters } def set_hyperparameter(self, name: str, value: Any) -> None: """ Set a hyperparameter value. Args: name: Hyperparameter name value: New value """ if name == 'learning_rate': self.learning_rate = value else: self.hyperparameters[name] = value