import os import torch import gradio as gr from pathlib import Path import numpy as np from pyannote.audio import Pipeline from transformers import pipeline from datasets import Dataset import librosa import soundfile as sf from datetime import timedelta from transformers.utils import is_flash_attn_2_available class DiarizationTranscriptionTranslation: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Load models for diarization, transcription, and translation self.load_models() def load_models(self): """Load all required models""" print("Loading models...") # Initialize the optimized ASR pipeline using transformers with a smaller, faster model # Using whisper-base for faster download and less space usage while maintaining good speed model_name = "openai/whisper-base" # Using base which is much smaller than large variants self.asr_pipeline = pipeline( "automatic-speech-recognition", model=model_name, torch_dtype=torch.float32, # Using float32 for better compatibility with base model device="cuda:0" if self.device == "cuda" else "cpu", ) # Load translation model self.translation_pipeline = pipeline( "translation", model="Helsinki-NLP/opus-mt-mul-en", device=0 if self.device == "cuda" else -1 ) # Initialize the diarization pipeline # Note: This requires authentication with Hugging Face for pyannote.audio models # For Hugging Face Spaces, authentication is handled automatically if set up in the space settings try: # Try to load with auth token first self.diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization@2.1", use_auth_token=True ) print("Diarization model loaded successfully!") except Exception as e: print(f"Could not load diarization model with auth: {e}") try: # Try to load without auth (for public models or when cache is available) self.diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization@2.1" ) print("Diarization model loaded successfully from cache!") except Exception as e2: print(f"Could not load diarization model: {e2}") print("Using fallback diarization method") self.diarization_pipeline = None print("Models loaded successfully!") def load_audio(self, file_path, sr=16000): """Load audio file and return waveform and sample rate""" audio, orig_sr = librosa.load(file_path, sr=None) if sr is not None and orig_sr != sr: audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr) return audio, orig_sr def diarize_audio(self, file_path): """Perform speaker diarization on the audio file""" if self.diarization_pipeline is None: # Fallback to simple approach if diarization model not available duration = librosa.get_duration(path=file_path) segments = [] current_time = 0 segment_duration = min(5.0, duration) # 5 second segments max speaker_id = 1 while current_time < duration: end_time = min(current_time + segment_duration, duration) segments.append({ "start": current_time, "end": end_time, "speaker": f"Speaker {speaker_id}" }) current_time = end_time speaker_id = 2 if speaker_id == 1 else 1 # Alternate speakers return segments else: # Use pyannote.audio for proper diarization from pyannote.audio import Audio from pyannote.core import Annotation # Load the audio file audio = Audio().read(file_path) # Apply diarization diarization = self.diarization_pipeline({"waveform": audio[0], "sample_rate": audio[1]}) # Convert diarization to segments segments = [] for segment, track, speaker in diarization.itertracks(yield_label=True): segments.append({ "start": segment.start, "end": segment.end, "speaker": speaker }) return segments def transcribe_audio(self, file_path): """Transcribe audio using optimized Whisper with insanely-fast-whisper techniques""" try: # The ASR pipeline can work directly with file paths # Using optimized parameters from insanely-fast-whisper for speed result = self.asr_pipeline( file_path, chunk_length_s=30, batch_size=24, # Increased batch size for faster processing max_new_tokens=128, return_timestamps=True, generate_kwargs={"task": "transcribe"} # Explicitly set task for better performance ) return result except Exception as e: print(f"Error during transcription: {str(e)}") return {"text": f"Transcription error: {str(e)}"} def detect_language(self, text): """Detect language of the transcribed text""" # Use langdetect for language detection try: from langdetect import detect return detect(text) except ImportError: # If langdetect is not available, use a simple heuristic # This is a basic heuristic - in practice, you'd want a proper language detection model non_english_chars = sum(1 for c in text if ord(c) > 127) if non_english_chars / len(text) > 0.1: # If more than 10% are non-English chars # For demo purposes, return a common non-English language # A real implementation would use a proper language detection model return "es" # Spanish as example return "en" except: # If detection fails, return English as default return "en" def translate_text(self, text): """Translate text to English""" try: # First detect the language detected_lang = self.detect_language(text) # If the text is not in English, translate it if detected_lang != "en": result = self.translation_pipeline(text) return result[0]['translation_text'] else: return text except Exception as e: print(f"Translation error: {str(e)}") # If translation fails, return the original text return text def process_audio(self, file_path): """Main processing pipeline: diarization -> transcription -> translation""" if not file_path: return "Please upload an audio file." try: # Step 1: Diarize the audio diarization_segments = self.diarize_audio(file_path) # Step 2: For each segment, transcribe and then translate results = [] for segment in diarization_segments: # Extract the audio segment audio, orig_sr = self.load_audio(file_path) start_sample = int(segment["start"] * orig_sr) end_sample = int(segment["end"] * orig_sr) segment_audio = audio[start_sample:end_sample] # Ensure segment_audio is not empty if len(segment_audio) == 0: continue # Skip empty segments # Add a small amount of silence if segment is too short for Whisper if len(segment_audio) < orig_sr * 0.1: # Less than 0.1 seconds min_samples = int(orig_sr * 0.1) zeros_to_add = min_samples - len(segment_audio) segment_audio = np.pad(segment_audio, (0, zeros_to_add), mode='constant') # Save the segment as a temporary file for Whisper temp_file = f"temp_segment_{segment['start']}_{segment['end']}.wav" # Use subtype parameter to ensure proper WAV format sf.write(temp_file, segment_audio, orig_sr, format='WAV', subtype='PCM_16') # Transcribe the segment try: transcription_result = self.transcribe_audio(temp_file) # Handle both possible return formats if isinstance(transcription_result, dict) and "text" in transcription_result: transcribed_text = transcription_result["text"] elif isinstance(transcription_result, str): transcribed_text = transcription_result else: transcribed_text = str(transcription_result) except Exception as e: print(f"Error transcribing segment {temp_file}: {str(e)}") transcribed_text = f"Transcription error: {str(e)}" # Continue with the error message as the transcription # Translate if necessary translated_text = self.translate_text(transcribed_text) results.append({ "start": segment["start"], "end": segment["end"], "speaker": segment["speaker"], "transcription": transcribed_text, "translation": translated_text }) # Clean up temp file try: os.remove(temp_file) except: pass # Ignore errors when removing temp file # Format results as a readable transcript transcript = [] for result in results: start_time = str(timedelta(seconds=int(result["start"]))) end_time = str(timedelta(seconds=int(result["end"]))) # Show both original transcription and translation if different if result["transcription"] != result["translation"]: transcript.append( f"[{start_time} - {end_time}] {result['speaker']}:\n" f" Original: {result['transcription']}\n" f" Translation: {result['translation']}\n" ) else: transcript.append( f"[{start_time} - {end_time}] {result['speaker']}: {result['translation']}" ) return "\n".join(transcript) except Exception as e: return f"Processing error: {str(e)}" # Initialize the processor processor = DiarizationTranscriptionTranslation() def process_audio_file(audio_file): """Wrapper function for Gradio interface""" if audio_file is None: return "Please upload an audio file." return processor.process_audio(audio_file) # Create Gradio interface interface = gr.Interface( fn=process_audio_file, inputs=gr.Audio(type="filepath", label="Upload Audio File"), outputs=gr.Textbox(label="Transcript with Speaker Labels", lines=20), title="Speaker Diarization, Transcription & Translation", description=""" This Space combines three powerful speech processing capabilities: 1. Speaker Diarization - Distinguishes between different speakers in your audio 2. Speech Transcription - Converts spoken words into text using Whisper 3. Automatic Translation - Translates non-English content to English Upload an audio file (MP3, WAV, or other common formats) and get a timestamped transcript with speaker labels. """, examples=[ # Add example files if available ], cache_examples=False ) if __name__ == "__main__": interface.launch(server_name="0.0.0.0", server_port=7860)