voice-model-rl-training / voice_rl /rl /reward_function.py
mbellan's picture
Initial deployment
c3efd49
"""Reward function for voice model RL training."""
import torch
import numpy as np
import logging
from typing import Dict, Optional, Tuple
try:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
ASR_AVAILABLE = True
except ImportError:
ASR_AVAILABLE = False
logger.warning("ASR dependencies not available. Transcription accuracy will use placeholder.")
logger = logging.getLogger(__name__)
class RewardFunction:
"""
Computes rewards for voice model outputs based on multiple quality metrics.
Reward components:
- Clarity: Signal quality and spectral characteristics
- Naturalness: Prosody and smoothness
- Accuracy: Similarity to reference (if available)
"""
DEFAULT_PENALTY = -1.0
def __init__(
self,
weights: Optional[Dict[str, float]] = None,
normalize_range: Tuple[float, float] = (0.0, 1.0),
use_asr: bool = True,
asr_model: Optional[str] = "facebook/wav2vec2-base-960h"
):
"""
Initialize reward function.
Args:
weights: Component weights {'clarity': 0.33, 'naturalness': 0.33, 'accuracy': 0.34}
normalize_range: Range for normalized rewards
use_asr: Whether to use ASR for transcription accuracy
asr_model: HuggingFace ASR model to use
"""
if weights is None:
weights = {
'clarity': 0.33,
'naturalness': 0.33,
'accuracy': 0.34
}
# Validate weights
if not np.isclose(sum(weights.values()), 1.0):
raise ValueError(f"Weights must sum to 1.0, got {sum(weights.values())}")
self.weights = weights
self.normalize_range = normalize_range
self.use_asr = use_asr and ASR_AVAILABLE
# Initialize ASR model if requested
self.asr_model = None
self.asr_processor = None
if self.use_asr:
try:
self.asr_processor = Wav2Vec2Processor.from_pretrained(asr_model)
self.asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model)
self.asr_model.eval()
logger.info(f"Loaded ASR model: {asr_model}")
except Exception as e:
logger.warning(f"Failed to load ASR model: {e}. Using placeholder accuracy.")
self.use_asr = False
logger.info(f"Initialized RewardFunction with weights: {weights}, ASR: {self.use_asr}")
def compute_reward(
self,
generated_audio: torch.Tensor,
reference_audio: Optional[torch.Tensor] = None,
transcription: Optional[str] = None
) -> float:
"""
Compute composite reward for generated audio.
Args:
generated_audio: Generated audio tensor
reference_audio: Optional reference audio for comparison
transcription: Optional expected transcription
Returns:
Normalized reward score
"""
try:
# Convert to numpy for processing
if isinstance(generated_audio, torch.Tensor):
generated_audio = generated_audio.detach().cpu().numpy()
if reference_audio is not None and isinstance(reference_audio, torch.Tensor):
reference_audio = reference_audio.detach().cpu().numpy()
# Compute individual components
clarity_score = self._compute_clarity(generated_audio)
naturalness_score = self._compute_naturalness(generated_audio, reference_audio)
accuracy_score = self._compute_accuracy(generated_audio, reference_audio, transcription)
# Weighted combination
reward = (
self.weights['clarity'] * clarity_score +
self.weights['naturalness'] * naturalness_score +
self.weights['accuracy'] * accuracy_score
)
# Normalize to target range
reward = self._normalize_reward(reward)
return float(reward)
except Exception as e:
logger.error(f"Error computing reward: {e}")
return self.DEFAULT_PENALTY
def _compute_clarity(self, audio: np.ndarray) -> float:
"""
Compute clarity score based on signal quality.
Measures:
- Signal-to-noise ratio
- Spectral flatness
- Absence of clipping
Args:
audio: Audio waveform
Returns:
Clarity score in [0, 1]
"""
score = 0.0
# Check for clipping
clipping_ratio = np.mean(np.abs(audio) > 0.99)
clipping_score = 1.0 - clipping_ratio
score += 0.3 * clipping_score
# Estimate SNR
signal_power = np.mean(audio ** 2)
if signal_power > 1e-10:
# Simple noise estimation from quietest samples
sorted_power = np.sort(audio ** 2)
noise_floor = np.mean(sorted_power[:max(1, len(sorted_power) // 20)])
snr = 10 * np.log10(signal_power / max(noise_floor, 1e-10))
snr_score = np.clip(snr / 30.0, 0.0, 1.0) # Normalize to [0, 1]
score += 0.4 * snr_score
else:
score += 0.0
# Spectral flatness (lower is better for speech)
try:
fft = np.fft.rfft(audio)
magnitude = np.abs(fft)
geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10)))
arithmetic_mean = np.mean(magnitude)
flatness = geometric_mean / (arithmetic_mean + 1e-10)
flatness_score = 1.0 - flatness # Invert: lower flatness is better
score += 0.3 * flatness_score
except:
score += 0.15 # Neutral score if computation fails
return np.clip(score, 0.0, 1.0)
def _compute_naturalness(
self,
audio: np.ndarray,
reference: Optional[np.ndarray] = None
) -> float:
"""
Compute naturalness score based on prosody and smoothness.
Measures:
- Smoothness (absence of abrupt changes)
- Energy distribution
- Similarity to reference if available
Args:
audio: Generated audio
reference: Optional reference audio
Returns:
Naturalness score in [0, 1]
"""
score = 0.0
# Smoothness: penalize abrupt changes
if len(audio) > 1:
diff = np.diff(audio)
smoothness = 1.0 - np.clip(np.std(diff) / 0.1, 0.0, 1.0)
score += 0.4 * smoothness
else:
score += 0.2
# Energy distribution: should not be too uniform or too spiky
if len(audio) > 10:
frame_size = len(audio) // 10
frame_energies = [
np.mean(audio[i:i+frame_size] ** 2)
for i in range(0, len(audio) - frame_size, frame_size)
]
energy_std = np.std(frame_energies)
# Optimal std is around 0.01-0.1
energy_score = 1.0 - np.clip(abs(energy_std - 0.05) / 0.1, 0.0, 1.0)
score += 0.3 * energy_score
else:
score += 0.15
# Similarity to reference if available
if reference is not None:
try:
# Align lengths
min_len = min(len(audio), len(reference))
audio_aligned = audio[:min_len]
reference_aligned = reference[:min_len]
# Compute correlation
correlation = np.corrcoef(audio_aligned, reference_aligned)[0, 1]
correlation_score = (correlation + 1.0) / 2.0 # Map [-1, 1] to [0, 1]
score += 0.3 * correlation_score
except:
score += 0.15
else:
score += 0.3 # Neutral score if no reference
return np.clip(score, 0.0, 1.0)
def _compute_accuracy(
self,
audio: np.ndarray,
reference: Optional[np.ndarray] = None,
transcription: Optional[str] = None
) -> float:
"""
Compute accuracy score based on similarity to reference and/or transcription.
Args:
audio: Generated audio
reference: Optional reference audio
transcription: Optional expected transcription
Returns:
Accuracy score in [0, 1]
"""
score = 0.0
num_components = 0
# Component 1: Audio similarity to reference
if reference is not None:
try:
# Align lengths
min_len = min(len(audio), len(reference))
audio_aligned = audio[:min_len]
reference_aligned = reference[:min_len]
# Mean squared error (lower is better)
mse = np.mean((audio_aligned - reference_aligned) ** 2)
mse_score = np.exp(-mse * 10) # Exponential decay
# Correlation
correlation = np.corrcoef(audio_aligned, reference_aligned)[0, 1]
correlation_score = (correlation + 1.0) / 2.0
# Combined audio similarity score
audio_sim_score = 0.5 * mse_score + 0.5 * correlation_score
score += audio_sim_score
num_components += 1
except Exception as e:
logger.debug(f"Error computing audio similarity: {e}")
# Component 2: Transcription accuracy using ASR
if transcription and self.use_asr and self.asr_model is not None:
try:
trans_score = self._compute_transcription_accuracy(audio, transcription)
score += trans_score
num_components += 1
except Exception as e:
logger.debug(f"Error computing transcription accuracy: {e}")
# Return average score or neutral if no components
if num_components > 0:
return np.clip(score / num_components, 0.0, 1.0)
else:
return 0.5
def _compute_transcription_accuracy(
self,
audio: np.ndarray,
expected_transcription: str,
sample_rate: int = 16000
) -> float:
"""
Compute transcription accuracy using ASR.
Args:
audio: Audio waveform
expected_transcription: Expected transcription text
sample_rate: Audio sample rate
Returns:
Transcription accuracy score in [0, 1]
"""
try:
# Convert to tensor
audio_tensor = torch.FloatTensor(audio)
# Resample if needed (ASR models typically use 16kHz)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
# Process audio
input_values = self.asr_processor(
audio_tensor,
sampling_rate=16000,
return_tensors="pt"
).input_values
# Get transcription
with torch.no_grad():
logits = self.asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.asr_processor.decode(predicted_ids[0])
# Compute similarity (simple word error rate approximation)
score = self._compute_text_similarity(
transcription.lower().strip(),
expected_transcription.lower().strip()
)
return score
except Exception as e:
logger.debug(f"Error in ASR transcription: {e}")
return 0.5
def _compute_text_similarity(self, predicted: str, expected: str) -> float:
"""
Compute text similarity between predicted and expected transcriptions.
Uses a simple Levenshtein distance-based metric.
Args:
predicted: Predicted transcription
expected: Expected transcription
Returns:
Similarity score in [0, 1]
"""
if not expected:
return 0.5
# Simple word-level comparison
pred_words = set(predicted.split())
exp_words = set(expected.split())
if not exp_words:
return 0.5
# Jaccard similarity
intersection = len(pred_words & exp_words)
union = len(pred_words | exp_words)
if union == 0:
return 0.0
return intersection / union
def _normalize_reward(self, reward: float) -> float:
"""
Normalize reward to target range.
Args:
reward: Raw reward value (assumed to be in [0, 1])
Returns:
Normalized reward
"""
min_val, max_val = self.normalize_range
return min_val + (max_val - min_val) * np.clip(reward, 0.0, 1.0)
def get_reward_components(
self,
generated_audio: torch.Tensor,
reference_audio: Optional[torch.Tensor] = None,
transcription: Optional[str] = None
) -> Dict[str, float]:
"""
Get breakdown of reward components.
Args:
generated_audio: Generated audio tensor
reference_audio: Optional reference audio
transcription: Optional expected transcription
Returns:
Dictionary with component scores
"""
try:
# Convert to numpy
if isinstance(generated_audio, torch.Tensor):
generated_audio = generated_audio.detach().cpu().numpy()
if reference_audio is not None and isinstance(reference_audio, torch.Tensor):
reference_audio = reference_audio.detach().cpu().numpy()
clarity = self._compute_clarity(generated_audio)
naturalness = self._compute_naturalness(generated_audio, reference_audio)
accuracy = self._compute_accuracy(generated_audio, reference_audio, transcription)
total = (
self.weights['clarity'] * clarity +
self.weights['naturalness'] * naturalness +
self.weights['accuracy'] * accuracy
)
return {
'clarity': clarity,
'naturalness': naturalness,
'accuracy': accuracy,
'total': total,
'normalized': self._normalize_reward(total)
}
except Exception as e:
logger.error(f"Error getting reward components: {e}")
return {
'clarity': 0.0,
'naturalness': 0.0,
'accuracy': 0.0,
'total': 0.0,
'normalized': self.DEFAULT_PENALTY
}