michaeltangz commited on
Commit
a9a8aec
Β·
1 Parent(s): 03bd1f9

install Flash Attention 2 and optimize Whisper model loading; enhance streaming transcription with pipeline approach and latency tracking

Browse files
Files changed (1) hide show
  1. app.py +156 -279
app.py CHANGED
@@ -2,11 +2,22 @@ import os
2
  import numpy as np
3
  import gradio as gr
4
  import torch
5
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
6
  import spaces
7
  import traceback
8
  from pydub import AudioSegment
9
  import librosa
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # -------------------------
12
  # Model Loading
@@ -14,7 +25,7 @@ import librosa
14
  print("πŸš€ Loading Whisper model...")
15
 
16
  model_id = "openai/whisper-large-v3-turbo"
17
- DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
18
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
19
 
20
  print(f"Using device={DEVICE}, dtype={TORCH_DTYPE}")
@@ -24,230 +35,94 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(
24
  torch_dtype=TORCH_DTYPE,
25
  low_cpu_mem_usage=True,
26
  use_safetensors=True,
 
27
  )
28
  model.to(DEVICE)
29
- model.eval()
30
 
31
  processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  print(f"βœ… Model loaded on {DEVICE}")
33
 
34
  # -------------------------
35
  # Constants
36
  # -------------------------
37
  SAMPLE_RATE = 16000
38
- BUFFER_SECONDS = 30 # Increased from 10 to keep more context
39
- MIN_AUDIO_LENGTH = 2.0 # Minimum 2 seconds before transcribing
40
- OVERLAP_SECONDS = 2 # Keep overlap for context
41
-
42
-
43
- def resample_audio(audio, orig_sr, target_sr=16000):
44
- """High-quality resampling using librosa."""
45
- if orig_sr == target_sr:
46
- return audio
47
- try:
48
- return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
49
- except Exception as e:
50
- print(f"Librosa resample failed: {e}, using linear interpolation")
51
- # Fallback to simple resampling
52
- duration = len(audio) / orig_sr
53
- target_length = int(duration * target_sr)
54
- if target_length == 0:
55
- return np.array([], dtype=np.float32)
56
- indices = np.linspace(0, len(audio) - 1, target_length)
57
- return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32)
58
-
59
-
60
- def detect_voice_activity(audio, threshold=0.01):
61
- """Simple VAD: check if audio has sufficient energy."""
62
- if len(audio) == 0:
63
- return False
64
- rms = np.sqrt(np.mean(audio**2))
65
- return rms > threshold
66
 
67
 
68
  @spaces.GPU
69
- def transcribe_audio(audio_chunk, history, full_transcript, last_transcribed_length):
70
  """
71
- Improved streaming transcription with better accuracy.
72
- audio_chunk: (sample_rate, audio_data) from Gradio
73
- history: accumulated audio buffer as numpy array
74
- full_transcript: accumulated text transcript
75
- last_transcribed_length: length of audio already transcribed
76
  """
 
77
  try:
78
- if audio_chunk is None:
79
- return history, full_transcript, full_transcript, last_transcribed_length
80
-
81
- # Parse audio
82
- if isinstance(audio_chunk, tuple):
83
- sr, data = audio_chunk
84
- else:
85
- return history, full_transcript, full_transcript, last_transcribed_length
86
-
87
- if data is None or len(data) == 0:
88
- return history, full_transcript, full_transcript, last_transcribed_length
89
-
90
- # Convert to mono float32
91
- data = np.asarray(data, dtype=np.float32)
92
- if data.ndim == 2:
93
- data = np.mean(data, axis=1)
94
 
95
- # Normalize if needed
96
- if data.dtype == np.int16:
97
- data = data.astype(np.float32) / 32768.0
98
- elif data.dtype == np.int32:
99
- data = data.astype(np.float32) / 2147483648.0
100
 
101
- data = np.clip(data, -1.0, 1.0)
 
 
102
 
103
- # Resample if needed
104
- if sr != SAMPLE_RATE:
105
- data = resample_audio(data, sr, SAMPLE_RATE)
 
 
106
 
107
- # Initialize history if needed
108
- if history is None or len(history) == 0:
109
- history = data
110
  else:
111
- history = np.concatenate([history, data])
112
 
113
- # Keep buffer within limits
114
- max_samples = SAMPLE_RATE * BUFFER_SECONDS
115
- if len(history) > max_samples:
116
- # Keep some overlap for context
117
- overlap_samples = int(SAMPLE_RATE * OVERLAP_SECONDS)
118
- history = history[-(max_samples + overlap_samples):]
119
-
120
- # Need minimum audio to transcribe
121
- min_samples = int(SAMPLE_RATE * MIN_AUDIO_LENGTH)
122
- if len(history) < min_samples:
123
- return history, full_transcript, full_transcript, last_transcribed_length
124
-
125
- # Check for voice activity
126
- if not detect_voice_activity(history[-min_samples:]):
127
- return history, full_transcript, full_transcript, last_transcribed_length
128
-
129
- # Only transcribe new audio (not already transcribed)
130
- new_audio_length = len(history) - last_transcribed_length
131
- if new_audio_length < SAMPLE_RATE * 1.0: # Wait for at least 1 second of new audio
132
- return history, full_transcript, full_transcript, last_transcribed_length
133
-
134
- # Transcribe the buffer with better parameters
135
- inputs = processor(
136
- history,
137
- sampling_rate=SAMPLE_RATE,
138
- return_tensors="pt"
139
- )
140
-
141
- input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
142
-
143
- with torch.no_grad():
144
- predicted_ids = model.generate(
145
- input_features,
146
- max_new_tokens=440, # Leave room for special tokens (total must be < 448)
147
- num_beams=3, # Beam search for better quality (balanced)
148
- do_sample=False,
149
- language="en", # Specify language for better accuracy
150
- task="transcribe",
151
- )
152
-
153
- text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
154
-
155
- if not text:
156
- return history, full_transcript, full_transcript, last_transcribed_length
157
-
158
- # Update the full transcript
159
- # Check if new text is different from what we already have
160
- if full_transcript:
161
- # If the new transcription starts with the end of our current transcript,
162
- # only add the new part
163
- words_current = full_transcript.split()
164
- words_new = text.split()
165
-
166
- # Find overlap
167
- overlap_found = False
168
- for i in range(min(len(words_current), len(words_new))):
169
- if words_current[-(i+1):] == words_new[:i+1]:
170
- # Found overlap, add only new words
171
- new_words = words_new[i+1:]
172
- if new_words:
173
- full_transcript = full_transcript + " " + " ".join(new_words)
174
- overlap_found = True
175
- break
176
-
177
- if not overlap_found:
178
- # No overlap found, check if it's completely new
179
- if text not in full_transcript:
180
- full_transcript = full_transcript + " " + text
181
  else:
182
- full_transcript = text
183
 
184
- # Update last transcribed length
185
- last_transcribed_length = len(history)
 
 
186
 
187
- return history, full_transcript, full_transcript, last_transcribed_length
188
-
189
  except Exception as e:
190
- print(f"Error: {e}")
191
  traceback.print_exc()
192
- return (
193
- history if history is not None else np.array([]),
194
- full_transcript,
195
- full_transcript,
196
- last_transcribed_length
197
- )
198
 
199
 
 
200
  def transcribe_file(file):
201
- """Transcribe an uploaded audio file with high quality settings."""
202
  if file is None:
203
  return ""
204
 
 
205
  try:
206
- # Load audio file using librosa for better quality
207
- audio_data, sr = librosa.load(file.name, sr=SAMPLE_RATE, mono=True)
208
 
209
- # Normalize
210
- audio_data = np.clip(audio_data, -1.0, 1.0)
211
 
212
- # Process in chunks with overlap
213
- chunk_size = SAMPLE_RATE * 30 # 30 second chunks
214
- overlap_size = SAMPLE_RATE * 2 # 2 second overlap
215
- texts = []
216
-
217
- for start in range(0, len(audio_data), chunk_size - overlap_size):
218
- chunk = audio_data[start:start + chunk_size]
219
- if len(chunk) < SAMPLE_RATE * 1.0: # Skip chunks less than 1 second
220
- continue
221
-
222
- inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt")
223
- input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
224
-
225
- with torch.no_grad():
226
- predicted_ids = model.generate(
227
- input_features,
228
- max_new_tokens=440, # Leave room for special tokens (total must be < 448)
229
- num_beams=5, # Higher beam search for best quality
230
- language="en",
231
- task="transcribe",
232
- )
233
-
234
- text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
235
- if text:
236
- # Remove duplicate text from overlaps
237
- if texts and text.startswith(texts[-1].split()[-5:][0] if len(texts[-1].split()) >= 5 else ""):
238
- # Find overlap and merge
239
- words_prev = texts[-1].split()
240
- words_curr = text.split()
241
- for i in range(min(10, len(words_prev), len(words_curr))):
242
- if words_prev[-(i+1):] == words_curr[:i+1]:
243
- texts[-1] = texts[-1] + " " + " ".join(words_curr[i+1:])
244
- break
245
- else:
246
- texts.append(text)
247
- else:
248
- texts.append(text)
249
-
250
- return " ".join(texts)
251
 
252
  except Exception as e:
253
  print(f"File transcription error: {e}")
@@ -255,111 +130,113 @@ def transcribe_file(file):
255
  return f"Error: {str(e)}"
256
 
257
 
258
- def clear_history():
259
- """Reset everything."""
260
- return np.array([]), "", "", 0
 
 
 
 
 
261
 
262
 
263
  # -------------------------
264
  # Gradio UI
265
  # -------------------------
266
- with gr.Blocks(title="🎀 Whisper ASR") as demo:
267
  gr.Markdown(
268
  """
269
- # 🎀 Whisper Real-Time Transcription (Improved Accuracy)
270
 
271
- **How to use:**
272
- - **Microphone**: Click to record, speak, see live transcription
273
- - **File Upload**: Upload audio file and click "Transcribe"
274
- - **Clear**: Reset the transcription
275
 
276
- **Improvements:**
277
- - Higher quality beam search
278
- - Better context retention (30s buffer)
279
- - Proper audio resampling with librosa
280
- - Voice activity detection
281
- - Smarter overlap handling
282
 
283
- Using Whisper-large-v3-turbo
284
  """
285
  )
286
 
287
- with gr.Row():
288
- with gr.Column():
289
- source = gr.Radio(
290
- ["Microphone", "Upload File"],
291
- value="Microphone",
292
- label="Audio Source"
293
- )
294
-
295
- mic = gr.Audio(
296
- sources=["microphone"],
297
- type="numpy",
298
- streaming=True,
299
- label="πŸŽ™οΈ Microphone",
300
- visible=True
301
- )
302
-
303
- file_input = gr.File(
304
- label="πŸ“ Upload Audio",
305
- file_types=["audio"],
306
- visible=False
307
- )
308
-
309
- transcribe_btn = gr.Button(
310
- "Transcribe File",
311
- visible=False
312
- )
313
 
314
- clear_btn = gr.Button("πŸ—‘οΈ Clear")
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- with gr.Column():
317
- output = gr.Textbox(
318
- label="πŸ“„ Transcription",
319
- lines=12,
320
- interactive=False
321
- )
322
-
323
- # State: audio buffer, full transcript, and last transcribed length
324
- audio_history = gr.State(np.array([]))
325
- transcript_state = gr.State("")
326
- last_transcribed_state = gr.State(0)
327
-
328
- # Toggle UI based on source
329
- def update_ui(choice):
330
- is_mic = choice == "Microphone"
331
- return (
332
- gr.update(visible=is_mic),
333
- gr.update(visible=not is_mic),
334
- gr.update(visible=not is_mic)
 
335
  )
336
 
337
- source.change(
338
- update_ui,
339
- inputs=source,
340
- outputs=[mic, file_input, transcribe_btn]
341
- )
342
-
343
- # Streaming mic input
344
- mic.stream(
345
- transcribe_audio,
346
- inputs=[mic, audio_history, transcript_state, last_transcribed_state],
347
- outputs=[audio_history, transcript_state, output, last_transcribed_state]
348
- )
349
-
350
- # File transcription
351
- transcribe_btn.click(
352
- transcribe_file,
353
- inputs=file_input,
354
- outputs=output
355
- )
356
-
357
- # Clear button
358
- clear_btn.click(
359
- clear_history,
360
- outputs=[audio_history, transcript_state, output, last_transcribed_state]
361
- )
 
 
 
 
 
 
362
 
363
  if __name__ == "__main__":
364
- # share=False on Spaces (automatically public), True for local
365
- demo.launch(share=False)
 
2
  import numpy as np
3
  import gradio as gr
4
  import torch
5
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
6
  import spaces
7
  import traceback
8
  from pydub import AudioSegment
9
  import librosa
10
+ import subprocess
11
+ import time
12
+
13
+ # -------------------------
14
+ # Install Flash Attention 2
15
+ # -------------------------
16
+ subprocess.run(
17
+ "pip install flash-attn --no-build-isolation",
18
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
19
+ shell=True,
20
+ )
21
 
22
  # -------------------------
23
  # Model Loading
 
25
  print("πŸš€ Loading Whisper model...")
26
 
27
  model_id = "openai/whisper-large-v3-turbo"
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
30
 
31
  print(f"Using device={DEVICE}, dtype={TORCH_DTYPE}")
 
35
  torch_dtype=TORCH_DTYPE,
36
  low_cpu_mem_usage=True,
37
  use_safetensors=True,
38
+ attn_implementation="flash_attention_2"
39
  )
40
  model.to(DEVICE)
 
41
 
42
  processor = AutoProcessor.from_pretrained(model_id)
43
+ tokenizer = WhisperTokenizer.from_pretrained(model_id)
44
+
45
+ # Create pipeline with proper configuration
46
+ pipe = pipeline(
47
+ task="automatic-speech-recognition",
48
+ model=model,
49
+ tokenizer=tokenizer,
50
+ feature_extractor=processor.feature_extractor,
51
+ chunk_length_s=30, # Process 30-second chunks
52
+ torch_dtype=TORCH_DTYPE,
53
+ device=DEVICE,
54
+ )
55
+
56
  print(f"βœ… Model loaded on {DEVICE}")
57
 
58
  # -------------------------
59
  # Constants
60
  # -------------------------
61
  SAMPLE_RATE = 16000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  @spaces.GPU
65
+ def stream_transcribe(stream, new_chunk):
66
  """
67
+ Streaming transcription using pipeline approach.
68
+ stream: accumulated audio buffer
69
+ new_chunk: (sample_rate, audio_data) from Gradio
 
 
70
  """
71
+ start_time = time.time()
72
  try:
73
+ if new_chunk is None:
74
+ return stream, "", f"0.00"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ sr, y = new_chunk
 
 
 
 
77
 
78
+ # Convert to mono if stereo
79
+ if y.ndim > 1:
80
+ y = y.mean(axis=1)
81
 
82
+ # Convert to float32 and normalize
83
+ y = y.astype(np.float32)
84
+ max_val = np.max(np.abs(y))
85
+ if max_val > 0:
86
+ y /= max_val
87
 
88
+ # Concatenate with existing stream
89
+ if stream is not None and len(stream) > 0:
90
+ stream = np.concatenate([stream, y])
91
  else:
92
+ stream = y
93
 
94
+ # Transcribe the accumulated stream
95
+ if len(stream) > SAMPLE_RATE * 0.5: # At least 0.5 seconds
96
+ transcription = pipe({"sampling_rate": sr, "raw": stream})["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  else:
98
+ transcription = ""
99
 
100
+ end_time = time.time()
101
+ latency = end_time - start_time
102
+
103
+ return stream, transcription, f"{latency:.2f}"
104
 
 
 
105
  except Exception as e:
106
+ print(f"Error during streaming transcription: {e}")
107
  traceback.print_exc()
108
+ return stream if stream is not None else np.array([]), "", "Error"
 
 
 
 
 
109
 
110
 
111
+ @spaces.GPU
112
  def transcribe_file(file):
113
+ """Transcribe an uploaded audio file using pipeline."""
114
  if file is None:
115
  return ""
116
 
117
+ start_time = time.time()
118
  try:
119
+ # Use pipeline directly on the file
120
+ transcription = pipe(file.name)["text"]
121
 
122
+ end_time = time.time()
123
+ latency = end_time - start_time
124
 
125
+ return f"{transcription}\n\n(Transcribed in {latency:.2f}s)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  except Exception as e:
128
  print(f"File transcription error: {e}")
 
130
  return f"Error: {str(e)}"
131
 
132
 
133
+ def clear_output():
134
+ """Clear the output text."""
135
+ return ""
136
+
137
+
138
+ def clear_state():
139
+ """Clear the audio stream state."""
140
+ return None
141
 
142
 
143
  # -------------------------
144
  # Gradio UI
145
  # -------------------------
146
+ with gr.Blocks(title="🎀 Whisper ASR", theme=gr.themes.Ocean()) as demo:
147
  gr.Markdown(
148
  """
149
+ # 🎀 Whisper Large V3 Turbo - Real-Time Transcription
150
 
151
+ **Transcribe audio in real-time with high accuracy!**
 
 
 
152
 
153
+ This demo uses:
154
+ - Model: `openai/whisper-large-v3-turbo`
155
+ - Flash Attention 2 for speed
156
+ - Optimized pipeline for best accuracy
 
 
157
 
158
+ **Note:** First transcription may take ~5 seconds. After that, it runs smoothly.
159
  """
160
  )
161
 
162
+ with gr.Tab("πŸŽ™οΈ Microphone"):
163
+ with gr.Row():
164
+ with gr.Column():
165
+ mic_input = gr.Audio(
166
+ sources=["microphone"],
167
+ type="numpy",
168
+ streaming=True,
169
+ label="Microphone Input"
170
+ )
171
+ with gr.Row():
172
+ clear_mic_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ with gr.Column():
175
+ mic_output = gr.Textbox(
176
+ label="πŸ“„ Real-Time Transcription",
177
+ lines=10,
178
+ interactive=False
179
+ )
180
+ latency_box = gr.Textbox(
181
+ label="⚑ Latency (seconds)",
182
+ value="0.00",
183
+ interactive=False,
184
+ scale=0
185
+ )
186
 
187
+ # State for streaming
188
+ stream_state = gr.State()
189
+
190
+ # Streaming transcription
191
+ mic_input.stream(
192
+ stream_transcribe,
193
+ inputs=[stream_state, mic_input],
194
+ outputs=[stream_state, mic_output, latency_box],
195
+ time_limit=60,
196
+ stream_every=2,
197
+ concurrency_limit=None
198
+ )
199
+
200
+ # Clear button
201
+ clear_mic_btn.click(
202
+ clear_state,
203
+ outputs=[stream_state]
204
+ ).then(
205
+ clear_output,
206
+ outputs=[mic_output]
207
  )
208
 
209
+ with gr.Tab("πŸ“ Upload File"):
210
+ with gr.Row():
211
+ with gr.Column():
212
+ file_input = gr.Audio(
213
+ sources=["upload"],
214
+ type="filepath",
215
+ label="Upload Audio File"
216
+ )
217
+ with gr.Row():
218
+ transcribe_file_btn = gr.Button("▢️ Transcribe", variant="primary")
219
+ clear_file_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
220
+
221
+ with gr.Column():
222
+ file_output = gr.Textbox(
223
+ label="πŸ“„ Transcription",
224
+ lines=10,
225
+ interactive=False
226
+ )
227
+
228
+ # File transcription
229
+ transcribe_file_btn.click(
230
+ transcribe_file,
231
+ inputs=file_input,
232
+ outputs=file_output
233
+ )
234
+
235
+ # Clear button
236
+ clear_file_btn.click(
237
+ clear_output,
238
+ outputs=[file_output]
239
+ )
240
 
241
  if __name__ == "__main__":
242
+ demo.launch(share=True)