|
|
""" |
|
|
Feature extractor for Distilled Speech Model. |
|
|
|
|
|
Handles audio preprocessing: normalization to zero mean and unit variance. |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
class DistilledSpeechFeatureExtractor: |
|
|
""" |
|
|
Feature extractor for DistilledSpeechModel. |
|
|
|
|
|
Normalizes audio to zero mean and unit variance (per-sample). |
|
|
Expected input: 16kHz mono audio. |
|
|
|
|
|
Example: |
|
|
>>> extractor = DistilledSpeechFeatureExtractor() |
|
|
>>> audio = np.random.randn(16000) # 1 second |
|
|
>>> inputs = extractor(audio, return_tensors="pt", sample_rate=16000) |
|
|
>>> inputs.input_values.shape |
|
|
torch.Size([1, 16000]) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
sampling_rate: int = 16000, |
|
|
do_normalize: bool = True, |
|
|
return_attention_mask: bool = False, |
|
|
): |
|
|
self.sampling_rate = sampling_rate |
|
|
self.do_normalize = do_normalize |
|
|
self.return_attention_mask = return_attention_mask |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
raw_speech: Union[np.ndarray, List[float], torch.Tensor], |
|
|
return_tensors: Optional[str] = "pt", |
|
|
sample_rate: Optional[int] = None, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Process raw audio into model inputs. |
|
|
|
|
|
Args: |
|
|
raw_speech: Raw audio waveform (1D array or tensor) |
|
|
return_tensors: "pt" for PyTorch tensors, "np" for numpy |
|
|
sample_rate: Sample rate of input audio (for validation) |
|
|
|
|
|
Returns: |
|
|
Object with input_values attribute |
|
|
""" |
|
|
|
|
|
if sample_rate is not None and sample_rate != self.sampling_rate: |
|
|
raise ValueError( |
|
|
f"Expected sample rate {self.sampling_rate}, got {sample_rate}. " |
|
|
f"Please resample your audio to {self.sampling_rate}Hz." |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(raw_speech, torch.Tensor): |
|
|
raw_speech = raw_speech.numpy() |
|
|
elif isinstance(raw_speech, list): |
|
|
raw_speech = np.array(raw_speech) |
|
|
|
|
|
raw_speech = np.asarray(raw_speech, dtype=np.float32) |
|
|
|
|
|
|
|
|
if raw_speech.ndim > 1: |
|
|
raw_speech = raw_speech.squeeze() |
|
|
if raw_speech.ndim != 1: |
|
|
raise ValueError(f"Expected 1D audio, got shape {raw_speech.shape}") |
|
|
|
|
|
|
|
|
if self.do_normalize: |
|
|
raw_speech = (raw_speech - raw_speech.mean()) / (raw_speech.std() + 1e-7) |
|
|
|
|
|
|
|
|
raw_speech = raw_speech[np.newaxis, :] |
|
|
|
|
|
|
|
|
if return_tensors == "pt": |
|
|
input_values = torch.from_numpy(raw_speech) |
|
|
else: |
|
|
input_values = raw_speech |
|
|
|
|
|
return FeatureExtractorOutput(input_values=input_values) |
|
|
|
|
|
def to_dict(self): |
|
|
"""Serialize to dict for saving.""" |
|
|
return { |
|
|
"sampling_rate": self.sampling_rate, |
|
|
"do_normalize": self.do_normalize, |
|
|
"return_attention_mask": self.return_attention_mask, |
|
|
"feature_extractor_type": "DistilledSpeechFeatureExtractor", |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, config_dict): |
|
|
"""Load from dict.""" |
|
|
return cls( |
|
|
sampling_rate=config_dict.get("sampling_rate", 16000), |
|
|
do_normalize=config_dict.get("do_normalize", True), |
|
|
return_attention_mask=config_dict.get("return_attention_mask", False), |
|
|
) |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
"""Save feature extractor config.""" |
|
|
import json |
|
|
import os |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f: |
|
|
json.dump(self.to_dict(), f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
|
"""Load feature extractor from directory or hub.""" |
|
|
import json |
|
|
import os |
|
|
|
|
|
if os.path.isdir(pretrained_model_name_or_path): |
|
|
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") |
|
|
else: |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
config_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="preprocessor_config.json", |
|
|
) |
|
|
|
|
|
with open(config_path, "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
return cls.from_dict(config) |
|
|
|
|
|
|
|
|
class FeatureExtractorOutput: |
|
|
"""Simple container for feature extractor output.""" |
|
|
|
|
|
def __init__(self, input_values): |
|
|
self.input_values = input_values |
|
|
|
|
|
def to(self, device): |
|
|
"""Move tensors to device.""" |
|
|
if isinstance(self.input_values, torch.Tensor): |
|
|
self.input_values = self.input_values.to(device) |
|
|
return self |
|
|
|