bekzod123 commited on
Commit
2adef92
·
1 Parent(s): 0c73b73
Files changed (1) hide show
  1. app.py +144 -57
app.py CHANGED
@@ -1,20 +1,20 @@
1
- from concurrent.futures import ThreadPoolExecutor
2
- from pathlib import Path
3
  import csv
4
  import datetime
5
  import gc
6
  import os
7
  import re
8
  import shutil
 
 
9
 
10
  import gradio as gr
11
  import gradio.themes as gr_themes
12
- from huggingface_hub import hf_hub_download
13
- from nemo.collections.asr.models import ASRModel
14
  import numpy as np
15
- from pydub import AudioSegment
16
  import spaces
17
  import torch
 
 
 
18
 
19
  try:
20
  from nemo.collections.asr.models import SortformerEncLabelModel
@@ -108,7 +108,9 @@ def get_audio_segment(audio_path, start_second, end_second):
108
  return None
109
  return frame_rate, samples
110
  except Exception as e:
111
- print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
 
 
112
  return None
113
 
114
 
@@ -160,7 +162,9 @@ def remove_dc_offset(samples: np.ndarray) -> np.ndarray:
160
  return samples - np.mean(samples, dtype=np.float32)
161
 
162
 
163
- def fft_bandpass(samples: np.ndarray, sr: int, low_hz: float, high_hz: float) -> np.ndarray:
 
 
164
  samples = np.asarray(samples, dtype=np.float32)
165
  if samples.size == 0:
166
  return samples
@@ -205,7 +209,9 @@ def spectral_denoise(
205
  mask = np.clip(mask, min_mask, 1.0)
206
 
207
  cleaned_stft = magnitude * mask * np.exp(1j * phase)
208
- cleaned = librosa.istft(cleaned_stft, hop_length=hop, win_length=n_fft, length=len(samples))
 
 
209
  return cleaned.astype(np.float32, copy=False)
210
 
211
 
@@ -289,7 +295,9 @@ def preprocess_audio_for_transcription(
289
  samples = np.clip(raw / max_abs, -1.0, 1.0)
290
 
291
  samples = remove_dc_offset(samples)
292
- samples = spectral_denoise(samples, strength=1.25, noise_percentile=15.0, min_mask=0.06)
 
 
293
  samples = fft_bandpass(samples, sr=target_sr, low_hz=120.0, high_hz=3600.0)
294
  samples = dynamic_rms_normalize(
295
  samples=samples,
@@ -332,7 +340,11 @@ def _parse_rttm_line(line: str):
332
  speaker = parts[7]
333
  if start is None or dur is None or dur <= 0:
334
  return None
335
- return {"start": start, "end": start + dur, "speaker": normalize_speaker_label(speaker)}
 
 
 
 
336
 
337
 
338
  def _parse_simple_segment_line(line: str):
@@ -360,7 +372,9 @@ def parse_diarization_output(raw_output, audio_duration_sec=None) -> list:
360
  e = _try_float(end)
361
  if s is None or e is None or e <= s:
362
  return
363
- parsed.append({"start": s, "end": e, "speaker": normalize_speaker_label(speaker)})
 
 
364
 
365
  def walk(obj):
366
  if obj is None:
@@ -413,7 +427,11 @@ def parse_diarization_output(raw_output, audio_duration_sec=None) -> list:
413
  return
414
 
415
  if isinstance(obj, (list, tuple)):
416
- if len(obj) >= 3 and _try_float(obj[0]) is not None and _try_float(obj[1]) is not None:
 
 
 
 
417
  append_seg(obj[0], obj[1], obj[2])
418
  return
419
  for item in obj:
@@ -421,7 +439,9 @@ def parse_diarization_output(raw_output, audio_duration_sec=None) -> list:
421
  return
422
 
423
  if hasattr(obj, "start") and hasattr(obj, "end"):
424
- append_seg(getattr(obj, "start"), getattr(obj, "end"), getattr(obj, "speaker", "0"))
 
 
425
 
426
  walk(raw_output)
427
 
@@ -453,7 +473,10 @@ def merge_adjacent_speaker_segments(segments: list, max_gap_sec: float = 0.15) -
453
  merged = [segments[0].copy()]
454
  for seg in segments[1:]:
455
  last = merged[-1]
456
- if seg["speaker"] == last["speaker"] and seg["start"] - last["end"] <= max_gap_sec:
 
 
 
457
  last["end"] = max(last["end"], seg["end"])
458
  else:
459
  merged.append(seg.copy())
@@ -547,7 +570,9 @@ def transcribe_default_with_timestamps(transcribe_path: str):
547
  return segments
548
 
549
 
550
- def _overlap_seconds(a_start: float, a_end: float, b_start: float, b_end: float) -> float:
 
 
551
  return max(0.0, min(a_end, b_end) - max(a_start, b_start))
552
 
553
 
@@ -555,7 +580,9 @@ def _join_tokens(tokens: list) -> str:
555
  return " ".join(t for t in tokens if t).strip()
556
 
557
 
558
- def split_asr_by_diarization_segments(asr_segments: list, diar_segments: list, asr_words: list = None) -> list:
 
 
559
  if not diar_segments:
560
  return []
561
 
@@ -634,7 +661,9 @@ def _clean_token_spacing(text: str) -> str:
634
 
635
 
636
  def _capitalize_first_alpha(text: str) -> str:
637
- return re.sub(r"^([^A-Za-z]*)([a-z])", lambda m: m.group(1) + m.group(2).upper(), text)
 
 
638
 
639
 
640
  def _capitalize_after_full_stop(text: str) -> str:
@@ -702,9 +731,13 @@ UZ_ORDINAL_TO_CARDINAL = {
702
  "o'ninchi": "o'n",
703
  "oninchi": "o'n",
704
  }
705
- UZ_MONTHS_PATTERN = r"yanvar|fevral|mart|aprel|may|iyun|iyul|avgust|sentabr|oktabr|noyabr|dekabr"
 
 
706
 
707
- _TOKEN_CORE_RE = re.compile(r"^([^A-Za-z0-9'`ʻʼ’‘]*)([A-Za-z0-9'`ʻʼ’‘]+)([^A-Za-z0-9'`ʻʼ’‘]*)$")
 
 
708
 
709
 
710
  def _normalize_uz_word(word: str) -> str:
@@ -729,7 +762,12 @@ def _normalize_uz_word(word: str) -> str:
729
  def _is_uz_number_like(word: str) -> bool:
730
  if not word:
731
  return False
732
- if word in UZ_CARDINAL or word in UZ_SCALES or word == "yuz" or word in UZ_ORDINAL_TO_CARDINAL:
 
 
 
 
 
733
  return True
734
  return re.match(r"^.+(?:inchi|nchi)$", word) is not None
735
 
@@ -882,7 +920,9 @@ def normalize_uzbek_date_forms(text: str) -> str:
882
  return text
883
 
884
 
885
- def postprocess_segment_texts(segment_timestamps: list, diarization_enabled: bool) -> list:
 
 
886
  for ts in segment_timestamps:
887
  txt = str(ts.get("segment", "") or "")
888
  txt = _clean_token_spacing(txt)
@@ -905,7 +945,9 @@ def resolve_player_audio_path(prepared_path, fallback_path: str) -> str:
905
 
906
 
907
  @spaces.GPU
908
- def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=True, use_diarization=False):
 
 
909
  if not audio_path:
910
  gr.Error("No audio file path provided for transcription.", duration=None)
911
  return (
@@ -924,8 +966,12 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
924
  original_path_name = Path(audio_path).name
925
  audio_name = Path(audio_path).stem
926
 
927
- csv_button_update = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
928
- srt_button_update = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
 
 
 
 
929
 
930
  transcribe_path = audio_path
931
  info_path_name = original_path_name
@@ -952,7 +998,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
952
  processed_audio = preprocess_audio_for_transcription(
953
  audio=audio, target_sr=16000, frame_ms=500, target_rms_db=-20.0
954
  )
955
- processed_audio_path = Path(session_dir, f"{audio_name}_asr_preprocessed.wav")
 
 
956
  processed_audio.export(processed_audio_path, format="wav")
957
  transcribe_path = processed_audio_path.as_posix()
958
  info_path_name = f"{original_path_name} (preprocessed)"
@@ -974,12 +1022,17 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
974
 
975
  if duration_sec > 480:
976
  try:
977
- gr.Info("Audio longer than 8 minutes. Applying long audio settings.", duration=3)
 
 
 
978
  model.change_attention_model("rel_pos_local_attn", [256, 256])
979
  model.change_subsampling_conv_chunking_factor(1)
980
  long_audio_settings_applied = True
981
  except Exception as setting_e:
982
- gr.Warning(f"Could not apply long audio settings: {setting_e}", duration=5)
 
 
983
 
984
  if device == "cuda":
985
  model.to(torch.bfloat16)
@@ -990,16 +1043,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
990
  try:
991
  gr.Info("Running ASR and diarization in parallel...", duration=3)
992
 
993
- diar_input_path = transcribe_path
994
- if not use_preprocessing:
995
- diar_audio_path = Path(session_dir, f"{audio_name}_diar_16k_mono.wav")
996
- diar_audio = audio
997
- if diar_audio.channels != 1:
998
- diar_audio = diar_audio.set_channels(1)
999
- if diar_audio.frame_rate != 16000:
1000
- diar_audio = diar_audio.set_frame_rate(16000)
1001
- diar_audio.export(diar_audio_path, format="wav")
1002
- diar_input_path = diar_audio_path.as_posix()
1003
 
1004
  dmodel = get_diar_model()
1005
  dmodel.to(device)
@@ -1010,9 +1054,13 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
1010
 
1011
  def _run_diar():
1012
  try:
1013
- diar_output_local = dmodel.diarize(audio=diar_input_path, batch_size=1)
 
 
1014
  except TypeError:
1015
- diar_output_local = dmodel.diarize(audio=[diar_input_path], batch_size=1)
 
 
1016
 
1017
  diar_segments_local = parse_diarization_output(
1018
  diar_output_local,
@@ -1042,17 +1090,27 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
1042
  diar_segments=diar_segments,
1043
  asr_words=asr_words,
1044
  )
1045
- segment_timestamps = merge_consecutive_transcript_rows(segment_timestamps)
 
 
1046
 
1047
  if not segment_timestamps:
1048
- gr.Warning("No aligned diarized rows. Using ASR segmentation.", duration=7)
 
 
 
1049
  segment_timestamps = asr_segments
1050
 
1051
  gr.Info("Diarization + ASR complete.", duration=2)
1052
 
1053
  except Exception as diar_e:
1054
- gr.Warning(f"Diarization failed: {diar_e}. Using standard ASR segmentation.", duration=7)
1055
- segment_timestamps = transcribe_default_with_timestamps(transcribe_path)
 
 
 
 
 
1056
  else:
1057
  segment_timestamps = transcribe_default_with_timestamps(transcribe_path)
1058
 
@@ -1070,7 +1128,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
1070
  ]
1071
  for ts in segment_timestamps
1072
  ]
1073
- raw_times_data = [[float(ts["start"]), float(ts["end"])] for ts in segment_timestamps]
 
 
1074
 
1075
  try:
1076
  csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
@@ -1082,7 +1142,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
1082
  value=csv_file_path, visible=True, label="Download Transcript (CSV)"
1083
  )
1084
  except Exception as csv_e:
1085
- gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
 
 
1086
 
1087
  if segment_timestamps:
1088
  try:
@@ -1091,10 +1153,14 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
1091
  with open(srt_file_path, "w", encoding="utf-8") as f:
1092
  f.write(srt_content)
1093
  srt_button_update = gr.DownloadButton(
1094
- value=srt_file_path, visible=True, label="Download Transcript (SRT)"
 
 
1095
  )
1096
  except Exception as srt_e:
1097
- gr.Warning(f"Failed to create transcript SRT file: {srt_e}", duration=5)
 
 
1098
 
1099
  gr.Info("Transcription complete.", duration=2)
1100
  return (
@@ -1116,7 +1182,10 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
1116
  srt_button_update,
1117
  )
1118
  except FileNotFoundError:
1119
- gr.Error(f"Audio file not found for transcription: {Path(transcribe_path).name}", duration=None)
 
 
 
1120
  return (
1121
  [["Error", "Error", "N/A", "File not found for transcription"]],
1122
  [[0.0, 0.0]],
@@ -1140,7 +1209,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
1140
  model.change_attention_model("rel_pos")
1141
  model.change_subsampling_conv_chunking_factor(-1)
1142
  except Exception as revert_e:
1143
- gr.Warning(f"Issue reverting model settings: {revert_e}", duration=5)
 
 
1144
 
1145
  if device == "cuda":
1146
  model.cpu()
@@ -1230,7 +1301,9 @@ nvidia_theme = gr_themes.Default(
1230
 
1231
  with gr.Blocks(theme=nvidia_theme) as demo:
1232
  model_display_name = MODEL_NAME.split("/")[-1] if "/" in MODEL_NAME else MODEL_NAME
1233
- gr.Markdown(f"<h1 style='text-align:center;margin:0 auto;'>Speech Transcription with {model_display_name}</h1>")
 
 
1234
  gr.HTML(article)
1235
 
1236
  current_audio_path_state = gr.State(None)
@@ -1248,18 +1321,32 @@ with gr.Blocks(theme=nvidia_theme) as demo:
1248
 
1249
  with gr.Tabs():
1250
  with gr.TabItem("Audio File"):
1251
- file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio File")
1252
- gr.Examples(examples=examples, inputs=[file_input], label="Example Audio Files")
1253
- file_transcribe_btn = gr.Button("Transcribe Uploaded File", variant="primary")
 
 
 
 
 
 
1254
 
1255
  with gr.TabItem("Microphone"):
1256
- mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio")
1257
- mic_transcribe_btn = gr.Button("Transcribe Microphone Input", variant="primary")
 
 
 
 
1258
 
1259
  gr.Markdown("---")
1260
  with gr.Row():
1261
- download_btn_csv = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
1262
- download_btn_srt = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
 
 
 
 
1263
 
1264
  vis_timestamps_df = gr.DataFrame(
1265
  headers=["Start (s)", "End (s)", "Speaker", "Segment"],
 
 
 
1
  import csv
2
  import datetime
3
  import gc
4
  import os
5
  import re
6
  import shutil
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from pathlib import Path
9
 
10
  import gradio as gr
11
  import gradio.themes as gr_themes
 
 
12
  import numpy as np
 
13
  import spaces
14
  import torch
15
+ from huggingface_hub import hf_hub_download
16
+ from nemo.collections.asr.models import ASRModel
17
+ from pydub import AudioSegment
18
 
19
  try:
20
  from nemo.collections.asr.models import SortformerEncLabelModel
 
108
  return None
109
  return frame_rate, samples
110
  except Exception as e:
111
+ print(
112
+ f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}"
113
+ )
114
  return None
115
 
116
 
 
162
  return samples - np.mean(samples, dtype=np.float32)
163
 
164
 
165
+ def fft_bandpass(
166
+ samples: np.ndarray, sr: int, low_hz: float, high_hz: float
167
+ ) -> np.ndarray:
168
  samples = np.asarray(samples, dtype=np.float32)
169
  if samples.size == 0:
170
  return samples
 
209
  mask = np.clip(mask, min_mask, 1.0)
210
 
211
  cleaned_stft = magnitude * mask * np.exp(1j * phase)
212
+ cleaned = librosa.istft(
213
+ cleaned_stft, hop_length=hop, win_length=n_fft, length=len(samples)
214
+ )
215
  return cleaned.astype(np.float32, copy=False)
216
 
217
 
 
295
  samples = np.clip(raw / max_abs, -1.0, 1.0)
296
 
297
  samples = remove_dc_offset(samples)
298
+ samples = spectral_denoise(
299
+ samples, strength=1.25, noise_percentile=15.0, min_mask=0.06
300
+ )
301
  samples = fft_bandpass(samples, sr=target_sr, low_hz=120.0, high_hz=3600.0)
302
  samples = dynamic_rms_normalize(
303
  samples=samples,
 
340
  speaker = parts[7]
341
  if start is None or dur is None or dur <= 0:
342
  return None
343
+ return {
344
+ "start": start,
345
+ "end": start + dur,
346
+ "speaker": normalize_speaker_label(speaker),
347
+ }
348
 
349
 
350
  def _parse_simple_segment_line(line: str):
 
372
  e = _try_float(end)
373
  if s is None or e is None or e <= s:
374
  return
375
+ parsed.append(
376
+ {"start": s, "end": e, "speaker": normalize_speaker_label(speaker)}
377
+ )
378
 
379
  def walk(obj):
380
  if obj is None:
 
427
  return
428
 
429
  if isinstance(obj, (list, tuple)):
430
+ if (
431
+ len(obj) >= 3
432
+ and _try_float(obj[0]) is not None
433
+ and _try_float(obj[1]) is not None
434
+ ):
435
  append_seg(obj[0], obj[1], obj[2])
436
  return
437
  for item in obj:
 
439
  return
440
 
441
  if hasattr(obj, "start") and hasattr(obj, "end"):
442
+ append_seg(
443
+ getattr(obj, "start"), getattr(obj, "end"), getattr(obj, "speaker", "0")
444
+ )
445
 
446
  walk(raw_output)
447
 
 
473
  merged = [segments[0].copy()]
474
  for seg in segments[1:]:
475
  last = merged[-1]
476
+ if (
477
+ seg["speaker"] == last["speaker"]
478
+ and seg["start"] - last["end"] <= max_gap_sec
479
+ ):
480
  last["end"] = max(last["end"], seg["end"])
481
  else:
482
  merged.append(seg.copy())
 
570
  return segments
571
 
572
 
573
+ def _overlap_seconds(
574
+ a_start: float, a_end: float, b_start: float, b_end: float
575
+ ) -> float:
576
  return max(0.0, min(a_end, b_end) - max(a_start, b_start))
577
 
578
 
 
580
  return " ".join(t for t in tokens if t).strip()
581
 
582
 
583
+ def split_asr_by_diarization_segments(
584
+ asr_segments: list, diar_segments: list, asr_words: list = None
585
+ ) -> list:
586
  if not diar_segments:
587
  return []
588
 
 
661
 
662
 
663
  def _capitalize_first_alpha(text: str) -> str:
664
+ return re.sub(
665
+ r"^([^A-Za-z]*)([a-z])", lambda m: m.group(1) + m.group(2).upper(), text
666
+ )
667
 
668
 
669
  def _capitalize_after_full_stop(text: str) -> str:
 
731
  "o'ninchi": "o'n",
732
  "oninchi": "o'n",
733
  }
734
+ UZ_MONTHS_PATTERN = (
735
+ r"yanvar|fevral|mart|aprel|may|iyun|iyul|avgust|sentabr|oktabr|noyabr|dekabr"
736
+ )
737
 
738
+ _TOKEN_CORE_RE = re.compile(
739
+ r"^([^A-Za-z0-9'`ʻʼ’‘]*)([A-Za-z0-9'`ʻʼ’‘]+)([^A-Za-z0-9'`ʻʼ’‘]*)$"
740
+ )
741
 
742
 
743
  def _normalize_uz_word(word: str) -> str:
 
762
  def _is_uz_number_like(word: str) -> bool:
763
  if not word:
764
  return False
765
+ if (
766
+ word in UZ_CARDINAL
767
+ or word in UZ_SCALES
768
+ or word == "yuz"
769
+ or word in UZ_ORDINAL_TO_CARDINAL
770
+ ):
771
  return True
772
  return re.match(r"^.+(?:inchi|nchi)$", word) is not None
773
 
 
920
  return text
921
 
922
 
923
+ def postprocess_segment_texts(
924
+ segment_timestamps: list, diarization_enabled: bool
925
+ ) -> list:
926
  for ts in segment_timestamps:
927
  txt = str(ts.get("segment", "") or "")
928
  txt = _clean_token_spacing(txt)
 
945
 
946
 
947
  @spaces.GPU
948
+ def get_transcripts_and_raw_times(
949
+ audio_path, session_dir, use_preprocessing=True, use_diarization=False
950
+ ):
951
  if not audio_path:
952
  gr.Error("No audio file path provided for transcription.", duration=None)
953
  return (
 
966
  original_path_name = Path(audio_path).name
967
  audio_name = Path(audio_path).stem
968
 
969
+ csv_button_update = gr.DownloadButton(
970
+ label="Download Transcript (CSV)", visible=False
971
+ )
972
+ srt_button_update = gr.DownloadButton(
973
+ label="Download Transcript (SRT)", visible=False
974
+ )
975
 
976
  transcribe_path = audio_path
977
  info_path_name = original_path_name
 
998
  processed_audio = preprocess_audio_for_transcription(
999
  audio=audio, target_sr=16000, frame_ms=500, target_rms_db=-20.0
1000
  )
1001
+ processed_audio_path = Path(
1002
+ session_dir, f"{audio_name}_asr_preprocessed.wav"
1003
+ )
1004
  processed_audio.export(processed_audio_path, format="wav")
1005
  transcribe_path = processed_audio_path.as_posix()
1006
  info_path_name = f"{original_path_name} (preprocessed)"
 
1022
 
1023
  if duration_sec > 480:
1024
  try:
1025
+ gr.Info(
1026
+ "Audio longer than 8 minutes. Applying long audio settings.",
1027
+ duration=3,
1028
+ )
1029
  model.change_attention_model("rel_pos_local_attn", [256, 256])
1030
  model.change_subsampling_conv_chunking_factor(1)
1031
  long_audio_settings_applied = True
1032
  except Exception as setting_e:
1033
+ gr.Warning(
1034
+ f"Could not apply long audio settings: {setting_e}", duration=5
1035
+ )
1036
 
1037
  if device == "cuda":
1038
  model.to(torch.bfloat16)
 
1043
  try:
1044
  gr.Info("Running ASR and diarization in parallel...", duration=3)
1045
 
1046
+ diar_input_path = audio_path
 
 
 
 
 
 
 
 
 
1047
 
1048
  dmodel = get_diar_model()
1049
  dmodel.to(device)
 
1054
 
1055
  def _run_diar():
1056
  try:
1057
+ diar_output_local = dmodel.diarize(
1058
+ audio=diar_input_path, batch_size=1
1059
+ )
1060
  except TypeError:
1061
+ diar_output_local = dmodel.diarize(
1062
+ audio=[diar_input_path], batch_size=1
1063
+ )
1064
 
1065
  diar_segments_local = parse_diarization_output(
1066
  diar_output_local,
 
1090
  diar_segments=diar_segments,
1091
  asr_words=asr_words,
1092
  )
1093
+ segment_timestamps = merge_consecutive_transcript_rows(
1094
+ segment_timestamps
1095
+ )
1096
 
1097
  if not segment_timestamps:
1098
+ gr.Warning(
1099
+ "No aligned diarized rows. Using ASR segmentation.",
1100
+ duration=7,
1101
+ )
1102
  segment_timestamps = asr_segments
1103
 
1104
  gr.Info("Diarization + ASR complete.", duration=2)
1105
 
1106
  except Exception as diar_e:
1107
+ gr.Warning(
1108
+ f"Diarization failed: {diar_e}. Using standard ASR segmentation.",
1109
+ duration=7,
1110
+ )
1111
+ segment_timestamps = transcribe_default_with_timestamps(
1112
+ transcribe_path
1113
+ )
1114
  else:
1115
  segment_timestamps = transcribe_default_with_timestamps(transcribe_path)
1116
 
 
1128
  ]
1129
  for ts in segment_timestamps
1130
  ]
1131
+ raw_times_data = [
1132
+ [float(ts["start"]), float(ts["end"])] for ts in segment_timestamps
1133
+ ]
1134
 
1135
  try:
1136
  csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
 
1142
  value=csv_file_path, visible=True, label="Download Transcript (CSV)"
1143
  )
1144
  except Exception as csv_e:
1145
+ gr.Error(
1146
+ f"Failed to create transcript CSV file: {csv_e}", duration=None
1147
+ )
1148
 
1149
  if segment_timestamps:
1150
  try:
 
1153
  with open(srt_file_path, "w", encoding="utf-8") as f:
1154
  f.write(srt_content)
1155
  srt_button_update = gr.DownloadButton(
1156
+ value=srt_file_path,
1157
+ visible=True,
1158
+ label="Download Transcript (SRT)",
1159
  )
1160
  except Exception as srt_e:
1161
+ gr.Warning(
1162
+ f"Failed to create transcript SRT file: {srt_e}", duration=5
1163
+ )
1164
 
1165
  gr.Info("Transcription complete.", duration=2)
1166
  return (
 
1182
  srt_button_update,
1183
  )
1184
  except FileNotFoundError:
1185
+ gr.Error(
1186
+ f"Audio file not found for transcription: {Path(transcribe_path).name}",
1187
+ duration=None,
1188
+ )
1189
  return (
1190
  [["Error", "Error", "N/A", "File not found for transcription"]],
1191
  [[0.0, 0.0]],
 
1209
  model.change_attention_model("rel_pos")
1210
  model.change_subsampling_conv_chunking_factor(-1)
1211
  except Exception as revert_e:
1212
+ gr.Warning(
1213
+ f"Issue reverting model settings: {revert_e}", duration=5
1214
+ )
1215
 
1216
  if device == "cuda":
1217
  model.cpu()
 
1301
 
1302
  with gr.Blocks(theme=nvidia_theme) as demo:
1303
  model_display_name = MODEL_NAME.split("/")[-1] if "/" in MODEL_NAME else MODEL_NAME
1304
+ gr.Markdown(
1305
+ f"<h1 style='text-align:center;margin:0 auto;'>Speech Transcription with {model_display_name}</h1>"
1306
+ )
1307
  gr.HTML(article)
1308
 
1309
  current_audio_path_state = gr.State(None)
 
1321
 
1322
  with gr.Tabs():
1323
  with gr.TabItem("Audio File"):
1324
+ file_input = gr.Audio(
1325
+ sources=["upload"], type="filepath", label="Upload Audio File"
1326
+ )
1327
+ gr.Examples(
1328
+ examples=examples, inputs=[file_input], label="Example Audio Files"
1329
+ )
1330
+ file_transcribe_btn = gr.Button(
1331
+ "Transcribe Uploaded File", variant="primary"
1332
+ )
1333
 
1334
  with gr.TabItem("Microphone"):
1335
+ mic_input = gr.Audio(
1336
+ sources=["microphone"], type="filepath", label="Record Audio"
1337
+ )
1338
+ mic_transcribe_btn = gr.Button(
1339
+ "Transcribe Microphone Input", variant="primary"
1340
+ )
1341
 
1342
  gr.Markdown("---")
1343
  with gr.Row():
1344
+ download_btn_csv = gr.DownloadButton(
1345
+ label="Download Transcript (CSV)", visible=False
1346
+ )
1347
+ download_btn_srt = gr.DownloadButton(
1348
+ label="Download Transcript (SRT)", visible=False
1349
+ )
1350
 
1351
  vis_timestamps_df = gr.DataFrame(
1352
  headers=["Start (s)", "End (s)", "Speaker", "Segment"],