| import streamlit as st |
| from pyannote.audio import Pipeline |
| import whisper |
| import tempfile |
| import os |
| import torch |
| from transformers import pipeline as tf_pipeline |
| from pydub import AudioSegment |
| import io |
|
|
| @st.cache_resource |
| def load_models(): |
| try: |
| diarization = Pipeline.from_pretrained( |
| "pyannote/speaker-diarization", |
| use_auth_token=st.secrets["hf_token"] |
| ) |
| |
| transcriber = whisper.load_model("base") |
| |
| summarizer = tf_pipeline( |
| "summarization", |
| model="facebook/bart-large-cnn", |
| device=0 if torch.cuda.is_available() else -1 |
| ) |
| |
| if not diarization or not transcriber or not summarizer: |
| raise ValueError("One or more models failed to load") |
| |
| return diarization, transcriber, summarizer |
| except Exception as e: |
| st.error(f"Error loading models: {str(e)}") |
| st.error("Debug info: Check if HF token is valid and has necessary permissions") |
| return None, None, None |
|
|
| def process_audio(audio_file, max_duration=600): |
| try: |
| audio_bytes = io.BytesIO(audio_file.getvalue()) |
| |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| try: |
| if audio_file.name.lower().endswith('.mp3'): |
| audio = AudioSegment.from_mp3(audio_bytes) |
| else: |
| audio = AudioSegment.from_wav(audio_bytes) |
| |
| |
| audio = audio.set_frame_rate(16000) |
| audio = audio.set_channels(1) |
| audio = audio.set_sample_width(2) |
| |
| audio.export( |
| tmp.name, |
| format="wav", |
| parameters=["-ac", "1", "-ar", "16000"] |
| ) |
| tmp_path = tmp.name |
| |
| except Exception as e: |
| st.error(f"Error converting audio: {str(e)}") |
| return None |
|
|
| diarization, transcriber, summarizer = load_models() |
| if not all([diarization, transcriber, summarizer]): |
| return "Model loading failed" |
|
|
| with st.spinner("Identifying speakers..."): |
| diarization_result = diarization(tmp_path) |
| |
| with st.spinner("Transcribing audio..."): |
| transcription = transcriber.transcribe(tmp_path) |
| |
| with st.spinner("Generating summary..."): |
| summary = summarizer(transcription["text"], max_length=130, min_length=30) |
|
|
| os.unlink(tmp_path) |
| |
| return { |
| "diarization": diarization_result, |
| "transcription": transcription, |
| "summary": summary[0]["summary_text"] |
| } |
| |
| except Exception as e: |
| st.error(f"Error processing audio: {str(e)}") |
| return None |
|
|
| def format_speaker_segments(diarization_result, transcription): |
| if diarization_result is None: |
| return [] |
| |
| formatted_segments = [] |
| whisper_segments = transcription.get('segments', []) |
| |
| try: |
| for turn, _, speaker in diarization_result.itertracks(yield_label=True): |
| current_text = "" |
| |
| for w_segment in whisper_segments: |
| w_start = float(w_segment['start']) |
| w_end = float(w_segment['end']) |
| |
| |
| if (w_start >= turn.start and w_start < turn.end) or \ |
| (w_end > turn.start and w_end <= turn.end): |
| current_text += w_segment['text'].strip() + " " |
| |
| formatted_segments.append({ |
| 'speaker': str(speaker), |
| 'start': float(turn.start), |
| 'end': float(turn.end), |
| 'text': current_text.strip() |
| }) |
| |
| except Exception as e: |
| st.error(f"Error formatting segments: {str(e)}") |
| return [] |
| |
| return formatted_segments |
|
|
| def format_timestamp(seconds): |
| minutes = int(seconds // 60) |
| seconds = seconds % 60 |
| return f"{minutes:02d}:{seconds:05.2f}" |
|
|
| def main(): |
| st.title("Multi-Speaker Audio Analyzer") |
| st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance") |
|
|
| uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"]) |
|
|
| if uploaded_file: |
| file_size = len(uploaded_file.getvalue()) / (1024 * 1024) |
| st.write(f"File size: {file_size:.2f} MB") |
| |
| st.audio(uploaded_file, format='audio/wav') |
| |
| if st.button("Analyze Audio"): |
| if file_size > 200: |
| st.error("File size exceeds 200MB limit") |
| else: |
| results = process_audio(uploaded_file) |
| |
| if results: |
| tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"]) |
| |
| with tab1: |
| st.write("Speaker Timeline:") |
| segments = format_speaker_segments( |
| results["diarization"], |
| results["transcription"] |
| ) |
| |
| if segments: |
| for segment in segments: |
| col1, col2, col3 = st.columns([2,3,5]) |
| |
| with col1: |
| speaker_num = int(segment['speaker'].split('_')[1]) |
| colors = ['🔵', '🔴'] |
| speaker_color = colors[speaker_num % len(colors)] |
| st.write(f"{speaker_color} {segment['speaker']}") |
| |
| with col2: |
| start_time = format_timestamp(segment['start']) |
| end_time = format_timestamp(segment['end']) |
| st.write(f"{start_time} → {end_time}") |
| |
| with col3: |
| if segment['text']: |
| st.write(f"\"{segment['text']}\"") |
| else: |
| st.write("(no speech detected)") |
| |
| st.markdown("---") |
| else: |
| st.warning("No speaker segments detected") |
| |
| with tab2: |
| st.write("Transcription:") |
| if "text" in results["transcription"]: |
| st.write(results["transcription"]["text"]) |
| else: |
| st.warning("No transcription available") |
| |
| with tab3: |
| st.write("Summary:") |
| if results["summary"]: |
| st.write(results["summary"]) |
| else: |
| st.warning("No summary available") |
|
|
| if __name__ == "__main__": |
| main() |