AuriStreamDistillLarge_100M40PredTeacher_bad / feature_extraction_distilled_speech.py
klemenk's picture
Upload distilled speech model
6f980ab verified
"""
Feature extractor for Distilled Speech Model.
Handles audio preprocessing: normalization to zero mean and unit variance.
"""
from typing import List, Optional, Union
import numpy as np
import torch
class DistilledSpeechFeatureExtractor:
"""
Feature extractor for DistilledSpeechModel.
Normalizes audio to zero mean and unit variance (per-sample).
Expected input: 16kHz mono audio.
Example:
>>> extractor = DistilledSpeechFeatureExtractor()
>>> audio = np.random.randn(16000) # 1 second
>>> inputs = extractor(audio, return_tensors="pt", sample_rate=16000)
>>> inputs.input_values.shape
torch.Size([1, 16000])
"""
def __init__(
self,
sampling_rate: int = 16000,
do_normalize: bool = True,
return_attention_mask: bool = False,
):
self.sampling_rate = sampling_rate
self.do_normalize = do_normalize
self.return_attention_mask = return_attention_mask
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], torch.Tensor],
return_tensors: Optional[str] = "pt",
sample_rate: Optional[int] = None,
**kwargs,
):
"""
Process raw audio into model inputs.
Args:
raw_speech: Raw audio waveform (1D array or tensor)
return_tensors: "pt" for PyTorch tensors, "np" for numpy
sample_rate: Sample rate of input audio (for validation)
Returns:
Object with input_values attribute
"""
# Validate sample rate
if sample_rate is not None and sample_rate != self.sampling_rate:
raise ValueError(
f"Expected sample rate {self.sampling_rate}, got {sample_rate}. "
f"Please resample your audio to {self.sampling_rate}Hz."
)
# Convert to numpy if needed
if isinstance(raw_speech, torch.Tensor):
raw_speech = raw_speech.numpy()
elif isinstance(raw_speech, list):
raw_speech = np.array(raw_speech)
raw_speech = np.asarray(raw_speech, dtype=np.float32)
# Ensure 1D
if raw_speech.ndim > 1:
raw_speech = raw_speech.squeeze()
if raw_speech.ndim != 1:
raise ValueError(f"Expected 1D audio, got shape {raw_speech.shape}")
# Normalize
if self.do_normalize:
raw_speech = (raw_speech - raw_speech.mean()) / (raw_speech.std() + 1e-7)
# Add batch dimension
raw_speech = raw_speech[np.newaxis, :]
# Convert to tensors
if return_tensors == "pt":
input_values = torch.from_numpy(raw_speech)
else:
input_values = raw_speech
return FeatureExtractorOutput(input_values=input_values)
def to_dict(self):
"""Serialize to dict for saving."""
return {
"sampling_rate": self.sampling_rate,
"do_normalize": self.do_normalize,
"return_attention_mask": self.return_attention_mask,
"feature_extractor_type": "DistilledSpeechFeatureExtractor",
}
@classmethod
def from_dict(cls, config_dict):
"""Load from dict."""
return cls(
sampling_rate=config_dict.get("sampling_rate", 16000),
do_normalize=config_dict.get("do_normalize", True),
return_attention_mask=config_dict.get("return_attention_mask", False),
)
def save_pretrained(self, save_directory: str):
"""Save feature extractor config."""
import json
import os
os.makedirs(save_directory, exist_ok=True)
with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""Load feature extractor from directory or hub."""
import json
import os
if os.path.isdir(pretrained_model_name_or_path):
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
else:
# Try to download from hub
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="preprocessor_config.json",
)
with open(config_path, "r") as f:
config = json.load(f)
return cls.from_dict(config)
class FeatureExtractorOutput:
"""Simple container for feature extractor output."""
def __init__(self, input_values):
self.input_values = input_values
def to(self, device):
"""Move tensors to device."""
if isinstance(self.input_values, torch.Tensor):
self.input_values = self.input_values.to(device)
return self