sarahai commited on
Commit
d5bf736
·
verified ·
1 Parent(s): 3145737

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wget
3
+ import torch
4
+ import streamlit as st
5
+ from omegaconf import OmegaConf
6
+ import nemo.collections.asr as nemo_asr
7
+ import json
8
+ import soundfile as sf
9
+ from torchaudio.transforms import Resample
10
+ import shutil
11
+
12
+ # --- 1. SETUP & CONFIGURATION ---
13
+ TEMP_DIR = os.path.join(os.getcwd(), "temp_streamlit_output")
14
+ os.makedirs(TEMP_DIR, exist_ok=True)
15
+ NUM_SPEAKERS = 2 # Default number of speakers
16
+
17
+ @st.cache_resource
18
+ def load_models():
19
+ """
20
+ Loads all the necessary models and configurations once and caches them.
21
+ """
22
+ # Load the official NeMo config file
23
+ CONFIG_URL = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml"
24
+ CONFIG_PATH = os.path.join(TEMP_DIR, "diar_infer_telephonic.yaml")
25
+ if not os.path.exists(CONFIG_PATH):
26
+ wget.download(CONFIG_URL, TEMP_DIR)
27
+ cfg = OmegaConf.load(CONFIG_PATH)
28
+
29
+ # Load Silero VAD model
30
+ vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, onnx=True)
31
+ get_speech_timestamps_func = utils[0]
32
+
33
+ return cfg, vad_model, get_speech_timestamps_func
34
+
35
+ # --- 2. HELPER FUNCTIONS (from our previous script) ---
36
+ def run_silero_vad(audio_path, vad_model, get_speech_timestamps_func):
37
+ SAMPLING_RATE = 16000
38
+ wav, sr = sf.read(audio_path)
39
+ if len(wav.shape) > 1: wav = wav.mean(axis=1)
40
+ if sr != SAMPLING_RATE:
41
+ resampler = Resample(orig_freq=sr, new_freq=SAMPLING_RATE)
42
+ wav = resampler(torch.tensor(wav, dtype=torch.float32))
43
+ else:
44
+ wav = torch.tensor(wav, dtype=torch.float32)
45
+ return get_speech_timestamps_func(wav, vad_model, sampling_rate=SAMPLING_RATE)
46
+
47
+ def write_vad_manifest(timestamps, audio_path, manifest_path):
48
+ with open(manifest_path, 'w') as f:
49
+ for ts in timestamps:
50
+ entry = {'audio_filepath': audio_path, 'offset': ts['start'] / 16000.0, 'duration': (ts['end'] - ts['start']) / 16000.0}
51
+ f.write(json.dumps(entry) + '\n')
52
+
53
+ def format_rttm_labels(input_rttm, output_rttm):
54
+ with open(input_rttm, 'r') as infile, open(output_rttm, 'w') as outfile:
55
+ for line in infile:
56
+ new_line = line
57
+ for i in range(20):
58
+ pyannote_label = f"speaker_{i}"
59
+ standard_label = f"SPEAKER_{i:02d}"
60
+ if pyannote_label in new_line:
61
+ new_line = new_line.replace(pyannote_label, standard_label)
62
+ outfile.write(new_line)
63
+
64
+ # --- 3. MAIN DIARIZATION LOGIC ---
65
+ def diarize_audio(audio_file_path, num_speakers, cfg, vad_model, get_speech_timestamps_func):
66
+ # Modify the config with our parameters
67
+ cfg.diarizer.manifest_filepath = os.path.join(TEMP_DIR, "input_manifest.json")
68
+ cfg.diarizer.out_dir = TEMP_DIR
69
+ cfg.diarizer.speaker_embeddings.model_path = "titanet_large"
70
+ cfg.diarizer.msdd_model.model_path = None
71
+ cfg.diarizer.clustering.parameters.num_speakers = num_speakers
72
+
73
+ # Prepare VAD output
74
+ vad_timestamps = run_silero_vad(audio_file_path, vad_model, get_speech_timestamps_func)
75
+ vad_manifest_path = os.path.join(TEMP_DIR, "vad_outputs.json")
76
+ write_vad_manifest(vad_timestamps, audio_file_path, vad_manifest_path)
77
+
78
+ # Prepare main manifest
79
+ meta = {'audio_filepath': audio_file_path, 'offset': 0, 'duration': None, 'label': 'infer', 'text': '-', 'vad_filepath': vad_manifest_path}
80
+ with open(cfg.diarizer.manifest_filepath, "w") as f:
81
+ f.write(json.dumps(meta) + '\n')
82
+
83
+ # Initialize and run diarizer
84
+ diarizer = nemo_asr.models.ClusteringDiarizer(cfg=cfg)
85
+ diarizer.diarize()
86
+
87
+ # Format and return the path to the final RTTM
88
+ file_id = os.path.splitext(os.path.basename(audio_file_path))[0]
89
+ raw_rttm_path = os.path.join(TEMP_DIR, "pred_rttms", f"{file_id}.rttm")
90
+ final_rttm_path = os.path.join(TEMP_DIR, f"{file_id}_formatted.rttm")
91
+ format_rttm_labels(raw_rttm_path, final_rttm_path)
92
+
93
+ return final_rttm_path
94
+
95
+ # --- 4. STREAMLIT UI ---
96
+ st.set_page_config(layout="wide")
97
+ st.title("🗣️ Speaker Diarization Tool")
98
+
99
+ st.write("Upload an audio file (.wav, .mp3) and the model will determine who spoke when.")
100
+
101
+ # Load models once
102
+ cfg, vad_model, get_speech_timestamps_func = load_models()
103
+
104
+ uploaded_file = st.file_uploader("Choose an audio file...", type=["wav", "mp3", "flac"])
105
+
106
+ if uploaded_file is not None:
107
+ st.audio(uploaded_file, format='audio/wav')
108
+
109
+ if st.button("Diarize Audio"):
110
+ with st.spinner('Processing... This may take a moment.'):
111
+ # Save uploaded file to a temporary path
112
+ temp_audio_path = os.path.join(TEMP_DIR, uploaded_file.name)
113
+ with open(temp_audio_path, "wb") as f:
114
+ f.write(uploaded_file.getbuffer())
115
+
116
+ # Run diarization
117
+ final_rttm_path = diarize_audio(temp_audio_path, NUM_SPEAKERS, cfg, vad_model, get_speech_timestamps_func)
118
+
119
+ # Read and display the RTTM content
120
+ with open(final_rttm_path, 'r') as f:
121
+ rttm_content = f.read()
122
+
123
+ st.success("Diarization complete!")
124
+ st.text_area("RTTM Output", rttm_content, height=300)
125
+
126
+ # Add a download button
127
+ st.download_button(
128
+ label="Download RTTM file",
129
+ data=rttm_content,
130
+ file_name=os.path.basename(final_rttm_path),
131
+ mime='text/plain',
132
+ )