File size: 5,037 Bytes
6f980ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Feature extractor for Distilled Speech Model.

Handles audio preprocessing: normalization to zero mean and unit variance.
"""

from typing import List, Optional, Union

import numpy as np
import torch


class DistilledSpeechFeatureExtractor:
    """
    Feature extractor for DistilledSpeechModel.
    
    Normalizes audio to zero mean and unit variance (per-sample).
    Expected input: 16kHz mono audio.
    
    Example:
        >>> extractor = DistilledSpeechFeatureExtractor()
        >>> audio = np.random.randn(16000)  # 1 second
        >>> inputs = extractor(audio, return_tensors="pt", sample_rate=16000)
        >>> inputs.input_values.shape
        torch.Size([1, 16000])
    """
    
    def __init__(
        self,
        sampling_rate: int = 16000,
        do_normalize: bool = True,
        return_attention_mask: bool = False,
    ):
        self.sampling_rate = sampling_rate
        self.do_normalize = do_normalize
        self.return_attention_mask = return_attention_mask
    
    def __call__(
        self,
        raw_speech: Union[np.ndarray, List[float], torch.Tensor],
        return_tensors: Optional[str] = "pt",
        sample_rate: Optional[int] = None,
        **kwargs,
    ):
        """
        Process raw audio into model inputs.
        
        Args:
            raw_speech: Raw audio waveform (1D array or tensor)
            return_tensors: "pt" for PyTorch tensors, "np" for numpy
            sample_rate: Sample rate of input audio (for validation)
            
        Returns:
            Object with input_values attribute
        """
        # Validate sample rate
        if sample_rate is not None and sample_rate != self.sampling_rate:
            raise ValueError(
                f"Expected sample rate {self.sampling_rate}, got {sample_rate}. "
                f"Please resample your audio to {self.sampling_rate}Hz."
            )
        
        # Convert to numpy if needed
        if isinstance(raw_speech, torch.Tensor):
            raw_speech = raw_speech.numpy()
        elif isinstance(raw_speech, list):
            raw_speech = np.array(raw_speech)
        
        raw_speech = np.asarray(raw_speech, dtype=np.float32)
        
        # Ensure 1D
        if raw_speech.ndim > 1:
            raw_speech = raw_speech.squeeze()
        if raw_speech.ndim != 1:
            raise ValueError(f"Expected 1D audio, got shape {raw_speech.shape}")
        
        # Normalize
        if self.do_normalize:
            raw_speech = (raw_speech - raw_speech.mean()) / (raw_speech.std() + 1e-7)
        
        # Add batch dimension
        raw_speech = raw_speech[np.newaxis, :]
        
        # Convert to tensors
        if return_tensors == "pt":
            input_values = torch.from_numpy(raw_speech)
        else:
            input_values = raw_speech
        
        return FeatureExtractorOutput(input_values=input_values)
    
    def to_dict(self):
        """Serialize to dict for saving."""
        return {
            "sampling_rate": self.sampling_rate,
            "do_normalize": self.do_normalize,
            "return_attention_mask": self.return_attention_mask,
            "feature_extractor_type": "DistilledSpeechFeatureExtractor",
        }
    
    @classmethod
    def from_dict(cls, config_dict):
        """Load from dict."""
        return cls(
            sampling_rate=config_dict.get("sampling_rate", 16000),
            do_normalize=config_dict.get("do_normalize", True),
            return_attention_mask=config_dict.get("return_attention_mask", False),
        )
    
    def save_pretrained(self, save_directory: str):
        """Save feature extractor config."""
        import json
        import os
        os.makedirs(save_directory, exist_ok=True)
        with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f:
            json.dump(self.to_dict(), f, indent=2)
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        """Load feature extractor from directory or hub."""
        import json
        import os
        
        if os.path.isdir(pretrained_model_name_or_path):
            config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
        else:
            # Try to download from hub
            from huggingface_hub import hf_hub_download
            config_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="preprocessor_config.json",
            )
        
        with open(config_path, "r") as f:
            config = json.load(f)
        
        return cls.from_dict(config)


class FeatureExtractorOutput:
    """Simple container for feature extractor output."""
    
    def __init__(self, input_values):
        self.input_values = input_values
    
    def to(self, device):
        """Move tensors to device."""
        if isinstance(self.input_values, torch.Tensor):
            self.input_values = self.input_values.to(device)
        return self