mic3333 commited on
Commit
55d67f9
Β·
1 Parent(s): 1f8fa97

simplify streaming transcription by removing VAD, diarization, and complex buffering logic

Browse files
Files changed (1) hide show
  1. app.py +165 -629
app.py CHANGED
@@ -1,33 +1,18 @@
1
  import os
2
- from contextlib import contextmanager, nullcontext
3
- from collections import deque
4
  import numpy as np
5
  import gradio as gr
6
  import torch
7
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
8
  import spaces
9
  import traceback
10
- import webrtcvad
11
- import re
12
- from difflib import SequenceMatcher
13
  from pydub import AudioSegment
14
 
15
- try:
16
- from pyannote.audio import Pipeline
17
- _HAVE_DIARIZATION = True
18
- except Exception:
19
- Pipeline = None
20
- _HAVE_DIARIZATION = False
21
-
22
  # -------------------------
23
- # Config / Model Loading
24
  # -------------------------
25
- print("πŸš€ Loading Whisper model at startup...")
26
- torch.set_float32_matmul_precision("high")
27
 
28
  model_id = "openai/whisper-large-v3-turbo"
29
-
30
- # Decide device and dtype once
31
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
32
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
33
 
@@ -42,213 +27,18 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(
42
  model.to(DEVICE)
43
  model.eval()
44
 
45
- # Configure generation settings
46
- try:
47
- model.generation_config.cache_implementation = "static"
48
- model.generation_config.max_new_tokens = 256
49
- except Exception as e:
50
- print("⚠️ Could not configure static cache on generation_config:", e)
51
-
52
  processor = AutoProcessor.from_pretrained(model_id)
53
- print(f"βœ… Model and processor loaded on {DEVICE}")
54
 
55
  # -------------------------
56
- # Globals / constants
57
  # -------------------------
58
  SAMPLE_RATE = 16000
59
- BUFFER_DURATION = 8 # seconds
60
- MAX_BUFFER_SAMPLES = int(SAMPLE_RATE * BUFFER_DURATION)
61
-
62
- # VAD (webrtcvad)
63
- vad = webrtcvad.Vad(2) # aggressiveness 0-3
64
-
65
- # Sentence splitting regex
66
- sentence_split_re = re.compile(
67
- r"(?<!Mr\.)(?<!Ms\.)(?<!Mrs\.)(?<!Dr\.)(?<!St\.)(?<!Jr\.)(?<!Sr\.)"
68
- r"(?<!Prof\.)(?<!Inc\.)(?<!Ltd\.)(?<!U\.S\.)"
69
- r"(?<=[.!?])\s+"
70
- )
71
-
72
-
73
- def create_initial_state():
74
- """
75
- Create a fresh per-session state dictionary.
76
- Uses list instead of deque and set for Gradio serialization compatibility.
77
- We convert back to deque/set during processing for efficiency.
78
- """
79
- return {
80
- "buffer": [], # Will be converted to deque during processing
81
- "full_transcript": "",
82
- "last_transcription": "",
83
- "entries": [],
84
- "processed_samples": 0,
85
- "total_audio_samples": 0,
86
- "speaker_map": {},
87
- "next_speaker_idx": 1,
88
- "seen_texts": [], # Will be converted to set during processing
89
- "unprocessed_audio": np.array([], dtype=np.float32),
90
- }
91
-
92
-
93
- def is_near_duplicate(a: str, b: str, threshold: float = 0.6) -> bool:
94
- """
95
- Return True if sentences a and b are very similar.
96
- """
97
- if not a or not b:
98
- return False
99
- ratio = SequenceMatcher(None, a.lower(), b.lower()).ratio()
100
- return ratio >= threshold
101
-
102
-
103
- def format_timestamp(seconds: float) -> str:
104
- """
105
- Format seconds as mm:ss.mmm (or hh:mm:ss.mmm for long audio).
106
- """
107
- total_ms = int(seconds * 1000)
108
- hours, rem = divmod(total_ms, 3_600_000)
109
- minutes, rem = divmod(rem, 60_000)
110
- secs, ms = divmod(rem, 1_000)
111
- if hours:
112
- return f"{hours:02d}:{minutes:02d}:{secs:02d}.{ms:03d}"
113
- return f"{minutes:02d}:{secs:02d}.{ms:03d}"
114
-
115
-
116
- diarization_pipeline = None
117
- diarization_call_count = 0 # Throttle diarization calls
118
-
119
-
120
- @contextmanager
121
- def _unsafe_torch_load_context():
122
- """
123
- Temporarily force torch.load to use weights_only=False.
124
- """
125
- orig_load = torch.load
126
-
127
- def _patched_load(*args, **kwargs):
128
- kwargs.setdefault("weights_only", False)
129
- return orig_load(*args, **kwargs)
130
-
131
- torch.load = _patched_load
132
- try:
133
- yield
134
- finally:
135
- torch.load = orig_load
136
-
137
-
138
- def get_diarization_pipeline():
139
- """
140
- Lazily load the pyannote diarization pipeline if available and configured.
141
- """
142
- global diarization_pipeline
143
- if not _HAVE_DIARIZATION:
144
- return None
145
- if diarization_pipeline is not None:
146
- return diarization_pipeline
147
-
148
- token = (
149
- os.environ.get("PYANNOTE_TOKEN")
150
- or os.environ.get("HF_TOKEN")
151
- or os.environ.get("HF_API_TOKEN")
152
- )
153
- if not token:
154
- print(
155
- "Diarization disabled: no Hugging Face token found. "
156
- "Set PYANNOTE_TOKEN, HF_TOKEN, or HF_API_TOKEN in your Space settings."
157
- )
158
- return None
159
-
160
- try:
161
- import torch.serialization as ts
162
- safe = []
163
- try:
164
- from torch.torch_version import TorchVersion
165
- safe.append(TorchVersion)
166
- except Exception:
167
- pass
168
- try:
169
- from pyannote.audio.core.task import Specifications, Problem, Resolution
170
- safe.append(Specifications)
171
- safe.append(Problem)
172
- safe.append(Resolution)
173
- except Exception:
174
- pass
175
-
176
- ctx = ts.safe_globals(safe) if safe else nullcontext()
177
- with _unsafe_torch_load_context():
178
- with ctx:
179
- diarization_pipeline = Pipeline.from_pretrained(
180
- "pyannote/speaker-diarization-3.1",
181
- use_auth_token=token,
182
- )
183
- print("βœ… Loaded pyannote speaker diarization pipeline.")
184
- except Exception as e:
185
- print("❌ Failed to load diarization pipeline:", e)
186
- diarization_pipeline = None
187
- return diarization_pipeline
188
-
189
-
190
- # -------------------------
191
- # VAD helpers
192
- # -------------------------
193
- def frame_generator(frame_duration_ms, audio, sample_rate):
194
- """
195
- Yields contiguous frames (numpy float32 array chunks).
196
- """
197
- n = int(sample_rate * (frame_duration_ms / 1000.0))
198
- offset = 0
199
- while offset + n <= len(audio):
200
- yield audio[offset:offset + n]
201
- offset += n
202
-
203
-
204
- def vad_collector(audio, sample_rate, frame_ms=30):
205
- """
206
- Return list of (start_sample, end_sample) voiced segments in `audio`.
207
- """
208
- frames = list(frame_generator(frame_ms, audio, sample_rate))
209
- if not frames:
210
- return []
211
-
212
- # Convert each frame to 16-bit PCM bytes for webrtcvad
213
- voiced_flags = []
214
- for f in frames:
215
- pcm16 = np.clip(f, -1.0, 1.0) # Ensure range
216
- pcm16 = (pcm16 * 32767).astype(np.int16).tobytes()
217
- try:
218
- is_speech = vad.is_speech(pcm16, sample_rate)
219
- except Exception:
220
- is_speech = False
221
- voiced_flags.append(is_speech)
222
 
223
- # Group consecutive voiced frames
224
- segments_ms = []
225
- start_frame = None
226
- for i, flag in enumerate(voiced_flags):
227
- if flag and start_frame is None:
228
- start_frame = i
229
- elif (not flag) and (start_frame is not None):
230
- segments_ms.append((start_frame * frame_ms, i * frame_ms))
231
- start_frame = None
232
- if start_frame is not None:
233
- segments_ms.append((start_frame * frame_ms, len(frames) * frame_ms))
234
 
235
- # Convert ms to sample indices
236
- sample_segments = []
237
- for s_ms, e_ms in segments_ms:
238
- s = int((s_ms / 1000.0) * sample_rate)
239
- e = int((e_ms / 1000.0) * sample_rate)
240
- sample_segments.append((s, e))
241
- return sample_segments
242
-
243
-
244
- # -------------------------
245
- # Audio resampling helper
246
- # -------------------------
247
- def resample_audio(audio, orig_sr, target_sr=16000):
248
- """
249
- Simple linear interpolation resampling.
250
- For production, consider using librosa or torchaudio for better quality.
251
- """
252
  if orig_sr == target_sr:
253
  return audio
254
  duration = len(audio) / orig_sr
@@ -256,489 +46,235 @@ def resample_audio(audio, orig_sr, target_sr=16000):
256
  if target_length == 0:
257
  return np.array([], dtype=np.float32)
258
  indices = np.linspace(0, len(audio) - 1, target_length)
259
- resampled = np.interp(indices, np.arange(len(audio)), audio)
260
- return resampled.astype(np.float32)
261
 
262
 
263
- # -------------------------
264
- # Core streaming transcription
265
- # -------------------------
266
  @spaces.GPU
267
- def stream_transcribe(audio, state):
268
  """
269
- Receives streaming audio chunks from Gradio Audio component.
270
- Returns (full_transcript, state).
 
271
  """
272
- global diarization_call_count
273
-
274
- # Ensure we have per-session state
275
- if state is None:
276
- state = create_initial_state()
277
-
278
- # Make a working copy and convert types back
279
- # Convert buffer from list back to deque if needed
280
- buffer_data = state["buffer"]
281
- if isinstance(buffer_data, list):
282
- buffer = deque(buffer_data, maxlen=MAX_BUFFER_SAMPLES)
283
- else:
284
- buffer = buffer_data
285
-
286
- full_transcript = state["full_transcript"]
287
- last_transcription = state["last_transcription"]
288
- entries = state["entries"].copy() # Copy list to avoid mutations
289
- processed_samples = state["processed_samples"]
290
- total_audio_samples = state["total_audio_samples"]
291
- speaker_map = state["speaker_map"].copy()
292
- next_speaker_idx = state["next_speaker_idx"]
293
-
294
- # Convert seen_texts back to set if it's a list
295
- seen_texts_data = state["seen_texts"]
296
- if isinstance(seen_texts_data, list):
297
- seen_texts = set(seen_texts_data)
298
- else:
299
- seen_texts = seen_texts_data.copy() if isinstance(seen_texts_data, set) else set()
300
-
301
- unprocessed_audio = state["unprocessed_audio"]
302
-
303
  try:
304
- if audio is None:
305
- return full_transcript, state
306
-
307
- # Expect (sr, data)
308
- if not (isinstance(audio, (list, tuple)) and len(audio) == 2):
309
- return full_transcript, state
310
- sr, data = audio
311
-
312
- if data is None or (isinstance(data, np.ndarray) and data.size == 0):
313
- return full_transcript, state
314
-
315
- # Convert to numpy float32
 
316
  data = np.asarray(data, dtype=np.float32)
317
-
318
- # If stereo, convert to mono
319
  if data.ndim == 2:
320
  data = np.mean(data, axis=1)
321
-
322
- # If int PCM, normalize
323
  if data.dtype == np.int16:
324
  data = data.astype(np.float32) / 32768.0
325
  elif data.dtype == np.int32:
326
  data = data.astype(np.float32) / 2147483648.0
327
-
328
- # Resample if needed
329
- if sr != SAMPLE_RATE:
330
- data = resample_audio(data, sr, SAMPLE_RATE)
331
-
332
- # Validate data range
333
- data = np.clip(data, -1.0, 1.0)
334
-
335
- # Track total samples received
336
- num_new = len(data)
337
- total_audio_samples += num_new
338
-
339
- # Add to buffer (deque will auto-trim to maxlen)
340
- buffer.extend(data)
341
 
342
- # Accumulate unprocessed audio for VAD
343
- unprocessed_audio = np.concatenate([unprocessed_audio, data])
344
-
345
- # If buffer too short, wait
346
- if len(buffer) < int(0.5 * SAMPLE_RATE):
347
- state["buffer"] = buffer
348
- state["total_audio_samples"] = total_audio_samples
349
- state["unprocessed_audio"] = unprocessed_audio
350
- return full_transcript, state
351
-
352
- # Only run VAD on NEW audio (unprocessed)
353
- if len(unprocessed_audio) < int(0.3 * SAMPLE_RATE):
354
- # Not enough new audio to process
355
- state["buffer"] = buffer
356
- state["total_audio_samples"] = total_audio_samples
357
- state["unprocessed_audio"] = unprocessed_audio
358
- return full_transcript, state
359
-
360
- # Run VAD on unprocessed audio to find speech
361
- segments = vad_collector(unprocessed_audio, SAMPLE_RATE)
362
-
363
- if not segments:
364
- # No speech detected in new audio, clear unprocessed buffer
365
- state["buffer"] = buffer
366
- state["total_audio_samples"] = total_audio_samples
367
- state["unprocessed_audio"] = np.array([], dtype=np.float32)
368
- return full_transcript, state
369
-
370
- # Get the last voiced segment from unprocessed audio
371
- start_samp, end_samp = segments[-1]
372
 
373
- # Extend with context from the full buffer
374
- # Calculate where this segment is in the full buffer
375
- buffer_array = np.array(buffer)
376
- buffer_len = len(buffer_array)
377
- unprocessed_len = len(unprocessed_audio)
378
 
379
- # Offset of unprocessed audio within buffer
380
- unprocessed_offset = buffer_len - unprocessed_len
 
 
 
381
 
382
- # Absolute positions in buffer
383
- abs_start_in_buffer = unprocessed_offset + start_samp
384
- abs_end_in_buffer = unprocessed_offset + end_samp
 
385
 
386
- # Add context
387
- ctx = int(0.15 * SAMPLE_RATE)
388
- s = max(0, abs_start_in_buffer - ctx)
389
- e = min(buffer_len, abs_end_in_buffer + ctx)
390
 
391
- segment_audio = buffer_array[s:e]
392
-
393
- # Calculate absolute timestamps
394
- # Buffer represents the last BUFFER_DURATION seconds of audio
395
- # The START of the buffer corresponds to (total_audio_samples - buffer_len)
396
- buffer_start_sample = total_audio_samples - buffer_len
397
 
398
- abs_start = buffer_start_sample + s
399
- abs_end = buffer_start_sample + e
400
- start_time = abs_start / SAMPLE_RATE
401
- end_time = abs_end / SAMPLE_RATE
402
-
403
- # Clear unprocessed audio after processing
404
- unprocessed_audio = np.array([], dtype=np.float32)
405
-
406
- # Optional speaker diarization (throttled - only every 3rd call for performance)
407
- speaker_label = "Speaker 1"
408
- diarization_call_count += 1
409
- pipeline = get_diarization_pipeline()
410
- if pipeline is not None and (diarization_call_count % 3 == 0):
411
- try:
412
- wave = torch.from_numpy(segment_audio).float().unsqueeze(0)
413
- diarization = pipeline({"waveform": wave, "sample_rate": SAMPLE_RATE})
414
- speaker_durations = {}
415
- for segment, _, raw_speaker in diarization.itertracks(yield_label=True):
416
- dur = segment.end - segment.start
417
- speaker_durations[raw_speaker] = speaker_durations.get(raw_speaker, 0.0) + dur
418
- if speaker_durations:
419
- dominant_raw = max(speaker_durations, key=speaker_durations.get)
420
- if dominant_raw not in speaker_map:
421
- speaker_map[dominant_raw] = f"Speaker {next_speaker_idx}"
422
- next_speaker_idx += 1
423
- speaker_label = speaker_map[dominant_raw]
424
- except Exception as e:
425
- print("Diarization failed:", e)
426
-
427
- # Skip if segment too short
428
- if len(segment_audio) < int(0.25 * SAMPLE_RATE):
429
- state["buffer"] = buffer
430
- state["total_audio_samples"] = total_audio_samples
431
- state["unprocessed_audio"] = unprocessed_audio
432
- return full_transcript, state
433
-
434
- # Process segment with Whisper
435
- inputs = processor(segment_audio.copy(), sampling_rate=SAMPLE_RATE, return_tensors="pt")
436
  input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
437
-
438
- # Generate with optimized settings for streaming (reduced beam search)
439
  with torch.no_grad():
440
  predicted_ids = model.generate(
441
  input_features,
442
  max_new_tokens=128,
443
- num_beams=1, # Greedy decoding for speed in streaming
444
- no_repeat_ngram_size=4,
445
- repetition_penalty=1.3,
446
- length_penalty=0.7,
447
  temperature=0.0,
448
  do_sample=False,
449
- early_stopping=True,
450
- suppress_tokens=[1, 2, 7, 9],
451
- forced_decoder_ids=None,
452
  )
453
-
454
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
455
- text = transcription[0].strip()
456
-
457
- # Sentence-by-sentence commit logic
458
- if not text:
459
- state["buffer"] = buffer
460
- state["total_audio_samples"] = total_audio_samples
461
- state["processed_samples"] = processed_samples
462
- state["unprocessed_audio"] = unprocessed_audio
463
- return full_transcript, state
464
-
465
- # Split into sentences
466
- ends_with_punct = bool(re.search(r"[.!?]\s*$", text))
467
- parts = sentence_split_re.split(text)
468
-
469
- if ends_with_punct:
470
- finished = parts
471
- else:
472
- finished = parts[:-1]
473
-
474
- # Process finished sentences
475
- for snt in finished:
476
- snt = snt.strip()
477
- if not snt:
478
- continue
479
-
480
- # Skip if exact duplicate
481
- if snt in seen_texts:
482
- continue
483
-
484
- # Skip if near-duplicate of last transcription
485
- if last_transcription and is_near_duplicate(snt, last_transcription, threshold=0.75):
486
- continue
487
-
488
- # Check for similar existing entries (O(n) but necessary for quality)
489
- is_duplicate = False
490
- for idx, entry in enumerate(entries):
491
- if is_near_duplicate(snt, entry["text"], threshold=0.7):
492
- # If new sentence is longer, upgrade the old one
493
- if len(snt) > len(entry["text"]):
494
- entries[idx] = {
495
- "text": snt,
496
- "start": entry["start"],
497
- "end": end_time,
498
- "speaker": speaker_label,
499
- }
500
- seen_texts.discard(entry["text"])
501
- seen_texts.add(snt)
502
- is_duplicate = True
503
- break
504
-
505
- if is_duplicate:
506
- last_transcription = snt
507
- continue
508
-
509
- # Add new entry
510
- entry = {
511
- "text": snt,
512
- "start": start_time,
513
- "end": end_time,
514
- "speaker": speaker_label,
515
- }
516
- entries.append(entry)
517
- seen_texts.add(snt)
518
- last_transcription = snt
519
-
520
- # Build formatted transcript
521
- lines = []
522
- for entry in entries:
523
- ts = format_timestamp(entry["start"])
524
- speaker = entry["speaker"]
525
- text_out = entry["text"]
526
- if text_out:
527
- lines.append(f"[{ts}] {speaker}: {text_out}")
528
-
529
- full_transcript = "\n".join(lines)
530
-
531
- # Update state (create new dict to avoid mutation)
532
- # Convert deque to list and set to list for Gradio compatibility
533
- state = {
534
- "buffer": list(buffer), # Convert deque to list for Gradio
535
- "full_transcript": full_transcript,
536
- "last_transcription": last_transcription,
537
- "entries": entries,
538
- "processed_samples": processed_samples,
539
- "total_audio_samples": total_audio_samples,
540
- "speaker_map": speaker_map,
541
- "next_speaker_idx": next_speaker_idx,
542
- "seen_texts": list(seen_texts), # Convert set to list for Gradio
543
- "unprocessed_audio": unprocessed_audio,
544
- }
545
-
546
- return full_transcript, state
547
-
548
  except Exception as e:
549
- print("Error in stream_transcribe:")
550
- print(traceback.format_exc())
551
- return full_transcript, state
552
 
553
 
554
- # -------------------------
555
- # Reset helper
556
- # -------------------------
557
- def reset_transcript(state):
558
- state = create_initial_state()
559
- return "", state
560
-
561
-
562
- def transcribe_uploaded_file(file, state):
563
- """
564
- High-accuracy transcription for uploaded audio file.
565
- Uses larger beam search for better quality.
566
- """
567
  if file is None:
568
- return state.get("full_transcript", ""), state
569
-
570
- path = getattr(file, "name", None) or file
571
  try:
572
- audio = AudioSegment.from_file(path)
573
- audio = audio.set_channels(1)
574
- audio = audio.set_frame_rate(SAMPLE_RATE)
575
- samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
576
 
577
- # Normalize based on sample width
578
  if audio.sample_width == 2:
579
  samples /= 32768.0
580
  elif audio.sample_width == 4:
581
  samples /= 2147483648.0
582
 
583
- # Clip to valid range
584
  samples = np.clip(samples, -1.0, 1.0)
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  except Exception as e:
587
- print("Error loading uploaded audio file:", e)
588
- return state.get("full_transcript", ""), state
589
-
590
- # Fresh state for file
591
- state = create_initial_state()
592
-
593
- # Process in 30-second chunks
594
- chunk_sec = 30.0
595
- chunk_size = int(SAMPLE_RATE * chunk_sec)
596
- texts = []
597
-
598
- for start in range(0, len(samples), chunk_size):
599
- chunk = samples[start:start + chunk_size]
600
- if len(chunk) < int(0.5 * SAMPLE_RATE): # Skip very short chunks
601
- continue
602
-
603
- inputs = processor(
604
- chunk,
605
- sampling_rate=SAMPLE_RATE,
606
- return_tensors="pt",
607
- )
608
- input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
609
-
610
- with torch.no_grad():
611
- predicted_ids = model.generate(
612
- input_features,
613
- max_new_tokens=256,
614
- num_beams=5, # Higher beam search for file upload quality
615
- no_repeat_ngram_size=4,
616
- repetition_penalty=1.3,
617
- length_penalty=0.7,
618
- temperature=0.0,
619
- do_sample=False,
620
- early_stopping=True,
621
- suppress_tokens=[1, 2, 7, 9],
622
- forced_decoder_ids=None,
623
- )
624
 
625
- text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
626
- if text:
627
- texts.append(text)
628
-
629
- full_text = " ".join(texts).strip()
630
- duration_sec = len(samples) / SAMPLE_RATE if len(samples) > 0 else 0.0
631
-
632
- entries = []
633
- if full_text:
634
- entries.append({
635
- "text": full_text,
636
- "start": 0.0,
637
- "end": duration_sec,
638
- "speaker": "Speaker 1",
639
- })
640
- formatted = f"[{format_timestamp(0.0)}] Speaker 1: {full_text}"
641
- else:
642
- formatted = ""
643
 
644
- state["entries"] = entries
645
- state["full_transcript"] = formatted
646
- state["last_transcription"] = full_text
647
- state["total_audio_samples"] = len(samples)
648
- state["seen_texts"] = [full_text] if full_text else [] # List instead of set
649
-
650
- return formatted, state
651
 
652
 
653
  # -------------------------
654
  # Gradio UI
655
  # -------------------------
656
- with gr.Blocks(title="🎀 Whisper ASR", theme=gr.themes.Soft()) as demo:
657
  gr.Markdown(
658
  """
659
- # 🎀 Whisper Real-Time ASR
660
 
661
- **πŸ’‘ How to use:**
662
- 1. Click the **microphone icon** to start recording
663
- 2. See real-time transcription below
664
- 3. Click **Clear** to reset the transcript
665
- 4. Click **Copy** to copy the transcript to clipboard
666
-
667
- Using OpenAI Whisper-large-v3-turbo with optimized streaming performance.
668
  """
669
  )
670
-
671
  with gr.Row():
672
- with gr.Column(scale=1):
673
- source_selector = gr.Radio(
674
- choices=["Microphone (live)", "Upload audio file"],
675
- value="Microphone (live)",
676
- label="Audio source",
677
  )
678
- mic_input = gr.Audio(
 
679
  sources=["microphone"],
680
  type="numpy",
681
  streaming=True,
682
- label="πŸŽ™οΈ Speak with your microphone",
683
- visible=True,
684
  )
 
685
  file_input = gr.File(
686
- label="πŸ“ Upload audio file",
687
  file_types=["audio"],
688
- file_count="single",
689
- visible=False,
690
  )
691
- transcribe_file_btn = gr.Button(
692
- "Transcribe Uploaded File", variant="secondary", visible=False
 
 
693
  )
694
- clear_btn = gr.Button("πŸ—‘οΈ Clear Transcript", variant="secondary")
695
- with gr.Column(scale=2):
696
- output_box = gr.Textbox(
697
- label="πŸ“„ Full Transcription",
698
- lines=10,
699
- interactive=False,
700
- show_copy_button=True
 
701
  )
702
-
703
- state = gr.State(create_initial_state())
704
-
705
- def _update_source_ui(source_choice):
706
- use_mic = source_choice.startswith("Microphone")
 
 
707
  return (
708
- gr.update(visible=use_mic),
709
- gr.update(visible=not use_mic),
710
- gr.update(visible=not use_mic),
711
  )
712
-
713
- source_selector.change(
714
- _update_source_ui,
715
- inputs=source_selector,
716
- outputs=[mic_input, file_input, transcribe_file_btn],
717
  )
718
-
719
- mic_input.stream(
720
- fn=stream_transcribe,
721
- inputs=[mic_input, state],
722
- outputs=[output_box, state],
 
723
  )
724
-
725
- transcribe_file_btn.click(
726
- fn=transcribe_uploaded_file,
727
- inputs=[file_input, state],
728
- outputs=[output_box, state],
 
729
  )
730
-
 
731
  clear_btn.click(
732
- fn=reset_transcript,
733
- inputs=state,
734
- outputs=[output_box, state],
735
  )
736
 
737
  if __name__ == "__main__":
738
- # Launch without show_api parameter to avoid schema generation bug
739
- # in some Gradio versions
740
- try:
741
- demo.launch(share=True, show_api=False)
742
- except TypeError:
743
- # Fallback if show_api causes issues
744
- demo.launch(share=True)
 
1
  import os
 
 
2
  import numpy as np
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
6
  import spaces
7
  import traceback
 
 
 
8
  from pydub import AudioSegment
9
 
 
 
 
 
 
 
 
10
  # -------------------------
11
+ # Model Loading
12
  # -------------------------
13
+ print("πŸš€ Loading Whisper model...")
 
14
 
15
  model_id = "openai/whisper-large-v3-turbo"
 
 
16
  DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
17
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
18
 
 
27
  model.to(DEVICE)
28
  model.eval()
29
 
 
 
 
 
 
 
 
30
  processor = AutoProcessor.from_pretrained(model_id)
31
+ print(f"βœ… Model loaded on {DEVICE}")
32
 
33
  # -------------------------
34
+ # Constants
35
  # -------------------------
36
  SAMPLE_RATE = 16000
37
+ BUFFER_SECONDS = 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def simple_resample(audio, orig_sr, target_sr=16000):
41
+ """Simple resampling using linear interpolation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if orig_sr == target_sr:
43
  return audio
44
  duration = len(audio) / orig_sr
 
46
  if target_length == 0:
47
  return np.array([], dtype=np.float32)
48
  indices = np.linspace(0, len(audio) - 1, target_length)
49
+ return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32)
 
50
 
51
 
 
 
 
52
  @spaces.GPU
53
+ def transcribe_audio(audio_chunk, history):
54
  """
55
+ Simple streaming transcription.
56
+ audio_chunk: (sample_rate, audio_data) from Gradio
57
+ history: accumulated audio buffer as numpy array
58
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  try:
60
+ if audio_chunk is None:
61
+ return history, ""
62
+
63
+ # Parse audio
64
+ if isinstance(audio_chunk, tuple):
65
+ sr, data = audio_chunk
66
+ else:
67
+ return history, ""
68
+
69
+ if data is None or len(data) == 0:
70
+ return history, ""
71
+
72
+ # Convert to mono float32
73
  data = np.asarray(data, dtype=np.float32)
 
 
74
  if data.ndim == 2:
75
  data = np.mean(data, axis=1)
76
+
77
+ # Normalize if needed
78
  if data.dtype == np.int16:
79
  data = data.astype(np.float32) / 32768.0
80
  elif data.dtype == np.int32:
81
  data = data.astype(np.float32) / 2147483648.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ data = np.clip(data, -1.0, 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # Resample if needed
86
+ if sr != SAMPLE_RATE:
87
+ data = simple_resample(data, sr, SAMPLE_RATE)
 
 
88
 
89
+ # Initialize history if needed
90
+ if history is None or len(history) == 0:
91
+ history = data
92
+ else:
93
+ history = np.concatenate([history, data])
94
 
95
+ # Keep only last N seconds
96
+ max_samples = SAMPLE_RATE * BUFFER_SECONDS
97
+ if len(history) > max_samples:
98
+ history = history[-max_samples:]
99
 
100
+ # Need minimum audio to transcribe
101
+ if len(history) < SAMPLE_RATE * 0.5: # 0.5 seconds minimum
102
+ return history, ""
 
103
 
104
+ # Transcribe the buffer
105
+ inputs = processor(
106
+ history,
107
+ sampling_rate=SAMPLE_RATE,
108
+ return_tensors="pt"
109
+ )
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
112
+
 
113
  with torch.no_grad():
114
  predicted_ids = model.generate(
115
  input_features,
116
  max_new_tokens=128,
117
+ num_beams=1, # Greedy for speed
 
 
 
118
  temperature=0.0,
119
  do_sample=False,
 
 
 
120
  )
121
+
122
+ text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
123
+
124
+ return history, text
125
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
+ print(f"Error: {e}")
128
+ traceback.print_exc()
129
+ return history if history is not None else np.array([]), ""
130
 
131
 
132
+ def transcribe_file(file):
133
+ """Transcribe an uploaded audio file."""
 
 
 
 
 
 
 
 
 
 
 
134
  if file is None:
135
+ return ""
136
+
 
137
  try:
138
+ # Load audio file
139
+ audio = AudioSegment.from_file(file.name)
140
+ audio = audio.set_channels(1).set_frame_rate(SAMPLE_RATE)
 
141
 
142
+ samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
143
  if audio.sample_width == 2:
144
  samples /= 32768.0
145
  elif audio.sample_width == 4:
146
  samples /= 2147483648.0
147
 
 
148
  samples = np.clip(samples, -1.0, 1.0)
149
 
150
+ # Process in chunks
151
+ chunk_size = SAMPLE_RATE * 30 # 30 second chunks
152
+ texts = []
153
+
154
+ for start in range(0, len(samples), chunk_size):
155
+ chunk = samples[start:start + chunk_size]
156
+ if len(chunk) < SAMPLE_RATE * 0.5:
157
+ continue
158
+
159
+ inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt")
160
+ input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
161
+
162
+ with torch.no_grad():
163
+ predicted_ids = model.generate(
164
+ input_features,
165
+ max_new_tokens=256,
166
+ num_beams=5, # Better quality for files
167
+ temperature=0.0,
168
+ do_sample=False,
169
+ )
170
+
171
+ text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
172
+ if text:
173
+ texts.append(text)
174
+
175
+ return " ".join(texts)
176
+
177
  except Exception as e:
178
+ print(f"File transcription error: {e}")
179
+ traceback.print_exc()
180
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ def clear_history():
184
+ """Reset everything."""
185
+ return np.array([]), ""
 
 
 
 
186
 
187
 
188
  # -------------------------
189
  # Gradio UI
190
  # -------------------------
191
+ with gr.Blocks(title="🎀 Whisper ASR") as demo:
192
  gr.Markdown(
193
  """
194
+ # 🎀 Whisper Real-Time Transcription
195
 
196
+ **How to use:**
197
+ - **Microphone**: Click to record, speak, see live transcription
198
+ - **File Upload**: Upload audio file and click "Transcribe"
199
+ - **Clear**: Reset the transcription
200
+
201
+ Using Whisper-large-v3-turbo
 
202
  """
203
  )
204
+
205
  with gr.Row():
206
+ with gr.Column():
207
+ source = gr.Radio(
208
+ ["Microphone", "Upload File"],
209
+ value="Microphone",
210
+ label="Audio Source"
211
  )
212
+
213
+ mic = gr.Audio(
214
  sources=["microphone"],
215
  type="numpy",
216
  streaming=True,
217
+ label="πŸŽ™οΈ Microphone",
218
+ visible=True
219
  )
220
+
221
  file_input = gr.File(
222
+ label="πŸ“ Upload Audio",
223
  file_types=["audio"],
224
+ visible=False
 
225
  )
226
+
227
+ transcribe_btn = gr.Button(
228
+ "Transcribe File",
229
+ visible=False
230
  )
231
+
232
+ clear_btn = gr.Button("πŸ—‘οΈ Clear")
233
+
234
+ with gr.Column():
235
+ output = gr.Textbox(
236
+ label="πŸ“„ Transcription",
237
+ lines=12,
238
+ interactive=False
239
  )
240
+
241
+ # State: just the audio buffer
242
+ audio_history = gr.State(np.array([]))
243
+
244
+ # Toggle UI based on source
245
+ def update_ui(choice):
246
+ is_mic = choice == "Microphone"
247
  return (
248
+ gr.update(visible=is_mic),
249
+ gr.update(visible=not is_mic),
250
+ gr.update(visible=not is_mic)
251
  )
252
+
253
+ source.change(
254
+ update_ui,
255
+ inputs=source,
256
+ outputs=[mic, file_input, transcribe_btn]
257
  )
258
+
259
+ # Streaming mic input
260
+ mic.stream(
261
+ transcribe_audio,
262
+ inputs=[mic, audio_history],
263
+ outputs=[audio_history, output]
264
  )
265
+
266
+ # File transcription
267
+ transcribe_btn.click(
268
+ transcribe_file,
269
+ inputs=file_input,
270
+ outputs=output
271
  )
272
+
273
+ # Clear button
274
  clear_btn.click(
275
+ clear_history,
276
+ outputs=[audio_history, output]
 
277
  )
278
 
279
  if __name__ == "__main__":
280
+ demo.launch(share=True)