Automatic Speech Recognition
Safetensors
Chinese
whisper
gpric024 commited on
Commit
b1ef1ed
·
1 Parent(s): 234d407

Adding handler and requirements.txt

Browse files
Files changed (2) hide show
  1. handler.py +197 -0
  2. requirements.txt +6 -0
handler.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Inference Handler for StutteredSpeechASR Model
3
+ Handles audio input and returns transcriptions for stuttered speech.
4
+ """
5
+
6
+ import torch
7
+ import librosa
8
+ import numpy as np
9
+ import base64
10
+ import io
11
+ import logging
12
+ from typing import Dict, Any
13
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class EndpointHandler:
21
+ """
22
+ Custom handler for StutteredSpeechASR inference endpoint.
23
+
24
+ This handler processes audio inputs and returns transcriptions
25
+ using the fine-tuned Whisper model for stuttered Mandarin speech.
26
+ """
27
+
28
+ def __init__(self, path: str = ""):
29
+ """
30
+ Initialize the handler by loading the model and processor.
31
+
32
+ Args:
33
+ path: Path to the model directory (provided by Inference Endpoints)
34
+ """
35
+ logger.info("Initializing StutteredSpeechASR handler...")
36
+
37
+ # Determine device and dtype
38
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
40
+
41
+ logger.info(f"Using device: {self.device}")
42
+ logger.info(f"Using dtype: {self.torch_dtype}")
43
+
44
+ # Load model and processor
45
+ try:
46
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
47
+ path,
48
+ torch_dtype=self.torch_dtype
49
+ )
50
+ self.processor = AutoProcessor.from_pretrained(path)
51
+ self.model.to(self.device)
52
+ self.model.eval() # Set to evaluation mode
53
+
54
+ logger.info("Model and processor loaded successfully!")
55
+ except Exception as e:
56
+ logger.error(f"Error loading model: {e}")
57
+ raise
58
+
59
+ def _load_audio_from_bytes(self, audio_bytes: bytes) -> np.ndarray:
60
+ """
61
+ Load audio from bytes and resample to 16kHz.
62
+
63
+ Args:
64
+ audio_bytes: Raw audio bytes
65
+
66
+ Returns:
67
+ Audio waveform as numpy array
68
+ """
69
+ try:
70
+ # Load audio from bytes using librosa
71
+ audio_buffer = io.BytesIO(audio_bytes)
72
+ waveform, _ = librosa.load(audio_buffer, sr=16000, mono=True)
73
+ return waveform
74
+ except Exception as e:
75
+ logger.error(f"Error loading audio from bytes: {e}")
76
+ raise ValueError(f"Failed to load audio: {e}")
77
+
78
+ def _load_audio_from_base64(self, base64_string: str) -> np.ndarray:
79
+ """
80
+ Load audio from base64-encoded string.
81
+
82
+ Args:
83
+ base64_string: Base64-encoded audio data
84
+
85
+ Returns:
86
+ Audio waveform as numpy array
87
+ """
88
+ try:
89
+ # Decode base64 string
90
+ audio_bytes = base64.b64decode(base64_string)
91
+ return self._load_audio_from_bytes(audio_bytes)
92
+ except Exception as e:
93
+ logger.error(f"Error decoding base64 audio: {e}")
94
+ raise ValueError(f"Failed to decode base64 audio: {e}")
95
+
96
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
97
+ """
98
+ Process incoming requests and return transcriptions.
99
+
100
+ Expected input formats:
101
+ 1. {"inputs": "base64_encoded_audio_string"}
102
+ 2. {"inputs": {"audio": "base64_encoded_audio_string"}}
103
+ 3. Binary audio data in request body
104
+
105
+ Args:
106
+ data: Input data dictionary
107
+
108
+ Returns:
109
+ Dictionary containing transcription and metadata
110
+ """
111
+ try:
112
+ logger.info("Processing inference request...")
113
+
114
+ # Extract audio data from various input formats
115
+ waveform = None
116
+
117
+ if isinstance(data, dict):
118
+ # Format 1: {"inputs": "base64_string"}
119
+ if "inputs" in data:
120
+ inputs = data["inputs"]
121
+
122
+ if isinstance(inputs, str):
123
+ # Base64-encoded audio
124
+ waveform = self._load_audio_from_base64(inputs)
125
+
126
+ elif isinstance(inputs, dict):
127
+ # Format 2: {"inputs": {"audio": "base64_string"}}
128
+ if "audio" in inputs:
129
+ waveform = self._load_audio_from_base64(inputs["audio"])
130
+ else:
131
+ raise ValueError("Missing 'audio' field in inputs dictionary")
132
+
133
+ elif isinstance(inputs, bytes):
134
+ # Binary audio data
135
+ waveform = self._load_audio_from_bytes(inputs)
136
+
137
+ else:
138
+ raise ValueError(f"Unsupported input type: {type(inputs)}")
139
+
140
+ # Direct audio field
141
+ elif "audio" in data:
142
+ audio_data = data["audio"]
143
+ if isinstance(audio_data, str):
144
+ waveform = self._load_audio_from_base64(audio_data)
145
+ elif isinstance(audio_data, bytes):
146
+ waveform = self._load_audio_from_bytes(audio_data)
147
+
148
+ else:
149
+ raise ValueError("No valid audio data found in request. Expected 'inputs' or 'audio' field.")
150
+
151
+ elif isinstance(data, (bytes, bytearray)):
152
+ # Format 3: Direct binary data
153
+ waveform = self._load_audio_from_bytes(bytes(data))
154
+
155
+ else:
156
+ raise ValueError(f"Unsupported data type: {type(data)}")
157
+
158
+ if waveform is None:
159
+ raise ValueError("Failed to extract audio from request")
160
+
161
+ logger.info(f"Audio loaded: {len(waveform)} samples at 16kHz")
162
+
163
+ # Process audio with the processor
164
+ input_features = self.processor(
165
+ waveform,
166
+ sampling_rate=16000,
167
+ return_tensors="pt"
168
+ ).input_features
169
+
170
+ # Move to device
171
+ input_features = input_features.to(self.device, dtype=self.torch_dtype)
172
+
173
+ # Run inference with forced Mandarin Chinese language
174
+ with torch.no_grad():
175
+ predicted_ids = self.model.generate(input_features)
176
+
177
+ # Decode transcription
178
+ transcription = self.processor.batch_decode(
179
+ predicted_ids,
180
+ skip_special_tokens=True
181
+ )[0]
182
+
183
+ logger.info(f"Transcription complete: {transcription[:100]}...")
184
+
185
+ # Return result
186
+ return {
187
+ "transcription": transcription.strip(),
188
+ "audio_duration_seconds": float(len(waveform) / 16000),
189
+ "model": "AImpower/StutteredSpeechASR"
190
+ }
191
+
192
+ except Exception as e:
193
+ logger.error(f"Error during inference: {e}", exc_info=True)
194
+ return {
195
+ "error": str(e),
196
+ "transcription": None
197
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ librosa>=0.10.0
5
+ numpy>=1.24.0
6
+ soundfile>=0.12.0