"""Reward function for voice model RL training.""" import torch import numpy as np import logging from typing import Dict, Optional, Tuple try: from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import torchaudio ASR_AVAILABLE = True except ImportError: ASR_AVAILABLE = False logger.warning("ASR dependencies not available. Transcription accuracy will use placeholder.") logger = logging.getLogger(__name__) class RewardFunction: """ Computes rewards for voice model outputs based on multiple quality metrics. Reward components: - Clarity: Signal quality and spectral characteristics - Naturalness: Prosody and smoothness - Accuracy: Similarity to reference (if available) """ DEFAULT_PENALTY = -1.0 def __init__( self, weights: Optional[Dict[str, float]] = None, normalize_range: Tuple[float, float] = (0.0, 1.0), use_asr: bool = True, asr_model: Optional[str] = "facebook/wav2vec2-base-960h" ): """ Initialize reward function. Args: weights: Component weights {'clarity': 0.33, 'naturalness': 0.33, 'accuracy': 0.34} normalize_range: Range for normalized rewards use_asr: Whether to use ASR for transcription accuracy asr_model: HuggingFace ASR model to use """ if weights is None: weights = { 'clarity': 0.33, 'naturalness': 0.33, 'accuracy': 0.34 } # Validate weights if not np.isclose(sum(weights.values()), 1.0): raise ValueError(f"Weights must sum to 1.0, got {sum(weights.values())}") self.weights = weights self.normalize_range = normalize_range self.use_asr = use_asr and ASR_AVAILABLE # Initialize ASR model if requested self.asr_model = None self.asr_processor = None if self.use_asr: try: self.asr_processor = Wav2Vec2Processor.from_pretrained(asr_model) self.asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model) self.asr_model.eval() logger.info(f"Loaded ASR model: {asr_model}") except Exception as e: logger.warning(f"Failed to load ASR model: {e}. Using placeholder accuracy.") self.use_asr = False logger.info(f"Initialized RewardFunction with weights: {weights}, ASR: {self.use_asr}") def compute_reward( self, generated_audio: torch.Tensor, reference_audio: Optional[torch.Tensor] = None, transcription: Optional[str] = None ) -> float: """ Compute composite reward for generated audio. Args: generated_audio: Generated audio tensor reference_audio: Optional reference audio for comparison transcription: Optional expected transcription Returns: Normalized reward score """ try: # Convert to numpy for processing if isinstance(generated_audio, torch.Tensor): generated_audio = generated_audio.detach().cpu().numpy() if reference_audio is not None and isinstance(reference_audio, torch.Tensor): reference_audio = reference_audio.detach().cpu().numpy() # Compute individual components clarity_score = self._compute_clarity(generated_audio) naturalness_score = self._compute_naturalness(generated_audio, reference_audio) accuracy_score = self._compute_accuracy(generated_audio, reference_audio, transcription) # Weighted combination reward = ( self.weights['clarity'] * clarity_score + self.weights['naturalness'] * naturalness_score + self.weights['accuracy'] * accuracy_score ) # Normalize to target range reward = self._normalize_reward(reward) return float(reward) except Exception as e: logger.error(f"Error computing reward: {e}") return self.DEFAULT_PENALTY def _compute_clarity(self, audio: np.ndarray) -> float: """ Compute clarity score based on signal quality. Measures: - Signal-to-noise ratio - Spectral flatness - Absence of clipping Args: audio: Audio waveform Returns: Clarity score in [0, 1] """ score = 0.0 # Check for clipping clipping_ratio = np.mean(np.abs(audio) > 0.99) clipping_score = 1.0 - clipping_ratio score += 0.3 * clipping_score # Estimate SNR signal_power = np.mean(audio ** 2) if signal_power > 1e-10: # Simple noise estimation from quietest samples sorted_power = np.sort(audio ** 2) noise_floor = np.mean(sorted_power[:max(1, len(sorted_power) // 20)]) snr = 10 * np.log10(signal_power / max(noise_floor, 1e-10)) snr_score = np.clip(snr / 30.0, 0.0, 1.0) # Normalize to [0, 1] score += 0.4 * snr_score else: score += 0.0 # Spectral flatness (lower is better for speech) try: fft = np.fft.rfft(audio) magnitude = np.abs(fft) geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10))) arithmetic_mean = np.mean(magnitude) flatness = geometric_mean / (arithmetic_mean + 1e-10) flatness_score = 1.0 - flatness # Invert: lower flatness is better score += 0.3 * flatness_score except: score += 0.15 # Neutral score if computation fails return np.clip(score, 0.0, 1.0) def _compute_naturalness( self, audio: np.ndarray, reference: Optional[np.ndarray] = None ) -> float: """ Compute naturalness score based on prosody and smoothness. Measures: - Smoothness (absence of abrupt changes) - Energy distribution - Similarity to reference if available Args: audio: Generated audio reference: Optional reference audio Returns: Naturalness score in [0, 1] """ score = 0.0 # Smoothness: penalize abrupt changes if len(audio) > 1: diff = np.diff(audio) smoothness = 1.0 - np.clip(np.std(diff) / 0.1, 0.0, 1.0) score += 0.4 * smoothness else: score += 0.2 # Energy distribution: should not be too uniform or too spiky if len(audio) > 10: frame_size = len(audio) // 10 frame_energies = [ np.mean(audio[i:i+frame_size] ** 2) for i in range(0, len(audio) - frame_size, frame_size) ] energy_std = np.std(frame_energies) # Optimal std is around 0.01-0.1 energy_score = 1.0 - np.clip(abs(energy_std - 0.05) / 0.1, 0.0, 1.0) score += 0.3 * energy_score else: score += 0.15 # Similarity to reference if available if reference is not None: try: # Align lengths min_len = min(len(audio), len(reference)) audio_aligned = audio[:min_len] reference_aligned = reference[:min_len] # Compute correlation correlation = np.corrcoef(audio_aligned, reference_aligned)[0, 1] correlation_score = (correlation + 1.0) / 2.0 # Map [-1, 1] to [0, 1] score += 0.3 * correlation_score except: score += 0.15 else: score += 0.3 # Neutral score if no reference return np.clip(score, 0.0, 1.0) def _compute_accuracy( self, audio: np.ndarray, reference: Optional[np.ndarray] = None, transcription: Optional[str] = None ) -> float: """ Compute accuracy score based on similarity to reference and/or transcription. Args: audio: Generated audio reference: Optional reference audio transcription: Optional expected transcription Returns: Accuracy score in [0, 1] """ score = 0.0 num_components = 0 # Component 1: Audio similarity to reference if reference is not None: try: # Align lengths min_len = min(len(audio), len(reference)) audio_aligned = audio[:min_len] reference_aligned = reference[:min_len] # Mean squared error (lower is better) mse = np.mean((audio_aligned - reference_aligned) ** 2) mse_score = np.exp(-mse * 10) # Exponential decay # Correlation correlation = np.corrcoef(audio_aligned, reference_aligned)[0, 1] correlation_score = (correlation + 1.0) / 2.0 # Combined audio similarity score audio_sim_score = 0.5 * mse_score + 0.5 * correlation_score score += audio_sim_score num_components += 1 except Exception as e: logger.debug(f"Error computing audio similarity: {e}") # Component 2: Transcription accuracy using ASR if transcription and self.use_asr and self.asr_model is not None: try: trans_score = self._compute_transcription_accuracy(audio, transcription) score += trans_score num_components += 1 except Exception as e: logger.debug(f"Error computing transcription accuracy: {e}") # Return average score or neutral if no components if num_components > 0: return np.clip(score / num_components, 0.0, 1.0) else: return 0.5 def _compute_transcription_accuracy( self, audio: np.ndarray, expected_transcription: str, sample_rate: int = 16000 ) -> float: """ Compute transcription accuracy using ASR. Args: audio: Audio waveform expected_transcription: Expected transcription text sample_rate: Audio sample rate Returns: Transcription accuracy score in [0, 1] """ try: # Convert to tensor audio_tensor = torch.FloatTensor(audio) # Resample if needed (ASR models typically use 16kHz) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) audio_tensor = resampler(audio_tensor) # Process audio input_values = self.asr_processor( audio_tensor, sampling_rate=16000, return_tensors="pt" ).input_values # Get transcription with torch.no_grad(): logits = self.asr_model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = self.asr_processor.decode(predicted_ids[0]) # Compute similarity (simple word error rate approximation) score = self._compute_text_similarity( transcription.lower().strip(), expected_transcription.lower().strip() ) return score except Exception as e: logger.debug(f"Error in ASR transcription: {e}") return 0.5 def _compute_text_similarity(self, predicted: str, expected: str) -> float: """ Compute text similarity between predicted and expected transcriptions. Uses a simple Levenshtein distance-based metric. Args: predicted: Predicted transcription expected: Expected transcription Returns: Similarity score in [0, 1] """ if not expected: return 0.5 # Simple word-level comparison pred_words = set(predicted.split()) exp_words = set(expected.split()) if not exp_words: return 0.5 # Jaccard similarity intersection = len(pred_words & exp_words) union = len(pred_words | exp_words) if union == 0: return 0.0 return intersection / union def _normalize_reward(self, reward: float) -> float: """ Normalize reward to target range. Args: reward: Raw reward value (assumed to be in [0, 1]) Returns: Normalized reward """ min_val, max_val = self.normalize_range return min_val + (max_val - min_val) * np.clip(reward, 0.0, 1.0) def get_reward_components( self, generated_audio: torch.Tensor, reference_audio: Optional[torch.Tensor] = None, transcription: Optional[str] = None ) -> Dict[str, float]: """ Get breakdown of reward components. Args: generated_audio: Generated audio tensor reference_audio: Optional reference audio transcription: Optional expected transcription Returns: Dictionary with component scores """ try: # Convert to numpy if isinstance(generated_audio, torch.Tensor): generated_audio = generated_audio.detach().cpu().numpy() if reference_audio is not None and isinstance(reference_audio, torch.Tensor): reference_audio = reference_audio.detach().cpu().numpy() clarity = self._compute_clarity(generated_audio) naturalness = self._compute_naturalness(generated_audio, reference_audio) accuracy = self._compute_accuracy(generated_audio, reference_audio, transcription) total = ( self.weights['clarity'] * clarity + self.weights['naturalness'] * naturalness + self.weights['accuracy'] * accuracy ) return { 'clarity': clarity, 'naturalness': naturalness, 'accuracy': accuracy, 'total': total, 'normalized': self._normalize_reward(total) } except Exception as e: logger.error(f"Error getting reward components: {e}") return { 'clarity': 0.0, 'naturalness': 0.0, 'accuracy': 0.0, 'total': 0.0, 'normalized': self.DEFAULT_PENALTY }