mbellan's picture
Initial deployment
c3efd49
"""Abstract base class for RL algorithms."""
from abc import ABC, abstractmethod
from typing import Dict, Any
import torch
class RLAlgorithm(ABC):
"""
Abstract base class for reinforcement learning algorithms.
Defines the interface that all RL algorithms must implement
for training voice models.
"""
def __init__(self, learning_rate: float, **kwargs):
"""
Initialize the RL algorithm.
Args:
learning_rate: Learning rate for optimization
**kwargs: Additional algorithm-specific parameters
"""
self.learning_rate = learning_rate
self.hyperparameters = kwargs
@abstractmethod
def compute_loss(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
**kwargs
) -> torch.Tensor:
"""
Compute the loss for the current batch.
Args:
states: Current states
actions: Actions taken
rewards: Rewards received
next_states: Next states
**kwargs: Additional algorithm-specific inputs
Returns:
Loss tensor
"""
pass
@abstractmethod
def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]:
"""
Update the policy based on computed loss.
Args:
loss: Computed loss tensor
Returns:
Dictionary containing update metrics (e.g., gradient norms)
"""
pass
def get_hyperparameters(self) -> Dict[str, Any]:
"""
Get the hyperparameters for this algorithm.
Returns:
Dictionary of hyperparameter names and values
"""
return {
'learning_rate': self.learning_rate,
**self.hyperparameters
}
def set_hyperparameter(self, name: str, value: Any) -> None:
"""
Set a hyperparameter value.
Args:
name: Hyperparameter name
value: New value
"""
if name == 'learning_rate':
self.learning_rate = value
else:
self.hyperparameters[name] = value