fixes
Browse files
app.py
CHANGED
|
@@ -1,20 +1,20 @@
|
|
| 1 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
-
from pathlib import Path
|
| 3 |
import csv
|
| 4 |
import datetime
|
| 5 |
import gc
|
| 6 |
import os
|
| 7 |
import re
|
| 8 |
import shutil
|
|
|
|
|
|
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
import gradio.themes as gr_themes
|
| 12 |
-
from huggingface_hub import hf_hub_download
|
| 13 |
-
from nemo.collections.asr.models import ASRModel
|
| 14 |
import numpy as np
|
| 15 |
-
from pydub import AudioSegment
|
| 16 |
import spaces
|
| 17 |
import torch
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
try:
|
| 20 |
from nemo.collections.asr.models import SortformerEncLabelModel
|
|
@@ -108,7 +108,9 @@ def get_audio_segment(audio_path, start_second, end_second):
|
|
| 108 |
return None
|
| 109 |
return frame_rate, samples
|
| 110 |
except Exception as e:
|
| 111 |
-
print(
|
|
|
|
|
|
|
| 112 |
return None
|
| 113 |
|
| 114 |
|
|
@@ -160,7 +162,9 @@ def remove_dc_offset(samples: np.ndarray) -> np.ndarray:
|
|
| 160 |
return samples - np.mean(samples, dtype=np.float32)
|
| 161 |
|
| 162 |
|
| 163 |
-
def fft_bandpass(
|
|
|
|
|
|
|
| 164 |
samples = np.asarray(samples, dtype=np.float32)
|
| 165 |
if samples.size == 0:
|
| 166 |
return samples
|
|
@@ -205,7 +209,9 @@ def spectral_denoise(
|
|
| 205 |
mask = np.clip(mask, min_mask, 1.0)
|
| 206 |
|
| 207 |
cleaned_stft = magnitude * mask * np.exp(1j * phase)
|
| 208 |
-
cleaned = librosa.istft(
|
|
|
|
|
|
|
| 209 |
return cleaned.astype(np.float32, copy=False)
|
| 210 |
|
| 211 |
|
|
@@ -289,7 +295,9 @@ def preprocess_audio_for_transcription(
|
|
| 289 |
samples = np.clip(raw / max_abs, -1.0, 1.0)
|
| 290 |
|
| 291 |
samples = remove_dc_offset(samples)
|
| 292 |
-
samples = spectral_denoise(
|
|
|
|
|
|
|
| 293 |
samples = fft_bandpass(samples, sr=target_sr, low_hz=120.0, high_hz=3600.0)
|
| 294 |
samples = dynamic_rms_normalize(
|
| 295 |
samples=samples,
|
|
@@ -332,7 +340,11 @@ def _parse_rttm_line(line: str):
|
|
| 332 |
speaker = parts[7]
|
| 333 |
if start is None or dur is None or dur <= 0:
|
| 334 |
return None
|
| 335 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
|
| 338 |
def _parse_simple_segment_line(line: str):
|
|
@@ -360,7 +372,9 @@ def parse_diarization_output(raw_output, audio_duration_sec=None) -> list:
|
|
| 360 |
e = _try_float(end)
|
| 361 |
if s is None or e is None or e <= s:
|
| 362 |
return
|
| 363 |
-
parsed.append(
|
|
|
|
|
|
|
| 364 |
|
| 365 |
def walk(obj):
|
| 366 |
if obj is None:
|
|
@@ -413,7 +427,11 @@ def parse_diarization_output(raw_output, audio_duration_sec=None) -> list:
|
|
| 413 |
return
|
| 414 |
|
| 415 |
if isinstance(obj, (list, tuple)):
|
| 416 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
append_seg(obj[0], obj[1], obj[2])
|
| 418 |
return
|
| 419 |
for item in obj:
|
|
@@ -421,7 +439,9 @@ def parse_diarization_output(raw_output, audio_duration_sec=None) -> list:
|
|
| 421 |
return
|
| 422 |
|
| 423 |
if hasattr(obj, "start") and hasattr(obj, "end"):
|
| 424 |
-
append_seg(
|
|
|
|
|
|
|
| 425 |
|
| 426 |
walk(raw_output)
|
| 427 |
|
|
@@ -453,7 +473,10 @@ def merge_adjacent_speaker_segments(segments: list, max_gap_sec: float = 0.15) -
|
|
| 453 |
merged = [segments[0].copy()]
|
| 454 |
for seg in segments[1:]:
|
| 455 |
last = merged[-1]
|
| 456 |
-
if
|
|
|
|
|
|
|
|
|
|
| 457 |
last["end"] = max(last["end"], seg["end"])
|
| 458 |
else:
|
| 459 |
merged.append(seg.copy())
|
|
@@ -547,7 +570,9 @@ def transcribe_default_with_timestamps(transcribe_path: str):
|
|
| 547 |
return segments
|
| 548 |
|
| 549 |
|
| 550 |
-
def _overlap_seconds(
|
|
|
|
|
|
|
| 551 |
return max(0.0, min(a_end, b_end) - max(a_start, b_start))
|
| 552 |
|
| 553 |
|
|
@@ -555,7 +580,9 @@ def _join_tokens(tokens: list) -> str:
|
|
| 555 |
return " ".join(t for t in tokens if t).strip()
|
| 556 |
|
| 557 |
|
| 558 |
-
def split_asr_by_diarization_segments(
|
|
|
|
|
|
|
| 559 |
if not diar_segments:
|
| 560 |
return []
|
| 561 |
|
|
@@ -634,7 +661,9 @@ def _clean_token_spacing(text: str) -> str:
|
|
| 634 |
|
| 635 |
|
| 636 |
def _capitalize_first_alpha(text: str) -> str:
|
| 637 |
-
return re.sub(
|
|
|
|
|
|
|
| 638 |
|
| 639 |
|
| 640 |
def _capitalize_after_full_stop(text: str) -> str:
|
|
@@ -702,9 +731,13 @@ UZ_ORDINAL_TO_CARDINAL = {
|
|
| 702 |
"o'ninchi": "o'n",
|
| 703 |
"oninchi": "o'n",
|
| 704 |
}
|
| 705 |
-
UZ_MONTHS_PATTERN =
|
|
|
|
|
|
|
| 706 |
|
| 707 |
-
_TOKEN_CORE_RE = re.compile(
|
|
|
|
|
|
|
| 708 |
|
| 709 |
|
| 710 |
def _normalize_uz_word(word: str) -> str:
|
|
@@ -729,7 +762,12 @@ def _normalize_uz_word(word: str) -> str:
|
|
| 729 |
def _is_uz_number_like(word: str) -> bool:
|
| 730 |
if not word:
|
| 731 |
return False
|
| 732 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
return True
|
| 734 |
return re.match(r"^.+(?:inchi|nchi)$", word) is not None
|
| 735 |
|
|
@@ -882,7 +920,9 @@ def normalize_uzbek_date_forms(text: str) -> str:
|
|
| 882 |
return text
|
| 883 |
|
| 884 |
|
| 885 |
-
def postprocess_segment_texts(
|
|
|
|
|
|
|
| 886 |
for ts in segment_timestamps:
|
| 887 |
txt = str(ts.get("segment", "") or "")
|
| 888 |
txt = _clean_token_spacing(txt)
|
|
@@ -905,7 +945,9 @@ def resolve_player_audio_path(prepared_path, fallback_path: str) -> str:
|
|
| 905 |
|
| 906 |
|
| 907 |
@spaces.GPU
|
| 908 |
-
def get_transcripts_and_raw_times(
|
|
|
|
|
|
|
| 909 |
if not audio_path:
|
| 910 |
gr.Error("No audio file path provided for transcription.", duration=None)
|
| 911 |
return (
|
|
@@ -924,8 +966,12 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 924 |
original_path_name = Path(audio_path).name
|
| 925 |
audio_name = Path(audio_path).stem
|
| 926 |
|
| 927 |
-
csv_button_update = gr.DownloadButton(
|
| 928 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 929 |
|
| 930 |
transcribe_path = audio_path
|
| 931 |
info_path_name = original_path_name
|
|
@@ -952,7 +998,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 952 |
processed_audio = preprocess_audio_for_transcription(
|
| 953 |
audio=audio, target_sr=16000, frame_ms=500, target_rms_db=-20.0
|
| 954 |
)
|
| 955 |
-
processed_audio_path = Path(
|
|
|
|
|
|
|
| 956 |
processed_audio.export(processed_audio_path, format="wav")
|
| 957 |
transcribe_path = processed_audio_path.as_posix()
|
| 958 |
info_path_name = f"{original_path_name} (preprocessed)"
|
|
@@ -974,12 +1022,17 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 974 |
|
| 975 |
if duration_sec > 480:
|
| 976 |
try:
|
| 977 |
-
gr.Info(
|
|
|
|
|
|
|
|
|
|
| 978 |
model.change_attention_model("rel_pos_local_attn", [256, 256])
|
| 979 |
model.change_subsampling_conv_chunking_factor(1)
|
| 980 |
long_audio_settings_applied = True
|
| 981 |
except Exception as setting_e:
|
| 982 |
-
gr.Warning(
|
|
|
|
|
|
|
| 983 |
|
| 984 |
if device == "cuda":
|
| 985 |
model.to(torch.bfloat16)
|
|
@@ -990,16 +1043,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 990 |
try:
|
| 991 |
gr.Info("Running ASR and diarization in parallel...", duration=3)
|
| 992 |
|
| 993 |
-
diar_input_path =
|
| 994 |
-
if not use_preprocessing:
|
| 995 |
-
diar_audio_path = Path(session_dir, f"{audio_name}_diar_16k_mono.wav")
|
| 996 |
-
diar_audio = audio
|
| 997 |
-
if diar_audio.channels != 1:
|
| 998 |
-
diar_audio = diar_audio.set_channels(1)
|
| 999 |
-
if diar_audio.frame_rate != 16000:
|
| 1000 |
-
diar_audio = diar_audio.set_frame_rate(16000)
|
| 1001 |
-
diar_audio.export(diar_audio_path, format="wav")
|
| 1002 |
-
diar_input_path = diar_audio_path.as_posix()
|
| 1003 |
|
| 1004 |
dmodel = get_diar_model()
|
| 1005 |
dmodel.to(device)
|
|
@@ -1010,9 +1054,13 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 1010 |
|
| 1011 |
def _run_diar():
|
| 1012 |
try:
|
| 1013 |
-
diar_output_local = dmodel.diarize(
|
|
|
|
|
|
|
| 1014 |
except TypeError:
|
| 1015 |
-
diar_output_local = dmodel.diarize(
|
|
|
|
|
|
|
| 1016 |
|
| 1017 |
diar_segments_local = parse_diarization_output(
|
| 1018 |
diar_output_local,
|
|
@@ -1042,17 +1090,27 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 1042 |
diar_segments=diar_segments,
|
| 1043 |
asr_words=asr_words,
|
| 1044 |
)
|
| 1045 |
-
segment_timestamps = merge_consecutive_transcript_rows(
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
if not segment_timestamps:
|
| 1048 |
-
gr.Warning(
|
|
|
|
|
|
|
|
|
|
| 1049 |
segment_timestamps = asr_segments
|
| 1050 |
|
| 1051 |
gr.Info("Diarization + ASR complete.", duration=2)
|
| 1052 |
|
| 1053 |
except Exception as diar_e:
|
| 1054 |
-
gr.Warning(
|
| 1055 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1056 |
else:
|
| 1057 |
segment_timestamps = transcribe_default_with_timestamps(transcribe_path)
|
| 1058 |
|
|
@@ -1070,7 +1128,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 1070 |
]
|
| 1071 |
for ts in segment_timestamps
|
| 1072 |
]
|
| 1073 |
-
raw_times_data = [
|
|
|
|
|
|
|
| 1074 |
|
| 1075 |
try:
|
| 1076 |
csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
|
|
@@ -1082,7 +1142,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 1082 |
value=csv_file_path, visible=True, label="Download Transcript (CSV)"
|
| 1083 |
)
|
| 1084 |
except Exception as csv_e:
|
| 1085 |
-
gr.Error(
|
|
|
|
|
|
|
| 1086 |
|
| 1087 |
if segment_timestamps:
|
| 1088 |
try:
|
|
@@ -1091,10 +1153,14 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 1091 |
with open(srt_file_path, "w", encoding="utf-8") as f:
|
| 1092 |
f.write(srt_content)
|
| 1093 |
srt_button_update = gr.DownloadButton(
|
| 1094 |
-
value=srt_file_path,
|
|
|
|
|
|
|
| 1095 |
)
|
| 1096 |
except Exception as srt_e:
|
| 1097 |
-
gr.Warning(
|
|
|
|
|
|
|
| 1098 |
|
| 1099 |
gr.Info("Transcription complete.", duration=2)
|
| 1100 |
return (
|
|
@@ -1116,7 +1182,10 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 1116 |
srt_button_update,
|
| 1117 |
)
|
| 1118 |
except FileNotFoundError:
|
| 1119 |
-
gr.Error(
|
|
|
|
|
|
|
|
|
|
| 1120 |
return (
|
| 1121 |
[["Error", "Error", "N/A", "File not found for transcription"]],
|
| 1122 |
[[0.0, 0.0]],
|
|
@@ -1140,7 +1209,9 @@ def get_transcripts_and_raw_times(audio_path, session_dir, use_preprocessing=Tru
|
|
| 1140 |
model.change_attention_model("rel_pos")
|
| 1141 |
model.change_subsampling_conv_chunking_factor(-1)
|
| 1142 |
except Exception as revert_e:
|
| 1143 |
-
gr.Warning(
|
|
|
|
|
|
|
| 1144 |
|
| 1145 |
if device == "cuda":
|
| 1146 |
model.cpu()
|
|
@@ -1230,7 +1301,9 @@ nvidia_theme = gr_themes.Default(
|
|
| 1230 |
|
| 1231 |
with gr.Blocks(theme=nvidia_theme) as demo:
|
| 1232 |
model_display_name = MODEL_NAME.split("/")[-1] if "/" in MODEL_NAME else MODEL_NAME
|
| 1233 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 1234 |
gr.HTML(article)
|
| 1235 |
|
| 1236 |
current_audio_path_state = gr.State(None)
|
|
@@ -1248,18 +1321,32 @@ with gr.Blocks(theme=nvidia_theme) as demo:
|
|
| 1248 |
|
| 1249 |
with gr.Tabs():
|
| 1250 |
with gr.TabItem("Audio File"):
|
| 1251 |
-
file_input = gr.Audio(
|
| 1252 |
-
|
| 1253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1254 |
|
| 1255 |
with gr.TabItem("Microphone"):
|
| 1256 |
-
mic_input = gr.Audio(
|
| 1257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1258 |
|
| 1259 |
gr.Markdown("---")
|
| 1260 |
with gr.Row():
|
| 1261 |
-
download_btn_csv = gr.DownloadButton(
|
| 1262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1263 |
|
| 1264 |
vis_timestamps_df = gr.DataFrame(
|
| 1265 |
headers=["Start (s)", "End (s)", "Speaker", "Segment"],
|
|
|
|
|
|
|
|
|
|
| 1 |
import csv
|
| 2 |
import datetime
|
| 3 |
import gc
|
| 4 |
import os
|
| 5 |
import re
|
| 6 |
import shutil
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
from pathlib import Path
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
import gradio.themes as gr_themes
|
|
|
|
|
|
|
| 12 |
import numpy as np
|
|
|
|
| 13 |
import spaces
|
| 14 |
import torch
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
from nemo.collections.asr.models import ASRModel
|
| 17 |
+
from pydub import AudioSegment
|
| 18 |
|
| 19 |
try:
|
| 20 |
from nemo.collections.asr.models import SortformerEncLabelModel
|
|
|
|
| 108 |
return None
|
| 109 |
return frame_rate, samples
|
| 110 |
except Exception as e:
|
| 111 |
+
print(
|
| 112 |
+
f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}"
|
| 113 |
+
)
|
| 114 |
return None
|
| 115 |
|
| 116 |
|
|
|
|
| 162 |
return samples - np.mean(samples, dtype=np.float32)
|
| 163 |
|
| 164 |
|
| 165 |
+
def fft_bandpass(
|
| 166 |
+
samples: np.ndarray, sr: int, low_hz: float, high_hz: float
|
| 167 |
+
) -> np.ndarray:
|
| 168 |
samples = np.asarray(samples, dtype=np.float32)
|
| 169 |
if samples.size == 0:
|
| 170 |
return samples
|
|
|
|
| 209 |
mask = np.clip(mask, min_mask, 1.0)
|
| 210 |
|
| 211 |
cleaned_stft = magnitude * mask * np.exp(1j * phase)
|
| 212 |
+
cleaned = librosa.istft(
|
| 213 |
+
cleaned_stft, hop_length=hop, win_length=n_fft, length=len(samples)
|
| 214 |
+
)
|
| 215 |
return cleaned.astype(np.float32, copy=False)
|
| 216 |
|
| 217 |
|
|
|
|
| 295 |
samples = np.clip(raw / max_abs, -1.0, 1.0)
|
| 296 |
|
| 297 |
samples = remove_dc_offset(samples)
|
| 298 |
+
samples = spectral_denoise(
|
| 299 |
+
samples, strength=1.25, noise_percentile=15.0, min_mask=0.06
|
| 300 |
+
)
|
| 301 |
samples = fft_bandpass(samples, sr=target_sr, low_hz=120.0, high_hz=3600.0)
|
| 302 |
samples = dynamic_rms_normalize(
|
| 303 |
samples=samples,
|
|
|
|
| 340 |
speaker = parts[7]
|
| 341 |
if start is None or dur is None or dur <= 0:
|
| 342 |
return None
|
| 343 |
+
return {
|
| 344 |
+
"start": start,
|
| 345 |
+
"end": start + dur,
|
| 346 |
+
"speaker": normalize_speaker_label(speaker),
|
| 347 |
+
}
|
| 348 |
|
| 349 |
|
| 350 |
def _parse_simple_segment_line(line: str):
|
|
|
|
| 372 |
e = _try_float(end)
|
| 373 |
if s is None or e is None or e <= s:
|
| 374 |
return
|
| 375 |
+
parsed.append(
|
| 376 |
+
{"start": s, "end": e, "speaker": normalize_speaker_label(speaker)}
|
| 377 |
+
)
|
| 378 |
|
| 379 |
def walk(obj):
|
| 380 |
if obj is None:
|
|
|
|
| 427 |
return
|
| 428 |
|
| 429 |
if isinstance(obj, (list, tuple)):
|
| 430 |
+
if (
|
| 431 |
+
len(obj) >= 3
|
| 432 |
+
and _try_float(obj[0]) is not None
|
| 433 |
+
and _try_float(obj[1]) is not None
|
| 434 |
+
):
|
| 435 |
append_seg(obj[0], obj[1], obj[2])
|
| 436 |
return
|
| 437 |
for item in obj:
|
|
|
|
| 439 |
return
|
| 440 |
|
| 441 |
if hasattr(obj, "start") and hasattr(obj, "end"):
|
| 442 |
+
append_seg(
|
| 443 |
+
getattr(obj, "start"), getattr(obj, "end"), getattr(obj, "speaker", "0")
|
| 444 |
+
)
|
| 445 |
|
| 446 |
walk(raw_output)
|
| 447 |
|
|
|
|
| 473 |
merged = [segments[0].copy()]
|
| 474 |
for seg in segments[1:]:
|
| 475 |
last = merged[-1]
|
| 476 |
+
if (
|
| 477 |
+
seg["speaker"] == last["speaker"]
|
| 478 |
+
and seg["start"] - last["end"] <= max_gap_sec
|
| 479 |
+
):
|
| 480 |
last["end"] = max(last["end"], seg["end"])
|
| 481 |
else:
|
| 482 |
merged.append(seg.copy())
|
|
|
|
| 570 |
return segments
|
| 571 |
|
| 572 |
|
| 573 |
+
def _overlap_seconds(
|
| 574 |
+
a_start: float, a_end: float, b_start: float, b_end: float
|
| 575 |
+
) -> float:
|
| 576 |
return max(0.0, min(a_end, b_end) - max(a_start, b_start))
|
| 577 |
|
| 578 |
|
|
|
|
| 580 |
return " ".join(t for t in tokens if t).strip()
|
| 581 |
|
| 582 |
|
| 583 |
+
def split_asr_by_diarization_segments(
|
| 584 |
+
asr_segments: list, diar_segments: list, asr_words: list = None
|
| 585 |
+
) -> list:
|
| 586 |
if not diar_segments:
|
| 587 |
return []
|
| 588 |
|
|
|
|
| 661 |
|
| 662 |
|
| 663 |
def _capitalize_first_alpha(text: str) -> str:
|
| 664 |
+
return re.sub(
|
| 665 |
+
r"^([^A-Za-z]*)([a-z])", lambda m: m.group(1) + m.group(2).upper(), text
|
| 666 |
+
)
|
| 667 |
|
| 668 |
|
| 669 |
def _capitalize_after_full_stop(text: str) -> str:
|
|
|
|
| 731 |
"o'ninchi": "o'n",
|
| 732 |
"oninchi": "o'n",
|
| 733 |
}
|
| 734 |
+
UZ_MONTHS_PATTERN = (
|
| 735 |
+
r"yanvar|fevral|mart|aprel|may|iyun|iyul|avgust|sentabr|oktabr|noyabr|dekabr"
|
| 736 |
+
)
|
| 737 |
|
| 738 |
+
_TOKEN_CORE_RE = re.compile(
|
| 739 |
+
r"^([^A-Za-z0-9'`ʻʼ’‘]*)([A-Za-z0-9'`ʻʼ’‘]+)([^A-Za-z0-9'`ʻʼ’‘]*)$"
|
| 740 |
+
)
|
| 741 |
|
| 742 |
|
| 743 |
def _normalize_uz_word(word: str) -> str:
|
|
|
|
| 762 |
def _is_uz_number_like(word: str) -> bool:
|
| 763 |
if not word:
|
| 764 |
return False
|
| 765 |
+
if (
|
| 766 |
+
word in UZ_CARDINAL
|
| 767 |
+
or word in UZ_SCALES
|
| 768 |
+
or word == "yuz"
|
| 769 |
+
or word in UZ_ORDINAL_TO_CARDINAL
|
| 770 |
+
):
|
| 771 |
return True
|
| 772 |
return re.match(r"^.+(?:inchi|nchi)$", word) is not None
|
| 773 |
|
|
|
|
| 920 |
return text
|
| 921 |
|
| 922 |
|
| 923 |
+
def postprocess_segment_texts(
|
| 924 |
+
segment_timestamps: list, diarization_enabled: bool
|
| 925 |
+
) -> list:
|
| 926 |
for ts in segment_timestamps:
|
| 927 |
txt = str(ts.get("segment", "") or "")
|
| 928 |
txt = _clean_token_spacing(txt)
|
|
|
|
| 945 |
|
| 946 |
|
| 947 |
@spaces.GPU
|
| 948 |
+
def get_transcripts_and_raw_times(
|
| 949 |
+
audio_path, session_dir, use_preprocessing=True, use_diarization=False
|
| 950 |
+
):
|
| 951 |
if not audio_path:
|
| 952 |
gr.Error("No audio file path provided for transcription.", duration=None)
|
| 953 |
return (
|
|
|
|
| 966 |
original_path_name = Path(audio_path).name
|
| 967 |
audio_name = Path(audio_path).stem
|
| 968 |
|
| 969 |
+
csv_button_update = gr.DownloadButton(
|
| 970 |
+
label="Download Transcript (CSV)", visible=False
|
| 971 |
+
)
|
| 972 |
+
srt_button_update = gr.DownloadButton(
|
| 973 |
+
label="Download Transcript (SRT)", visible=False
|
| 974 |
+
)
|
| 975 |
|
| 976 |
transcribe_path = audio_path
|
| 977 |
info_path_name = original_path_name
|
|
|
|
| 998 |
processed_audio = preprocess_audio_for_transcription(
|
| 999 |
audio=audio, target_sr=16000, frame_ms=500, target_rms_db=-20.0
|
| 1000 |
)
|
| 1001 |
+
processed_audio_path = Path(
|
| 1002 |
+
session_dir, f"{audio_name}_asr_preprocessed.wav"
|
| 1003 |
+
)
|
| 1004 |
processed_audio.export(processed_audio_path, format="wav")
|
| 1005 |
transcribe_path = processed_audio_path.as_posix()
|
| 1006 |
info_path_name = f"{original_path_name} (preprocessed)"
|
|
|
|
| 1022 |
|
| 1023 |
if duration_sec > 480:
|
| 1024 |
try:
|
| 1025 |
+
gr.Info(
|
| 1026 |
+
"Audio longer than 8 minutes. Applying long audio settings.",
|
| 1027 |
+
duration=3,
|
| 1028 |
+
)
|
| 1029 |
model.change_attention_model("rel_pos_local_attn", [256, 256])
|
| 1030 |
model.change_subsampling_conv_chunking_factor(1)
|
| 1031 |
long_audio_settings_applied = True
|
| 1032 |
except Exception as setting_e:
|
| 1033 |
+
gr.Warning(
|
| 1034 |
+
f"Could not apply long audio settings: {setting_e}", duration=5
|
| 1035 |
+
)
|
| 1036 |
|
| 1037 |
if device == "cuda":
|
| 1038 |
model.to(torch.bfloat16)
|
|
|
|
| 1043 |
try:
|
| 1044 |
gr.Info("Running ASR and diarization in parallel...", duration=3)
|
| 1045 |
|
| 1046 |
+
diar_input_path = audio_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1047 |
|
| 1048 |
dmodel = get_diar_model()
|
| 1049 |
dmodel.to(device)
|
|
|
|
| 1054 |
|
| 1055 |
def _run_diar():
|
| 1056 |
try:
|
| 1057 |
+
diar_output_local = dmodel.diarize(
|
| 1058 |
+
audio=diar_input_path, batch_size=1
|
| 1059 |
+
)
|
| 1060 |
except TypeError:
|
| 1061 |
+
diar_output_local = dmodel.diarize(
|
| 1062 |
+
audio=[diar_input_path], batch_size=1
|
| 1063 |
+
)
|
| 1064 |
|
| 1065 |
diar_segments_local = parse_diarization_output(
|
| 1066 |
diar_output_local,
|
|
|
|
| 1090 |
diar_segments=diar_segments,
|
| 1091 |
asr_words=asr_words,
|
| 1092 |
)
|
| 1093 |
+
segment_timestamps = merge_consecutive_transcript_rows(
|
| 1094 |
+
segment_timestamps
|
| 1095 |
+
)
|
| 1096 |
|
| 1097 |
if not segment_timestamps:
|
| 1098 |
+
gr.Warning(
|
| 1099 |
+
"No aligned diarized rows. Using ASR segmentation.",
|
| 1100 |
+
duration=7,
|
| 1101 |
+
)
|
| 1102 |
segment_timestamps = asr_segments
|
| 1103 |
|
| 1104 |
gr.Info("Diarization + ASR complete.", duration=2)
|
| 1105 |
|
| 1106 |
except Exception as diar_e:
|
| 1107 |
+
gr.Warning(
|
| 1108 |
+
f"Diarization failed: {diar_e}. Using standard ASR segmentation.",
|
| 1109 |
+
duration=7,
|
| 1110 |
+
)
|
| 1111 |
+
segment_timestamps = transcribe_default_with_timestamps(
|
| 1112 |
+
transcribe_path
|
| 1113 |
+
)
|
| 1114 |
else:
|
| 1115 |
segment_timestamps = transcribe_default_with_timestamps(transcribe_path)
|
| 1116 |
|
|
|
|
| 1128 |
]
|
| 1129 |
for ts in segment_timestamps
|
| 1130 |
]
|
| 1131 |
+
raw_times_data = [
|
| 1132 |
+
[float(ts["start"]), float(ts["end"])] for ts in segment_timestamps
|
| 1133 |
+
]
|
| 1134 |
|
| 1135 |
try:
|
| 1136 |
csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
|
|
|
|
| 1142 |
value=csv_file_path, visible=True, label="Download Transcript (CSV)"
|
| 1143 |
)
|
| 1144 |
except Exception as csv_e:
|
| 1145 |
+
gr.Error(
|
| 1146 |
+
f"Failed to create transcript CSV file: {csv_e}", duration=None
|
| 1147 |
+
)
|
| 1148 |
|
| 1149 |
if segment_timestamps:
|
| 1150 |
try:
|
|
|
|
| 1153 |
with open(srt_file_path, "w", encoding="utf-8") as f:
|
| 1154 |
f.write(srt_content)
|
| 1155 |
srt_button_update = gr.DownloadButton(
|
| 1156 |
+
value=srt_file_path,
|
| 1157 |
+
visible=True,
|
| 1158 |
+
label="Download Transcript (SRT)",
|
| 1159 |
)
|
| 1160 |
except Exception as srt_e:
|
| 1161 |
+
gr.Warning(
|
| 1162 |
+
f"Failed to create transcript SRT file: {srt_e}", duration=5
|
| 1163 |
+
)
|
| 1164 |
|
| 1165 |
gr.Info("Transcription complete.", duration=2)
|
| 1166 |
return (
|
|
|
|
| 1182 |
srt_button_update,
|
| 1183 |
)
|
| 1184 |
except FileNotFoundError:
|
| 1185 |
+
gr.Error(
|
| 1186 |
+
f"Audio file not found for transcription: {Path(transcribe_path).name}",
|
| 1187 |
+
duration=None,
|
| 1188 |
+
)
|
| 1189 |
return (
|
| 1190 |
[["Error", "Error", "N/A", "File not found for transcription"]],
|
| 1191 |
[[0.0, 0.0]],
|
|
|
|
| 1209 |
model.change_attention_model("rel_pos")
|
| 1210 |
model.change_subsampling_conv_chunking_factor(-1)
|
| 1211 |
except Exception as revert_e:
|
| 1212 |
+
gr.Warning(
|
| 1213 |
+
f"Issue reverting model settings: {revert_e}", duration=5
|
| 1214 |
+
)
|
| 1215 |
|
| 1216 |
if device == "cuda":
|
| 1217 |
model.cpu()
|
|
|
|
| 1301 |
|
| 1302 |
with gr.Blocks(theme=nvidia_theme) as demo:
|
| 1303 |
model_display_name = MODEL_NAME.split("/")[-1] if "/" in MODEL_NAME else MODEL_NAME
|
| 1304 |
+
gr.Markdown(
|
| 1305 |
+
f"<h1 style='text-align:center;margin:0 auto;'>Speech Transcription with {model_display_name}</h1>"
|
| 1306 |
+
)
|
| 1307 |
gr.HTML(article)
|
| 1308 |
|
| 1309 |
current_audio_path_state = gr.State(None)
|
|
|
|
| 1321 |
|
| 1322 |
with gr.Tabs():
|
| 1323 |
with gr.TabItem("Audio File"):
|
| 1324 |
+
file_input = gr.Audio(
|
| 1325 |
+
sources=["upload"], type="filepath", label="Upload Audio File"
|
| 1326 |
+
)
|
| 1327 |
+
gr.Examples(
|
| 1328 |
+
examples=examples, inputs=[file_input], label="Example Audio Files"
|
| 1329 |
+
)
|
| 1330 |
+
file_transcribe_btn = gr.Button(
|
| 1331 |
+
"Transcribe Uploaded File", variant="primary"
|
| 1332 |
+
)
|
| 1333 |
|
| 1334 |
with gr.TabItem("Microphone"):
|
| 1335 |
+
mic_input = gr.Audio(
|
| 1336 |
+
sources=["microphone"], type="filepath", label="Record Audio"
|
| 1337 |
+
)
|
| 1338 |
+
mic_transcribe_btn = gr.Button(
|
| 1339 |
+
"Transcribe Microphone Input", variant="primary"
|
| 1340 |
+
)
|
| 1341 |
|
| 1342 |
gr.Markdown("---")
|
| 1343 |
with gr.Row():
|
| 1344 |
+
download_btn_csv = gr.DownloadButton(
|
| 1345 |
+
label="Download Transcript (CSV)", visible=False
|
| 1346 |
+
)
|
| 1347 |
+
download_btn_srt = gr.DownloadButton(
|
| 1348 |
+
label="Download Transcript (SRT)", visible=False
|
| 1349 |
+
)
|
| 1350 |
|
| 1351 |
vis_timestamps_df = gr.DataFrame(
|
| 1352 |
headers=["Start (s)", "End (s)", "Speaker", "Segment"],
|