Nymbo commited on
Commit
e2b9281
·
verified ·
1 Parent(s): 63941a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +614 -0
app.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import tempfile
4
+ import os
5
+ import time
6
+ import datetime
7
+ import csv
8
+ import warnings
9
+ import numpy as np
10
+
11
+ # Suppress expected warnings
12
+ warnings.filterwarnings("ignore", message=".*deprecated.*")
13
+ warnings.filterwarnings("ignore", message=".*torch.cuda.*")
14
+
15
+ # Lazy imports for heavy dependencies
16
+ _NEMO_IMPORT_ERROR = None
17
+ try:
18
+ from nemo.collections.asr.models import ASRModel
19
+ except Exception as e:
20
+ ASRModel = None
21
+ _NEMO_IMPORT_ERROR = str(e)
22
+
23
+ try:
24
+ from pydub import AudioSegment
25
+ except ImportError:
26
+ AudioSegment = None
27
+
28
+ try:
29
+ import yt_dlp as youtube_dl
30
+ except ImportError:
31
+ youtube_dl = None
32
+
33
+ # Model configuration
34
+ MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v3"
35
+ SAMPLE_RATE = 16000 # Parakeet expects 16kHz audio
36
+ LONG_AUDIO_THRESHOLD_S = 480 # 8 minutes - switch to local attention
37
+ YT_LENGTH_LIMIT_S = 3600 # Limit YouTube videos to 1 hour
38
+
39
+ # Detect if running on Hugging Face Spaces (YouTube won't work there due to network restrictions)
40
+ IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
41
+
42
+ # Supported languages (auto-detected by the model)
43
+ SUPPORTED_LANGUAGES = [
44
+ "Bulgarian (bg)", "Croatian (hr)", "Czech (cs)", "Danish (da)",
45
+ "Dutch (nl)", "English (en)", "Estonian (et)", "Finnish (fi)",
46
+ "French (fr)", "German (de)", "Greek (el)", "Hungarian (hu)",
47
+ "Italian (it)", "Latvian (lv)", "Lithuanian (lt)", "Maltese (mt)",
48
+ "Polish (pl)", "Portuguese (pt)", "Romanian (ro)", "Slovak (sk)",
49
+ "Slovenian (sl)", "Spanish (es)", "Swedish (sv)", "Russian (ru)",
50
+ "Ukrainian (uk)"
51
+ ]
52
+
53
+ # Lazy load state for the Parakeet model
54
+ _PARAKEET_STATE = {"initialized": False, "model": None, "device": "cpu"}
55
+
56
+
57
+ def _init_parakeet() -> None:
58
+ """Initialize the Parakeet model lazily on first use."""
59
+ if _PARAKEET_STATE["initialized"]:
60
+ return
61
+
62
+ if ASRModel is None:
63
+ error_msg = _NEMO_IMPORT_ERROR or "Unknown import error"
64
+ raise gr.Error(
65
+ f"NeMo toolkit import failed: {error_msg}. "
66
+ "Please run: pip install nemo_toolkit[asr]"
67
+ )
68
+
69
+ # Detect device
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+
72
+ print(f"Initializing Parakeet model on device: {device}")
73
+
74
+ try:
75
+ model = ASRModel.from_pretrained(model_name=MODEL_NAME)
76
+ model.eval()
77
+
78
+ if device == "cuda":
79
+ model.to("cuda")
80
+ model.to(torch.bfloat16)
81
+
82
+ _PARAKEET_STATE.update({
83
+ "initialized": True,
84
+ "model": model,
85
+ "device": device,
86
+ })
87
+ print("Parakeet model initialized successfully.")
88
+
89
+ except Exception as e:
90
+ raise gr.Error(f"Failed to initialize Parakeet model: {str(e)[:200]}")
91
+
92
+
93
+ def get_device_info() -> str:
94
+ """Get the current device being used for inference."""
95
+ if _PARAKEET_STATE["initialized"]:
96
+ return _PARAKEET_STATE["device"]
97
+ return "cuda" if torch.cuda.is_available() else "cpu"
98
+
99
+
100
+ def _load_and_preprocess_audio(audio_path: str) -> tuple[str, float]:
101
+ """
102
+ Load audio file, resample to 16kHz mono if needed.
103
+ Returns (processed_path, duration_seconds).
104
+ """
105
+ if AudioSegment is None:
106
+ raise gr.Error("pydub not installed. Please run: pip install pydub")
107
+
108
+ audio = AudioSegment.from_file(audio_path)
109
+ duration_sec = audio.duration_seconds
110
+
111
+ needs_processing = False
112
+
113
+ # Resample to 16kHz if needed
114
+ if audio.frame_rate != SAMPLE_RATE:
115
+ audio = audio.set_frame_rate(SAMPLE_RATE)
116
+ needs_processing = True
117
+
118
+ # Convert to mono if stereo or multi-channel
119
+ if audio.channels > 1:
120
+ audio = audio.set_channels(1)
121
+ needs_processing = True
122
+
123
+ if needs_processing:
124
+ # Export to temp file
125
+ temp_dir = tempfile.mkdtemp()
126
+ processed_path = os.path.join(temp_dir, "processed_audio.wav")
127
+ audio.export(processed_path, format="wav")
128
+ return processed_path, duration_sec
129
+ else:
130
+ return audio_path, duration_sec
131
+
132
+
133
+ def _format_srt_time(seconds: float) -> str:
134
+ """Convert seconds to SRT time format HH:MM:SS,mmm."""
135
+ sanitized = max(0.0, seconds)
136
+ delta = datetime.timedelta(seconds=sanitized)
137
+ total_int_seconds = int(delta.total_seconds())
138
+
139
+ hours = total_int_seconds // 3600
140
+ minutes = (total_int_seconds % 3600) // 60
141
+ secs = total_int_seconds % 60
142
+ ms = delta.microseconds // 1000
143
+
144
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"
145
+
146
+
147
+ def _generate_srt_content(segment_timestamps: list) -> str:
148
+ """Generate SRT formatted string from segment timestamps."""
149
+ srt_lines = []
150
+ for i, ts in enumerate(segment_timestamps):
151
+ start_time = _format_srt_time(ts['start'])
152
+ end_time = _format_srt_time(ts['end'])
153
+ text = ts['segment']
154
+ srt_lines.append(str(i + 1))
155
+ srt_lines.append(f"{start_time} --> {end_time}")
156
+ srt_lines.append(text)
157
+ srt_lines.append("")
158
+ return "\n".join(srt_lines)
159
+
160
+
161
+ def _generate_csv_content(segment_timestamps: list) -> str:
162
+ """Generate CSV formatted string from segment timestamps."""
163
+ import io
164
+ output = io.StringIO()
165
+ writer = csv.writer(output)
166
+ writer.writerow(["Start (s)", "End (s)", "Segment"])
167
+ for ts in segment_timestamps:
168
+ writer.writerow([f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']])
169
+ return output.getvalue()
170
+
171
+
172
+ def transcribe_audio(
173
+ audio_path: str,
174
+ return_timestamps: bool,
175
+ timestamp_level: str,
176
+ ):
177
+ """
178
+ Transcribe audio file using Parakeet.
179
+
180
+ Args:
181
+ audio_path: Path to the audio file
182
+ return_timestamps: Whether to include timestamps
183
+ timestamp_level: Level of timestamps ("word", "segment", or "char")
184
+
185
+ Returns:
186
+ Tuple of (transcription_text, csv_file_path, srt_file_path)
187
+ """
188
+ if not audio_path:
189
+ raise gr.Error("Please provide an audio file to transcribe.")
190
+
191
+ # Initialize model on first use
192
+ _init_parakeet()
193
+ model = _PARAKEET_STATE["model"]
194
+ device = _PARAKEET_STATE["device"]
195
+
196
+ processed_path = None
197
+ long_audio_settings_applied = False
198
+
199
+ try:
200
+ # Preprocess audio
201
+ gr.Info("Loading and preprocessing audio...")
202
+ processed_path, duration_sec = _load_and_preprocess_audio(audio_path)
203
+
204
+ # Apply long audio settings if needed
205
+ if duration_sec > LONG_AUDIO_THRESHOLD_S:
206
+ gr.Info(f"Audio is {duration_sec:.0f}s (>{LONG_AUDIO_THRESHOLD_S}s). Applying local attention for long audio.")
207
+ try:
208
+ model.change_attention_model("rel_pos_local_attn", [256, 256])
209
+ model.change_subsampling_conv_chunking_factor(1)
210
+ long_audio_settings_applied = True
211
+ except Exception as e:
212
+ gr.Warning(f"Could not apply long audio settings: {e}")
213
+
214
+ # Ensure model is on correct device with correct dtype
215
+ if device == "cuda":
216
+ model.to("cuda")
217
+ model.to(torch.bfloat16)
218
+ else:
219
+ model.to("cpu")
220
+ model.to(torch.float32)
221
+
222
+ # Transcribe
223
+ gr.Info("Transcribing audio...")
224
+ print(f"DEBUG: Calling transcribe with timestamps={return_timestamps}")
225
+ output = model.transcribe([processed_path], timestamps=return_timestamps)
226
+ print(f"DEBUG: Transcription complete, got output type: {type(output)}")
227
+
228
+ if not output or not isinstance(output, list) or not output[0]:
229
+ raise gr.Error("Transcription failed or produced unexpected output.")
230
+
231
+ # Extract text
232
+ transcription_text = output[0].text if hasattr(output[0], 'text') else str(output[0])
233
+ print(f"DEBUG: Extracted text: {transcription_text[:100] if transcription_text else 'empty'}...")
234
+
235
+ # Handle timestamps
236
+ csv_path = None
237
+ srt_path = None
238
+
239
+ if return_timestamps and hasattr(output[0], 'timestamp') and output[0].timestamp:
240
+ timestamps = output[0].timestamp
241
+
242
+ # Get timestamps at the requested level
243
+ if timestamp_level in timestamps:
244
+ ts_data = timestamps[timestamp_level]
245
+
246
+ # Format text with timestamps
247
+ if timestamp_level == "segment":
248
+ lines = []
249
+ for ts in ts_data:
250
+ start = ts.get('start', 0)
251
+ end = ts.get('end', 0)
252
+ text = ts.get('segment', '')
253
+ lines.append(f"[{start:.2f}s - {end:.2f}s] {text}")
254
+ transcription_text = "\n".join(lines)
255
+
256
+ # Generate download files
257
+ temp_dir = tempfile.mkdtemp()
258
+
259
+ # CSV
260
+ csv_content = _generate_csv_content(ts_data)
261
+ csv_path = os.path.join(temp_dir, "transcription.csv")
262
+ with open(csv_path, 'w', encoding='utf-8') as f:
263
+ f.write(csv_content)
264
+
265
+ # SRT
266
+ srt_content = _generate_srt_content(ts_data)
267
+ srt_path = os.path.join(temp_dir, "transcription.srt")
268
+ with open(srt_path, 'w', encoding='utf-8') as f:
269
+ f.write(srt_content)
270
+
271
+ elif timestamp_level == "word":
272
+ lines = []
273
+ for ts in ts_data:
274
+ start = ts.get('start', 0)
275
+ end = ts.get('end', 0)
276
+ word = ts.get('word', '')
277
+ lines.append(f"[{start:.2f}s] {word}")
278
+ transcription_text = "\n".join(lines)
279
+
280
+ elif timestamp_level == "char":
281
+ lines = []
282
+ for ts in ts_data:
283
+ start = ts.get('start', 0)
284
+ char = ts.get('char', '')
285
+ lines.append(f"[{start:.3f}s] {char}")
286
+ transcription_text = "\n".join(lines)
287
+
288
+ gr.Info("Transcription complete!")
289
+ print(f"DEBUG: Returning transcription of length {len(transcription_text)}")
290
+
291
+ # Return with download buttons visibility using gr.update()
292
+ return (
293
+ transcription_text,
294
+ gr.update(value=csv_path, visible=csv_path is not None),
295
+ gr.update(value=srt_path, visible=srt_path is not None),
296
+ )
297
+
298
+ except gr.Error:
299
+ raise
300
+ except torch.cuda.OutOfMemoryError:
301
+ raise gr.Error("CUDA out of memory. Please try a shorter audio file.")
302
+ except Exception as e:
303
+ raise gr.Error(f"Transcription failed: {str(e)[:200]}")
304
+
305
+ finally:
306
+ # Revert long audio settings
307
+ if long_audio_settings_applied:
308
+ try:
309
+ model.change_attention_model("rel_pos")
310
+ model.change_subsampling_conv_chunking_factor(-1)
311
+ except Exception:
312
+ pass
313
+
314
+ # Clean up temp file
315
+ if processed_path and processed_path != audio_path:
316
+ try:
317
+ os.remove(processed_path)
318
+ os.rmdir(os.path.dirname(processed_path))
319
+ except Exception:
320
+ pass
321
+
322
+ # Note: We intentionally keep the model on GPU to avoid reload overhead
323
+ # The model will be reused for subsequent transcriptions
324
+
325
+
326
+ def _get_yt_html_embed(yt_url: str) -> str:
327
+ """Generate YouTube embed HTML for display."""
328
+ video_id = yt_url.split("?v=")[-1].split("&")[0]
329
+ return (
330
+ f'<center><iframe width="500" height="320" '
331
+ f'src="https://www.youtube.com/embed/{video_id}"></iframe></center>'
332
+ )
333
+
334
+
335
+ def _download_yt_audio(yt_url: str, filepath: str) -> None:
336
+ """Download audio from a YouTube URL."""
337
+ if youtube_dl is None:
338
+ raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp")
339
+
340
+ info_loader = youtube_dl.YoutubeDL()
341
+
342
+ try:
343
+ info = info_loader.extract_info(yt_url, download=False)
344
+ except youtube_dl.utils.DownloadError as err:
345
+ err_str = str(err)
346
+ if "Failed to resolve" in err_str or "No address associated" in err_str:
347
+ raise gr.Error(
348
+ "YouTube download failed due to network restrictions. "
349
+ "This feature requires running the app locally. "
350
+ "On Hugging Face Spaces, outbound connections to YouTube are blocked."
351
+ )
352
+ raise gr.Error(str(err))
353
+
354
+ # Parse duration
355
+ file_length = info.get("duration_string", "0")
356
+ file_h_m_s = file_length.split(":")
357
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
358
+
359
+ if len(file_h_m_s) == 1:
360
+ file_h_m_s.insert(0, 0)
361
+ if len(file_h_m_s) == 2:
362
+ file_h_m_s.insert(0, 0)
363
+
364
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
365
+
366
+ if file_length_s > YT_LENGTH_LIMIT_S:
367
+ yt_limit_hms = time.strftime("%H:%M:%S", time.gmtime(YT_LENGTH_LIMIT_S))
368
+ file_hms = time.strftime("%H:%M:%S", time.gmtime(file_length_s))
369
+ raise gr.Error(f"Maximum YouTube length is {yt_limit_hms}, got {file_hms}.")
370
+
371
+ ydl_opts = {
372
+ "outtmpl": filepath,
373
+ "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
374
+ }
375
+
376
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
377
+ try:
378
+ ydl.download([yt_url])
379
+ except youtube_dl.utils.ExtractorError as err:
380
+ raise gr.Error(str(err))
381
+
382
+
383
+ def transcribe_youtube(
384
+ yt_url: str,
385
+ return_timestamps: bool,
386
+ timestamp_level: str,
387
+ ):
388
+ """
389
+ Transcribe a YouTube video.
390
+
391
+ Yields tuples of (html_embed, transcription_text) for streaming updates.
392
+ """
393
+ if not yt_url:
394
+ raise gr.Error("Please provide a YouTube URL.")
395
+
396
+ if youtube_dl is None:
397
+ raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp")
398
+
399
+ html_embed = _get_yt_html_embed(yt_url)
400
+
401
+ # Initialize model
402
+ _init_parakeet()
403
+ model = _PARAKEET_STATE["model"]
404
+ device = _PARAKEET_STATE["device"]
405
+
406
+ # Download video to temp directory
407
+ with tempfile.TemporaryDirectory() as tmpdir:
408
+ filepath = os.path.join(tmpdir, "video.mp4")
409
+
410
+ # Yield initial state while downloading
411
+ yield html_embed, "Downloading video..."
412
+
413
+ _download_yt_audio(yt_url, filepath)
414
+
415
+ yield html_embed, "Processing audio..."
416
+
417
+ # Preprocess audio
418
+ processed_path, duration_sec = _load_and_preprocess_audio(filepath)
419
+
420
+ long_audio_settings_applied = False
421
+
422
+ try:
423
+ # Apply long audio settings if needed
424
+ if duration_sec > LONG_AUDIO_THRESHOLD_S:
425
+ try:
426
+ model.change_attention_model("rel_pos_local_attn", [256, 256])
427
+ model.change_subsampling_conv_chunking_factor(1)
428
+ long_audio_settings_applied = True
429
+ except Exception:
430
+ pass
431
+
432
+ # Ensure model is on correct device
433
+ if device == "cuda":
434
+ model.to("cuda")
435
+ model.to(torch.bfloat16)
436
+ else:
437
+ model.to("cpu")
438
+ model.to(torch.float32)
439
+
440
+ yield html_embed, "Transcribing audio..."
441
+
442
+ # Transcribe
443
+ output = model.transcribe([processed_path], timestamps=return_timestamps)
444
+
445
+ if not output or not isinstance(output, list) or not output[0]:
446
+ raise gr.Error("Transcription failed or produced unexpected output.")
447
+
448
+ # Extract text
449
+ transcription_text = output[0].text if hasattr(output[0], 'text') else str(output[0])
450
+
451
+ # Handle timestamps if requested
452
+ if return_timestamps and hasattr(output[0], 'timestamp') and output[0].timestamp:
453
+ timestamps = output[0].timestamp
454
+ if timestamp_level in timestamps:
455
+ ts_data = timestamps[timestamp_level]
456
+ if timestamp_level == "segment":
457
+ lines = []
458
+ for ts in ts_data:
459
+ start = ts.get('start', 0)
460
+ end = ts.get('end', 0)
461
+ text = ts.get('segment', '')
462
+ lines.append(f"[{start:.2f}s - {end:.2f}s] {text}")
463
+ transcription_text = "\n".join(lines)
464
+ elif timestamp_level == "word":
465
+ lines = []
466
+ for ts in ts_data:
467
+ start = ts.get('start', 0)
468
+ word = ts.get('word', '')
469
+ lines.append(f"[{start:.2f}s] {word}")
470
+ transcription_text = "\n".join(lines)
471
+
472
+ yield html_embed, transcription_text
473
+
474
+ finally:
475
+ # Revert long audio settings
476
+ if long_audio_settings_applied:
477
+ try:
478
+ model.change_attention_model("rel_pos")
479
+ model.change_subsampling_conv_chunking_factor(-1)
480
+ except Exception:
481
+ pass
482
+
483
+ # Clean up temp file if different from original
484
+ if processed_path != filepath:
485
+ try:
486
+ os.remove(processed_path)
487
+ os.rmdir(os.path.dirname(processed_path))
488
+ except Exception:
489
+ pass
490
+
491
+
492
+ # Build the Gradio interface
493
+ with gr.Blocks(title="Parakeet-ASR") as demo:
494
+ # Header
495
+ gr.HTML(
496
+ f"""
497
+ <h1 style='text-align: center;'>Parakeet-ASR 🦜</h1>
498
+ <p style='text-align: center;'>
499
+ Powered by <code>nvidia/parakeet-tdt-0.6b-v3</code> on
500
+ <strong>{get_device_info().upper()}</strong>
501
+ </p>
502
+ <p style='text-align: center; font-size: 0.9em;'>
503
+ Supports 25 European languages with automatic detection, punctuation, and capitalization.
504
+ </p>
505
+ """
506
+ )
507
+
508
+ with gr.Tabs():
509
+ # Tab 1: Audio File / Microphone
510
+ with gr.TabItem("Audio File"):
511
+ with gr.Row():
512
+ with gr.Column():
513
+ audio_input = gr.Audio(
514
+ label="Audio Input",
515
+ sources=["microphone", "upload"],
516
+ type="filepath",
517
+ )
518
+
519
+ timestamps_checkbox = gr.Checkbox(
520
+ label="Return Timestamps",
521
+ value=False,
522
+ )
523
+
524
+ timestamp_level_radio = gr.Radio(
525
+ choices=["segment", "word", "char"],
526
+ value="segment",
527
+ label="Timestamp Level",
528
+ info="Level of detail for timestamps",
529
+ visible=False,
530
+ )
531
+
532
+ # Show/hide timestamp level based on checkbox
533
+ timestamps_checkbox.change(
534
+ fn=lambda x: gr.Radio(visible=x),
535
+ inputs=[timestamps_checkbox],
536
+ outputs=[timestamp_level_radio],
537
+ )
538
+
539
+ transcribe_btn = gr.Button("Transcribe", variant="primary")
540
+
541
+ with gr.Column():
542
+ audio_output = gr.Textbox(
543
+ label="Transcription",
544
+ placeholder="Transcribed text will appear here...",
545
+ lines=12,
546
+ )
547
+
548
+ with gr.Row():
549
+ download_csv_btn = gr.DownloadButton(
550
+ label="Download CSV",
551
+ visible=False,
552
+ )
553
+ download_srt_btn = gr.DownloadButton(
554
+ label="Download SRT",
555
+ visible=False,
556
+ )
557
+
558
+ transcribe_btn.click(
559
+ fn=transcribe_audio,
560
+ inputs=[audio_input, timestamps_checkbox, timestamp_level_radio],
561
+ outputs=[audio_output, download_csv_btn, download_srt_btn],
562
+ api_name="transcribe",
563
+ )
564
+
565
+ # Tab 2: YouTube (only shown when running locally)
566
+ if not IS_HF_SPACE:
567
+ with gr.TabItem("YouTube"):
568
+ with gr.Row():
569
+ with gr.Column():
570
+ yt_url_input = gr.Textbox(
571
+ label="YouTube URL",
572
+ placeholder="Paste a YouTube video URL here...",
573
+ lines=1,
574
+ )
575
+
576
+ yt_timestamps_checkbox = gr.Checkbox(
577
+ label="Return Timestamps",
578
+ value=False,
579
+ )
580
+
581
+ yt_timestamp_level_radio = gr.Radio(
582
+ choices=["segment", "word"],
583
+ value="segment",
584
+ label="Timestamp Level",
585
+ visible=False,
586
+ )
587
+
588
+ yt_timestamps_checkbox.change(
589
+ fn=lambda x: gr.Radio(visible=x),
590
+ inputs=[yt_timestamps_checkbox],
591
+ outputs=[yt_timestamp_level_radio],
592
+ )
593
+
594
+ yt_transcribe_btn = gr.Button("Transcribe YouTube", variant="primary")
595
+
596
+ with gr.Column():
597
+ yt_embed = gr.HTML(label="Video")
598
+ yt_output = gr.Textbox(
599
+ label="Transcription",
600
+ placeholder="Transcribed text will appear here...",
601
+ lines=10,
602
+ )
603
+
604
+ yt_transcribe_btn.click(
605
+ fn=transcribe_youtube,
606
+ inputs=[yt_url_input, yt_timestamps_checkbox, yt_timestamp_level_radio],
607
+ outputs=[yt_embed, yt_output],
608
+ api_name="transcribe_youtube",
609
+ )
610
+
611
+
612
+
613
+ if __name__ == "__main__":
614
+ demo.queue().launch(theme="Nymbo/Nymbo_Theme")