import gradio as gr import whisper import torch from pyannote.audio import Pipeline from pydub import AudioSegment import re import os from typing import List, Dict, Tuple import tempfile # Detect and use GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load models (will be cached after first load) print("Loading Whisper model...") whisper_model = whisper.load_model("large-v2", device=device) # Load on GPU if available print(f"Whisper model loaded on {device}") # Diarization pipeline will be loaded on-demand with user's token # Filler words and minimal vocalizations to remove FILLER_WORDS = [ r'\buh\b', r'\bum\b', r'\bmmm+\b', r'\bmm+\b', r'\bhmm+\b', r'\bahh+\b', r'\buhh+\b', r'\berr+\b', r'\boh\b', r'\byou know\b', r'\blike\b', r'\bbasically\b', r'\bliterally\b', r'\bactually\b', r'\bokay\b', r'\bright\b', r'\byeah\b', r'\buh-huh\b', r'\bmhm\b', r'\bnah\b' ] def convert_to_wav(audio_path: str) -> str: """Convert audio file to WAV format for processing.""" audio = AudioSegment.from_file(audio_path) wav_path = tempfile.mktemp(suffix=".wav") audio.export(wav_path, format="wav") return wav_path def clean_text(text: str) -> str: """Remove filler words, stutters, and clean up text.""" # Remove filler words for filler in FILLER_WORDS: text = re.sub(filler, '', text, flags=re.IGNORECASE) # Remove stutters (e.g., "I-I-I" -> "I") text = re.sub(r'\b(\w+)(-\1)+\b', r'\1', text) # Clean up extra spaces text = re.sub(r'\s+', ' ', text) text = text.strip() return text def identify_speaker(speaker_label: str, voice_mapping: Dict[str, str] = None) -> str: """ Identify speaker based on diarization label and user-provided voice mapping. Args: speaker_label: The speaker label from diarization (e.g., "SPEAKER_00") voice_mapping: Dictionary mapping speaker labels to names Returns: The identified speaker name """ if voice_mapping and speaker_label in voice_mapping: return voice_mapping[speaker_label] else: # Fallback for unmapped speakers speaker_num = speaker_label.split("_")[-1] if "_" in speaker_label else "00" return f"Speaker {speaker_num}" def format_timestamp(seconds: float) -> str: """Convert seconds to SRT timestamp format (HH:MM:SS,mmm).""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) millis = int((seconds % 1) * 1000) return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" def split_into_sentences(text: str) -> List[str]: """Split text into sentences for better subtitle formatting.""" # Split on sentence boundaries sentences = re.split(r'(?<=[.!?])\s+', text) return [s.strip() for s in sentences if s.strip()] def process_audio_to_srt( audio_path: str, hf_token: str, voice1_name: str = "", voice1_desc: str = "", voice2_name: str = "", voice2_desc: str = "", voice3_name: str = "", voice3_desc: str = "", progress=gr.Progress() ) -> Tuple[str, str]: """ Main processing function: STT + Diarization + SRT generation. Args: audio_path: Path to the audio file hf_token: Hugging Face API token for accessing Pyannote models voice1_name: Name for the first voice voice1_desc: Description for the first voice voice2_name: Name for the second voice voice2_desc: Description for the second voice voice3_name: Name for the third voice voice3_desc: Description for the third voice progress: Gradio progress tracker Returns: (srt_content, debug_info) """ # Validate HF token if not hf_token or not hf_token.strip(): return "Error: Hugging Face token is required. Please provide your HF token.", "Token validation failed" # Build voice mapping from user inputs voice_mapping = {} if voice1_name.strip(): voice_mapping["SPEAKER_00"] = voice1_name.strip() if voice2_name.strip(): voice_mapping["SPEAKER_01"] = voice2_name.strip() if voice3_name.strip(): voice_mapping["SPEAKER_02"] = voice3_name.strip() try: progress(0, desc="Loading Pyannote diarization pipeline...") # Load diarization pipeline with user's token try: diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", token=hf_token.strip() ) # Move to GPU if available if device == "cuda": diarization_pipeline.to(torch.device(device)) except Exception as e: error_msg = str(e) if "gated repo" in error_msg.lower() or "agreement" in error_msg.lower(): return ("Error: You need to accept the user agreement for pyannote/speaker-diarization-3.1\n" "Please visit: https://huggingface.co/pyannote/speaker-diarization-3.1\n" "Accept the agreement, then try again."), f"Pipeline loading failed: {error_msg}" elif "token" in error_msg.lower() or "unauthorized" in error_msg.lower(): return ("Error: Invalid Hugging Face token. Please check your token and try again.\n" "Get your token at: https://huggingface.co/settings/tokens"), f"Token validation failed: {error_msg}" else: return f"Error loading diarization pipeline: {error_msg}", f"Pipeline loading failed: {error_msg}" progress(0.05, desc="Converting audio to WAV format...") # Convert to WAV if needed if not audio_path.endswith('.wav'): wav_path = convert_to_wav(audio_path) else: wav_path = audio_path # Step 1: Transcribe with Whisper progress(0.1, desc="Starting Whisper transcription (this may take 2-5 minutes)...") result = whisper_model.transcribe( wav_path, language="en", word_timestamps=True, verbose=False, fp16=(device == "cuda") # Use FP16 on GPU for faster processing ) # Step 2: Perform speaker diarization progress(0.4, desc="Transcription complete! Now analyzing speakers with Pyannote...") progress(0.45, desc="Pyannote: Loading audio and extracting features...") progress(0.5, desc="Pyannote: Detecting speaker segments (this is the longest step - 3-10 minutes)...") diarization = diarization_pipeline(wav_path) # Step 3: Align transcription with speaker labels progress(0.75, desc="Diarization complete! Matching speakers to transcription...") # Create a list of speaker segments speaker_segments = [] for turn, _, speaker in diarization.itertracks(yield_label=True): speaker_segments.append({ 'start': turn.start, 'end': turn.end, 'speaker': speaker }) # Match words to speakers segments_with_speakers = [] for segment in result['segments']: segment_start = segment['start'] segment_end = segment['end'] segment_text = segment['text'].strip() # Find the speaker for this segment (based on overlap) speaker = None max_overlap = 0 for spk_seg in speaker_segments: overlap_start = max(segment_start, spk_seg['start']) overlap_end = min(segment_end, spk_seg['end']) overlap_duration = max(0, overlap_end - overlap_start) if overlap_duration > max_overlap: max_overlap = overlap_duration speaker = spk_seg['speaker'] if speaker: speaker_name = identify_speaker(speaker, voice_mapping) segments_with_speakers.append({ 'start': segment_start, 'end': segment_end, 'text': segment_text, 'speaker': speaker_name }) # Step 4: Generate SRT with formatting rules progress(0.85, desc="Cleaning text and formatting SRT subtitles...") srt_lines = [] subtitle_number = 1 for seg in segments_with_speakers: # Clean the text cleaned_text = clean_text(seg['text']) if not cleaned_text: continue # Split into sentences if needed sentences = split_into_sentences(cleaned_text) if not sentences: sentences = [cleaned_text] # Create subtitle blocks (one per sentence) for sentence in sentences: if not sentence: continue start_time = format_timestamp(seg['start']) end_time = format_timestamp(seg['end']) # Format: subtitle number, timestamps, (Speaker) text srt_lines.append(f"{subtitle_number}") srt_lines.append(f"{start_time} --> {end_time}") srt_lines.append(f"({seg['speaker']}) {sentence}") srt_lines.append("") # Blank line between subtitles subtitle_number += 1 srt_content = "\n".join(srt_lines) # Clean up temporary file if wav_path != audio_path and os.path.exists(wav_path): os.remove(wav_path) debug_info = f"Processed successfully!\nTotal segments: {len(segments_with_speakers)}\nTotal subtitles: {subtitle_number - 1}" progress(1.0, desc="Complete! SRT file ready for download.") return srt_content, debug_info except Exception as e: return f"Error: {str(e)}", f"Processing failed: {str(e)}" def save_srt_file(srt_content: str) -> str: """Save SRT content to a temporary file for download.""" if not srt_content or srt_content.startswith("Error"): return None temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.srt', delete=False, encoding='utf-8') temp_file.write(srt_content) temp_file.close() return temp_file.name # Create Gradio interface with gr.Blocks(title="Audio to SRT Converter with Speaker Diarization", theme=gr.themes.Soft()) as demo: # Display GPU info gpu_info = f"Running on: {device.upper()}" if device == "cuda": gpu_name = torch.cuda.get_device_name(0) gpu_info += f" ({gpu_name})" gr.Markdown(f""" # Audio to SRT Converter with Speaker Diarization Convert audio files to formatted SRT subtitles with automatic speaker detection and identification.