Spaces:
Runtime error
Runtime error
| """Reward function for voice model RL training.""" | |
| import torch | |
| import numpy as np | |
| import logging | |
| from typing import Dict, Optional, Tuple | |
| try: | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
| import torchaudio | |
| ASR_AVAILABLE = True | |
| except ImportError: | |
| ASR_AVAILABLE = False | |
| logger.warning("ASR dependencies not available. Transcription accuracy will use placeholder.") | |
| logger = logging.getLogger(__name__) | |
| class RewardFunction: | |
| """ | |
| Computes rewards for voice model outputs based on multiple quality metrics. | |
| Reward components: | |
| - Clarity: Signal quality and spectral characteristics | |
| - Naturalness: Prosody and smoothness | |
| - Accuracy: Similarity to reference (if available) | |
| """ | |
| DEFAULT_PENALTY = -1.0 | |
| def __init__( | |
| self, | |
| weights: Optional[Dict[str, float]] = None, | |
| normalize_range: Tuple[float, float] = (0.0, 1.0), | |
| use_asr: bool = True, | |
| asr_model: Optional[str] = "facebook/wav2vec2-base-960h" | |
| ): | |
| """ | |
| Initialize reward function. | |
| Args: | |
| weights: Component weights {'clarity': 0.33, 'naturalness': 0.33, 'accuracy': 0.34} | |
| normalize_range: Range for normalized rewards | |
| use_asr: Whether to use ASR for transcription accuracy | |
| asr_model: HuggingFace ASR model to use | |
| """ | |
| if weights is None: | |
| weights = { | |
| 'clarity': 0.33, | |
| 'naturalness': 0.33, | |
| 'accuracy': 0.34 | |
| } | |
| # Validate weights | |
| if not np.isclose(sum(weights.values()), 1.0): | |
| raise ValueError(f"Weights must sum to 1.0, got {sum(weights.values())}") | |
| self.weights = weights | |
| self.normalize_range = normalize_range | |
| self.use_asr = use_asr and ASR_AVAILABLE | |
| # Initialize ASR model if requested | |
| self.asr_model = None | |
| self.asr_processor = None | |
| if self.use_asr: | |
| try: | |
| self.asr_processor = Wav2Vec2Processor.from_pretrained(asr_model) | |
| self.asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model) | |
| self.asr_model.eval() | |
| logger.info(f"Loaded ASR model: {asr_model}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load ASR model: {e}. Using placeholder accuracy.") | |
| self.use_asr = False | |
| logger.info(f"Initialized RewardFunction with weights: {weights}, ASR: {self.use_asr}") | |
| def compute_reward( | |
| self, | |
| generated_audio: torch.Tensor, | |
| reference_audio: Optional[torch.Tensor] = None, | |
| transcription: Optional[str] = None | |
| ) -> float: | |
| """ | |
| Compute composite reward for generated audio. | |
| Args: | |
| generated_audio: Generated audio tensor | |
| reference_audio: Optional reference audio for comparison | |
| transcription: Optional expected transcription | |
| Returns: | |
| Normalized reward score | |
| """ | |
| try: | |
| # Convert to numpy for processing | |
| if isinstance(generated_audio, torch.Tensor): | |
| generated_audio = generated_audio.detach().cpu().numpy() | |
| if reference_audio is not None and isinstance(reference_audio, torch.Tensor): | |
| reference_audio = reference_audio.detach().cpu().numpy() | |
| # Compute individual components | |
| clarity_score = self._compute_clarity(generated_audio) | |
| naturalness_score = self._compute_naturalness(generated_audio, reference_audio) | |
| accuracy_score = self._compute_accuracy(generated_audio, reference_audio, transcription) | |
| # Weighted combination | |
| reward = ( | |
| self.weights['clarity'] * clarity_score + | |
| self.weights['naturalness'] * naturalness_score + | |
| self.weights['accuracy'] * accuracy_score | |
| ) | |
| # Normalize to target range | |
| reward = self._normalize_reward(reward) | |
| return float(reward) | |
| except Exception as e: | |
| logger.error(f"Error computing reward: {e}") | |
| return self.DEFAULT_PENALTY | |
| def _compute_clarity(self, audio: np.ndarray) -> float: | |
| """ | |
| Compute clarity score based on signal quality. | |
| Measures: | |
| - Signal-to-noise ratio | |
| - Spectral flatness | |
| - Absence of clipping | |
| Args: | |
| audio: Audio waveform | |
| Returns: | |
| Clarity score in [0, 1] | |
| """ | |
| score = 0.0 | |
| # Check for clipping | |
| clipping_ratio = np.mean(np.abs(audio) > 0.99) | |
| clipping_score = 1.0 - clipping_ratio | |
| score += 0.3 * clipping_score | |
| # Estimate SNR | |
| signal_power = np.mean(audio ** 2) | |
| if signal_power > 1e-10: | |
| # Simple noise estimation from quietest samples | |
| sorted_power = np.sort(audio ** 2) | |
| noise_floor = np.mean(sorted_power[:max(1, len(sorted_power) // 20)]) | |
| snr = 10 * np.log10(signal_power / max(noise_floor, 1e-10)) | |
| snr_score = np.clip(snr / 30.0, 0.0, 1.0) # Normalize to [0, 1] | |
| score += 0.4 * snr_score | |
| else: | |
| score += 0.0 | |
| # Spectral flatness (lower is better for speech) | |
| try: | |
| fft = np.fft.rfft(audio) | |
| magnitude = np.abs(fft) | |
| geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10))) | |
| arithmetic_mean = np.mean(magnitude) | |
| flatness = geometric_mean / (arithmetic_mean + 1e-10) | |
| flatness_score = 1.0 - flatness # Invert: lower flatness is better | |
| score += 0.3 * flatness_score | |
| except: | |
| score += 0.15 # Neutral score if computation fails | |
| return np.clip(score, 0.0, 1.0) | |
| def _compute_naturalness( | |
| self, | |
| audio: np.ndarray, | |
| reference: Optional[np.ndarray] = None | |
| ) -> float: | |
| """ | |
| Compute naturalness score based on prosody and smoothness. | |
| Measures: | |
| - Smoothness (absence of abrupt changes) | |
| - Energy distribution | |
| - Similarity to reference if available | |
| Args: | |
| audio: Generated audio | |
| reference: Optional reference audio | |
| Returns: | |
| Naturalness score in [0, 1] | |
| """ | |
| score = 0.0 | |
| # Smoothness: penalize abrupt changes | |
| if len(audio) > 1: | |
| diff = np.diff(audio) | |
| smoothness = 1.0 - np.clip(np.std(diff) / 0.1, 0.0, 1.0) | |
| score += 0.4 * smoothness | |
| else: | |
| score += 0.2 | |
| # Energy distribution: should not be too uniform or too spiky | |
| if len(audio) > 10: | |
| frame_size = len(audio) // 10 | |
| frame_energies = [ | |
| np.mean(audio[i:i+frame_size] ** 2) | |
| for i in range(0, len(audio) - frame_size, frame_size) | |
| ] | |
| energy_std = np.std(frame_energies) | |
| # Optimal std is around 0.01-0.1 | |
| energy_score = 1.0 - np.clip(abs(energy_std - 0.05) / 0.1, 0.0, 1.0) | |
| score += 0.3 * energy_score | |
| else: | |
| score += 0.15 | |
| # Similarity to reference if available | |
| if reference is not None: | |
| try: | |
| # Align lengths | |
| min_len = min(len(audio), len(reference)) | |
| audio_aligned = audio[:min_len] | |
| reference_aligned = reference[:min_len] | |
| # Compute correlation | |
| correlation = np.corrcoef(audio_aligned, reference_aligned)[0, 1] | |
| correlation_score = (correlation + 1.0) / 2.0 # Map [-1, 1] to [0, 1] | |
| score += 0.3 * correlation_score | |
| except: | |
| score += 0.15 | |
| else: | |
| score += 0.3 # Neutral score if no reference | |
| return np.clip(score, 0.0, 1.0) | |
| def _compute_accuracy( | |
| self, | |
| audio: np.ndarray, | |
| reference: Optional[np.ndarray] = None, | |
| transcription: Optional[str] = None | |
| ) -> float: | |
| """ | |
| Compute accuracy score based on similarity to reference and/or transcription. | |
| Args: | |
| audio: Generated audio | |
| reference: Optional reference audio | |
| transcription: Optional expected transcription | |
| Returns: | |
| Accuracy score in [0, 1] | |
| """ | |
| score = 0.0 | |
| num_components = 0 | |
| # Component 1: Audio similarity to reference | |
| if reference is not None: | |
| try: | |
| # Align lengths | |
| min_len = min(len(audio), len(reference)) | |
| audio_aligned = audio[:min_len] | |
| reference_aligned = reference[:min_len] | |
| # Mean squared error (lower is better) | |
| mse = np.mean((audio_aligned - reference_aligned) ** 2) | |
| mse_score = np.exp(-mse * 10) # Exponential decay | |
| # Correlation | |
| correlation = np.corrcoef(audio_aligned, reference_aligned)[0, 1] | |
| correlation_score = (correlation + 1.0) / 2.0 | |
| # Combined audio similarity score | |
| audio_sim_score = 0.5 * mse_score + 0.5 * correlation_score | |
| score += audio_sim_score | |
| num_components += 1 | |
| except Exception as e: | |
| logger.debug(f"Error computing audio similarity: {e}") | |
| # Component 2: Transcription accuracy using ASR | |
| if transcription and self.use_asr and self.asr_model is not None: | |
| try: | |
| trans_score = self._compute_transcription_accuracy(audio, transcription) | |
| score += trans_score | |
| num_components += 1 | |
| except Exception as e: | |
| logger.debug(f"Error computing transcription accuracy: {e}") | |
| # Return average score or neutral if no components | |
| if num_components > 0: | |
| return np.clip(score / num_components, 0.0, 1.0) | |
| else: | |
| return 0.5 | |
| def _compute_transcription_accuracy( | |
| self, | |
| audio: np.ndarray, | |
| expected_transcription: str, | |
| sample_rate: int = 16000 | |
| ) -> float: | |
| """ | |
| Compute transcription accuracy using ASR. | |
| Args: | |
| audio: Audio waveform | |
| expected_transcription: Expected transcription text | |
| sample_rate: Audio sample rate | |
| Returns: | |
| Transcription accuracy score in [0, 1] | |
| """ | |
| try: | |
| # Convert to tensor | |
| audio_tensor = torch.FloatTensor(audio) | |
| # Resample if needed (ASR models typically use 16kHz) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| audio_tensor = resampler(audio_tensor) | |
| # Process audio | |
| input_values = self.asr_processor( | |
| audio_tensor, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_values | |
| # Get transcription | |
| with torch.no_grad(): | |
| logits = self.asr_model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = self.asr_processor.decode(predicted_ids[0]) | |
| # Compute similarity (simple word error rate approximation) | |
| score = self._compute_text_similarity( | |
| transcription.lower().strip(), | |
| expected_transcription.lower().strip() | |
| ) | |
| return score | |
| except Exception as e: | |
| logger.debug(f"Error in ASR transcription: {e}") | |
| return 0.5 | |
| def _compute_text_similarity(self, predicted: str, expected: str) -> float: | |
| """ | |
| Compute text similarity between predicted and expected transcriptions. | |
| Uses a simple Levenshtein distance-based metric. | |
| Args: | |
| predicted: Predicted transcription | |
| expected: Expected transcription | |
| Returns: | |
| Similarity score in [0, 1] | |
| """ | |
| if not expected: | |
| return 0.5 | |
| # Simple word-level comparison | |
| pred_words = set(predicted.split()) | |
| exp_words = set(expected.split()) | |
| if not exp_words: | |
| return 0.5 | |
| # Jaccard similarity | |
| intersection = len(pred_words & exp_words) | |
| union = len(pred_words | exp_words) | |
| if union == 0: | |
| return 0.0 | |
| return intersection / union | |
| def _normalize_reward(self, reward: float) -> float: | |
| """ | |
| Normalize reward to target range. | |
| Args: | |
| reward: Raw reward value (assumed to be in [0, 1]) | |
| Returns: | |
| Normalized reward | |
| """ | |
| min_val, max_val = self.normalize_range | |
| return min_val + (max_val - min_val) * np.clip(reward, 0.0, 1.0) | |
| def get_reward_components( | |
| self, | |
| generated_audio: torch.Tensor, | |
| reference_audio: Optional[torch.Tensor] = None, | |
| transcription: Optional[str] = None | |
| ) -> Dict[str, float]: | |
| """ | |
| Get breakdown of reward components. | |
| Args: | |
| generated_audio: Generated audio tensor | |
| reference_audio: Optional reference audio | |
| transcription: Optional expected transcription | |
| Returns: | |
| Dictionary with component scores | |
| """ | |
| try: | |
| # Convert to numpy | |
| if isinstance(generated_audio, torch.Tensor): | |
| generated_audio = generated_audio.detach().cpu().numpy() | |
| if reference_audio is not None and isinstance(reference_audio, torch.Tensor): | |
| reference_audio = reference_audio.detach().cpu().numpy() | |
| clarity = self._compute_clarity(generated_audio) | |
| naturalness = self._compute_naturalness(generated_audio, reference_audio) | |
| accuracy = self._compute_accuracy(generated_audio, reference_audio, transcription) | |
| total = ( | |
| self.weights['clarity'] * clarity + | |
| self.weights['naturalness'] * naturalness + | |
| self.weights['accuracy'] * accuracy | |
| ) | |
| return { | |
| 'clarity': clarity, | |
| 'naturalness': naturalness, | |
| 'accuracy': accuracy, | |
| 'total': total, | |
| 'normalized': self._normalize_reward(total) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting reward components: {e}") | |
| return { | |
| 'clarity': 0.0, | |
| 'naturalness': 0.0, | |
| 'accuracy': 0.0, | |
| 'total': 0.0, | |
| 'normalized': self.DEFAULT_PENALTY | |
| } | |