| """ |
| Feature extractor for Distilled Speech Model. |
| |
| Handles audio preprocessing: normalization to zero mean and unit variance. |
| """ |
|
|
| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| class DistilledSpeechFeatureExtractor: |
| """ |
| Feature extractor for DistilledSpeechModel. |
| |
| Normalizes audio to zero mean and unit variance (per-sample). |
| Expected input: 16kHz mono audio. |
| |
| Example: |
| >>> extractor = DistilledSpeechFeatureExtractor() |
| >>> audio = np.random.randn(16000) # 1 second |
| >>> inputs = extractor(audio, return_tensors="pt", sample_rate=16000) |
| >>> inputs.input_values.shape |
| torch.Size([1, 16000]) |
| """ |
| |
| def __init__( |
| self, |
| sampling_rate: int = 16000, |
| do_normalize: bool = True, |
| return_attention_mask: bool = False, |
| ): |
| self.sampling_rate = sampling_rate |
| self.do_normalize = do_normalize |
| self.return_attention_mask = return_attention_mask |
| |
| def __call__( |
| self, |
| raw_speech: Union[np.ndarray, List[float], torch.Tensor], |
| return_tensors: Optional[str] = "pt", |
| sample_rate: Optional[int] = None, |
| **kwargs, |
| ): |
| """ |
| Process raw audio into model inputs. |
| |
| Args: |
| raw_speech: Raw audio waveform (1D array or tensor) |
| return_tensors: "pt" for PyTorch tensors, "np" for numpy |
| sample_rate: Sample rate of input audio (for validation) |
| |
| Returns: |
| Object with input_values attribute |
| """ |
| |
| if sample_rate is not None and sample_rate != self.sampling_rate: |
| raise ValueError( |
| f"Expected sample rate {self.sampling_rate}, got {sample_rate}. " |
| f"Please resample your audio to {self.sampling_rate}Hz." |
| ) |
| |
| |
| if isinstance(raw_speech, torch.Tensor): |
| raw_speech = raw_speech.numpy() |
| elif isinstance(raw_speech, list): |
| raw_speech = np.array(raw_speech) |
| |
| raw_speech = np.asarray(raw_speech, dtype=np.float32) |
| |
| |
| if raw_speech.ndim > 1: |
| raw_speech = raw_speech.squeeze() |
| if raw_speech.ndim != 1: |
| raise ValueError(f"Expected 1D audio, got shape {raw_speech.shape}") |
| |
| |
| if self.do_normalize: |
| raw_speech = (raw_speech - raw_speech.mean()) / (raw_speech.std() + 1e-7) |
| |
| |
| raw_speech = raw_speech[np.newaxis, :] |
| |
| |
| if return_tensors == "pt": |
| input_values = torch.from_numpy(raw_speech) |
| else: |
| input_values = raw_speech |
| |
| return FeatureExtractorOutput(input_values=input_values) |
| |
| def to_dict(self): |
| """Serialize to dict for saving.""" |
| return { |
| "sampling_rate": self.sampling_rate, |
| "do_normalize": self.do_normalize, |
| "return_attention_mask": self.return_attention_mask, |
| "feature_extractor_type": "DistilledSpeechFeatureExtractor", |
| } |
| |
| @classmethod |
| def from_dict(cls, config_dict): |
| """Load from dict.""" |
| return cls( |
| sampling_rate=config_dict.get("sampling_rate", 16000), |
| do_normalize=config_dict.get("do_normalize", True), |
| return_attention_mask=config_dict.get("return_attention_mask", False), |
| ) |
| |
| def save_pretrained(self, save_directory: str): |
| """Save feature extractor config.""" |
| import json |
| import os |
| os.makedirs(save_directory, exist_ok=True) |
| with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f: |
| json.dump(self.to_dict(), f, indent=2) |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
| """Load feature extractor from directory or hub.""" |
| import json |
| import os |
| |
| if os.path.isdir(pretrained_model_name_or_path): |
| config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") |
| else: |
| |
| from huggingface_hub import hf_hub_download |
| config_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="preprocessor_config.json", |
| ) |
| |
| with open(config_path, "r") as f: |
| config = json.load(f) |
| |
| return cls.from_dict(config) |
|
|
|
|
| class FeatureExtractorOutput: |
| """Simple container for feature extractor output.""" |
| |
| def __init__(self, input_values): |
| self.input_values = input_values |
| |
| def to(self, device): |
| """Move tensors to device.""" |
| if isinstance(self.input_values, torch.Tensor): |
| self.input_values = self.input_values.to(device) |
| return self |
|
|