""" Feature extractor for Distilled Speech Model. Handles audio preprocessing: normalization to zero mean and unit variance. """ from typing import List, Optional, Union import numpy as np import torch class DistilledSpeechFeatureExtractor: """ Feature extractor for DistilledSpeechModel. Normalizes audio to zero mean and unit variance (per-sample). Expected input: 16kHz mono audio. Example: >>> extractor = DistilledSpeechFeatureExtractor() >>> audio = np.random.randn(16000) # 1 second >>> inputs = extractor(audio, return_tensors="pt", sample_rate=16000) >>> inputs.input_values.shape torch.Size([1, 16000]) """ def __init__( self, sampling_rate: int = 16000, do_normalize: bool = True, return_attention_mask: bool = False, ): self.sampling_rate = sampling_rate self.do_normalize = do_normalize self.return_attention_mask = return_attention_mask def __call__( self, raw_speech: Union[np.ndarray, List[float], torch.Tensor], return_tensors: Optional[str] = "pt", sample_rate: Optional[int] = None, **kwargs, ): """ Process raw audio into model inputs. Args: raw_speech: Raw audio waveform (1D array or tensor) return_tensors: "pt" for PyTorch tensors, "np" for numpy sample_rate: Sample rate of input audio (for validation) Returns: Object with input_values attribute """ # Validate sample rate if sample_rate is not None and sample_rate != self.sampling_rate: raise ValueError( f"Expected sample rate {self.sampling_rate}, got {sample_rate}. " f"Please resample your audio to {self.sampling_rate}Hz." ) # Convert to numpy if needed if isinstance(raw_speech, torch.Tensor): raw_speech = raw_speech.numpy() elif isinstance(raw_speech, list): raw_speech = np.array(raw_speech) raw_speech = np.asarray(raw_speech, dtype=np.float32) # Ensure 1D if raw_speech.ndim > 1: raw_speech = raw_speech.squeeze() if raw_speech.ndim != 1: raise ValueError(f"Expected 1D audio, got shape {raw_speech.shape}") # Normalize if self.do_normalize: raw_speech = (raw_speech - raw_speech.mean()) / (raw_speech.std() + 1e-7) # Add batch dimension raw_speech = raw_speech[np.newaxis, :] # Convert to tensors if return_tensors == "pt": input_values = torch.from_numpy(raw_speech) else: input_values = raw_speech return FeatureExtractorOutput(input_values=input_values) def to_dict(self): """Serialize to dict for saving.""" return { "sampling_rate": self.sampling_rate, "do_normalize": self.do_normalize, "return_attention_mask": self.return_attention_mask, "feature_extractor_type": "DistilledSpeechFeatureExtractor", } @classmethod def from_dict(cls, config_dict): """Load from dict.""" return cls( sampling_rate=config_dict.get("sampling_rate", 16000), do_normalize=config_dict.get("do_normalize", True), return_attention_mask=config_dict.get("return_attention_mask", False), ) def save_pretrained(self, save_directory: str): """Save feature extractor config.""" import json import os os.makedirs(save_directory, exist_ok=True) with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f: json.dump(self.to_dict(), f, indent=2) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): """Load feature extractor from directory or hub.""" import json import os if os.path.isdir(pretrained_model_name_or_path): config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") else: # Try to download from hub from huggingface_hub import hf_hub_download config_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="preprocessor_config.json", ) with open(config_path, "r") as f: config = json.load(f) return cls.from_dict(config) class FeatureExtractorOutput: """Simple container for feature extractor output.""" def __init__(self, input_values): self.input_values = input_values def to(self, device): """Move tensors to device.""" if isinstance(self.input_values, torch.Tensor): self.input_values = self.input_values.to(device) return self