| import streamlit as st |
| import torch |
| import torchaudio |
| from pyannote.audio import Pipeline |
| from pyannote.audio.pipelines.utils.hook import ProgressHook |
| import tempfile |
| import os |
| import matplotlib.pyplot as plt |
| from pyannote.core import notebook |
| from huggingface_hub import HfApi, snapshot_download, hf_hub_download |
| from huggingface_hub.errors import LocalEntryNotFoundError, HfHubHTTPError |
| import requests |
| import pyannote.audio |
| import sys |
| import traceback |
| from speechbrain.pretrained import EncoderClassifier |
| from pydub import AudioSegment |
| import numpy as np |
|
|
| |
| st.set_page_config(page_title="Optimized Speaker Diarization App", layout="wide") |
|
|
| st.title("Optimized Speaker Diarization App") |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| if not HF_TOKEN: |
| st.error("HF_TOKEN not found in environment variables. Please set it in your Hugging Face Space secrets.") |
| st.stop() |
|
|
|
|
|
|
| class ProgressHook: |
| def __init__(self, status, progress_bar): |
| self.status = status |
| self.progress_bar = progress_bar |
| self.total = 0 |
| self.completed = 0 |
| self.current_stage = "" |
|
|
| def __call__(self, *args, **kwargs): |
| if len(args) == 2 and isinstance(args[0], str): |
| |
| self.current_stage = args[0] |
| self.status.update(label=f"Processing: {self.current_stage}", state="running") |
| elif 'completed' in kwargs and 'total' in kwargs: |
| self.completed = kwargs['completed'] |
| self.total = kwargs['total'] |
| self._update_progress() |
| elif len(args) == 2 and all(isinstance(arg, (int, float)) for arg in args): |
| self.completed, self.total = args |
| self._update_progress() |
| |
| def _update_progress(self): |
| if self.total > 0: |
| progress_percentage = min(self.completed / self.total, 1.0) |
| self.status.update(label=f"Processing: {self.current_stage} - {progress_percentage:.1%} complete", state="running") |
| self.progress_bar.progress(progress_percentage) |
|
|
|
|
|
|
| def preprocess_audio(tmp_path): |
| |
| audio = AudioSegment.from_file(tmp_path) |
| |
| |
| if audio.channels == 2: |
| audio = audio.set_channels(1) |
| |
| |
| if audio.frame_rate != 16000: |
| audio = audio.set_frame_rate(16000) |
| st.info("Resampled audio to 16 kHz") |
| |
| |
| samples = np.array(audio.get_array_of_samples()) |
| |
| |
| waveform = torch.FloatTensor(samples).unsqueeze(0) / 32768.0 |
| |
| |
| segment_size = 160000 |
| |
| |
| num_segments = (waveform.shape[1] + segment_size - 1) // segment_size |
| |
| |
| expected_length = num_segments * segment_size |
| |
| |
| padding_length = expected_length - waveform.shape[1] |
| |
| if padding_length > 0: |
| |
| pad = torch.zeros((waveform.shape[0], padding_length)) |
| waveform = torch.cat((waveform, pad), dim=1) |
| st.info(f"Padded waveform with {padding_length} zeros") |
| else: |
| st.info("No padding needed") |
| |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as processed_file: |
| processed_path = processed_file.name |
| torchaudio.save(processed_path, waveform, 16000) |
| st.info("Saved processed waveform to temporary WAV file") |
| |
| return waveform, 16000, processed_path |
|
|
| def check_versions(): |
| st.info("Checking package versions...") |
| |
| pyannote_version = pyannote.audio.__version__ |
| torch_version = torch.__version__ |
| |
| st.write(f"Pyannote Audio version: {pyannote_version}") |
| st.write(f"PyTorch version: {torch_version}") |
| |
| if pyannote_version < "3.1.0": |
| st.warning("Your pyannote.audio version might be outdated. Consider upgrading to 3.1.0 or later.") |
| |
| if torch_version < "2.0.0": |
| st.warning("Your PyTorch version might be outdated. Consider upgrading to 2.0.0 or later.") |
|
|
| check_versions() |
|
|
| def verify_token(token): |
| api = HfApi() |
| try: |
| user_info = api.whoami(token=token) |
| st.success(f"Token verified. Logged in as: {user_info['name']}") |
| return True |
| except Exception as e: |
| st.error(f"Token verification failed: {str(e)}") |
| return False |
|
|
| def check_hf_api(): |
| st.info("Checking Hugging Face API...") |
| api_url = "https://huggingface.co/api/models/pyannote/speaker-diarization-3.1" |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
| |
| try: |
| response = requests.get(api_url, headers=headers) |
| response.raise_for_status() |
| st.success("Successfully connected to Hugging Face API") |
| with st.expander("API Response"): |
| st.json(response.json()) |
| except requests.exceptions.RequestException as e: |
| st.error(f"Error connecting to Hugging Face API: {str(e)}") |
| if response.status_code == 403: |
| st.error("Access denied. Please check your token permissions.") |
| st.info("Ensure your token has permission to access gated repositories.") |
| st.code(response.text) |
|
|
| def verify_model_files(): |
| st.info("Verifying model files...") |
| required_files = [ |
| "config.yaml", |
| "pytorch_model.bin", |
| "pyannote_serialized_object.bin" |
| ] |
| |
| for file in required_files: |
| try: |
| path = hf_hub_download("pyannote/speaker-diarization-3.1", filename=file, use_auth_token=HF_TOKEN) |
| if os.path.exists(path): |
| st.success(f"File {file} found at {path}") |
| else: |
| st.error(f"File {file} not found") |
| except Exception as e: |
| st.error(f"Error downloading {file}: {str(e)}") |
|
|
|
|
| @st.cache_resource |
| def load_pipeline(): |
| try: |
| st.info("Attempting to load the pipeline...") |
| pipeline = Pipeline.from_pretrained( |
| "pyannote/speaker-diarization-3.1", |
| use_auth_token=HF_TOKEN |
| ) |
| st.success("Pipeline created successfully") |
|
|
| if torch.cuda.is_available(): |
| st.info("Moving pipeline to GPU...") |
| pipeline.to(torch.device("cuda")) |
| st.success("Pipeline moved to GPU") |
| |
| return pipeline |
| except Exception as e: |
| st.error(f"Error loading pipeline: {str(e)}") |
| st.error("Error details:") |
| st.code(traceback.format_exc()) |
| raise e |
|
|
| @st.cache_resource |
| def load_speechbrain_model(): |
| st.info("Loading SpeechBrain model...") |
| classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb") |
| st.success("SpeechBrain model loaded successfully") |
| return classifier |
|
|
| |
| with st.sidebar: |
| st.header("Settings") |
| show_advanced = st.toggle("Show Advanced Options") |
| if show_advanced: |
| num_speakers = st.number_input("Number of speakers (0 for auto)", min_value=0, value=0) |
| min_speakers = st.number_input("Minimum number of speakers", min_value=1, value=1) |
| max_speakers = st.number_input("Maximum number of speakers", min_value=1, value=5) |
|
|
| |
| tab1, tab2, tab3 = st.tabs(["Upload & Process", "Results", "Visualization"]) |
|
|
|
|
|
|
| with tab1: |
| uploaded_file = st.file_uploader("Choose an audio file", type=['wav', 'mp3', 'flac']) |
| |
| if uploaded_file is not None: |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: |
| tmp_file.write(uploaded_file.getvalue()) |
| tmp_path = tmp_file.name |
|
|
| try: |
| if verify_token(HF_TOKEN): |
| check_hf_api() |
| verify_model_files() |
| pipeline = load_pipeline() |
| speechbrain_model = load_speechbrain_model() |
| else: |
| st.stop() |
|
|
| |
| waveform, sample_rate, processed_path = preprocess_audio(tmp_path) |
|
|
| with st.status("Processing audio...", expanded=True) as status: |
| progress_bar = st.progress(0) |
| |
| progress_hook = ProgressHook(status, progress_bar) |
|
|
| |
| diarization_args = { |
| "file": processed_path, |
| "hook": progress_hook |
| } |
| if show_advanced: |
| if num_speakers > 0: |
| diarization_args["num_speakers"] = num_speakers |
| else: |
| diarization_args["min_speakers"] = min_speakers |
| diarization_args["max_speakers"] = max_speakers |
|
|
| diarization = pipeline(**diarization_args) |
| status.update(label="Diarization complete!", state="complete") |
|
|
| |
| rttm_content = "" |
| for turn, _, speaker in diarization.itertracks(yield_label=True): |
| rttm_line = f"SPEAKER {os.path.basename(tmp_path)} 1 {turn.start:.3f} {turn.duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n" |
| rttm_content += rttm_line |
|
|
| |
| embeddings = speechbrain_model.encode_batch(waveform) |
| st.success("Speaker embeddings generated successfully") |
|
|
| except Exception as e: |
| st.error(f"An error occurred: {str(e)}") |
| st.error("Error details:") |
| st.code(traceback.format_exc()) |
|
|
| finally: |
| |
| os.unlink(tmp_path) |
| if 'processed_path' in locals(): |
| os.unlink(processed_path) |
|
|
|
|
| with tab2: |
| if 'diarization' in locals(): |
| st.subheader("Diarization Results") |
| st.metric("Number of speakers detected", len(diarization.labels())) |
| |
| with st.expander("RTTM Output"): |
| st.text_area("RTTM Content", rttm_content, height=300) |
| |
| st.download_button( |
| label="Download RTTM file", |
| data=rttm_content, |
| file_name="diarization.rttm", |
| mime="text/plain" |
| ) |
|
|
| with tab3: |
| if 'diarization' in locals(): |
| if st.button("Visualize Diarization"): |
| fig, ax = plt.subplots(figsize=(10, 2)) |
| notebook.plot_diarization(diarization, ax=ax) |
| plt.tight_layout() |
| st.pyplot(fig) |
|
|
| |
| with st.expander("Debug Information"): |
| st.write(f"Working directory: {os.getcwd()}") |
| st.write(f"Files in working directory: {os.listdir()}") |
| st.write(f"Python version: {sys.version.split()[0]}") |
| st.write(f"PyTorch version: {torch.__version__}") |
| st.write(f"Pyannote Audio version: {pyannote.audio.__version__}") |
| st.write(f"CUDA available: {torch.cuda.is_available()}") |
| st.write(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}") |
|
|
| |
| with st.expander("Token Permissions"): |
| st.markdown(""" |
| If you're encountering access issues, please ensure your Hugging Face token has the following permissions: |
| 1. Go to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) |
| 2. Find your token or create a new one |
| 3. Ensure "Read" access is granted |
| 4. Check the box for "Access to gated repositories" |
| 5. Save the changes and try again |
| """) |
|
|
| |
| if st.button("Clear Cache"): |
| import shutil |
| cache_dir = "./model_cache" |
| if os.path.exists(cache_dir): |
| shutil.rmtree(cache_dir) |
| st.success("Cache cleared successfully.") |
| else: |
| st.info("No cache directory found.") |