import gradio as gr import os import torch import gc import json import whisperx from pyannote.audio import Pipeline from huggingface_hub import HfFolder from transformers import pipeline import numpy as np import soundfile as sf import io import tempfile # --- Configuration --- HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: print("WARNING: HF_TOKEN environment variable not set. Please set it as a Space secret or directly for local testing.") print("Visit https://huggingface.co/settings/tokens to create one and accept model conditions for pyannote/speaker-diarization, etc.") device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "int8" whisper_model_size = "medium" # 'large-v2' is best but most resource intensive. # 'small' or 'medium' are better for free tiers. # --- Global Models (loaded once) --- whisper_model_global = None diarize_pipeline_global = None translation_pipeline_global = None def load_all_models(): global whisper_model_global, diarize_pipeline_global, translation_pipeline_global print(f"Loading WhisperX model ({whisper_model_size})...") whisper_model_global = whisperx.load_model(whisper_model_size, device=device, compute_type=compute_type) print("Loading Pyannote Diarization Pipeline...") if not HF_TOKEN: raise ValueError("Hugging Face token (HF_TOKEN) not set. Please set it as a Space secret.") diarize_pipeline_global = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device) print("Loading translation model (Helsinki-NLP/opus-mt-ta-en)...") try: translation_pipeline_global = pipeline( "translation", model="Helsinki-NLP/opus-mt-ta-en", device=0 if device == "cuda" else -1 ) except Exception as e: print(f"Could not load translation model: {e}") translation_pipeline_global = None # Load models when the Gradio app starts load_all_models() def convert_audio_for_whisper(audio_input): """ Converts Gradio audio input (filepath or (sr, numpy_array)) to a 16kHz mono WAV file that WhisperX expects. Returns the path to the temporary WAV file. """ temp_wav_path = None if isinstance(audio_input, str): # Filepath from gr.Audio(type="filepath") input_filepath = audio_input temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name try: waveform, sample_rate = sf.read(input_filepath) if waveform.ndim > 1: waveform = waveform.mean(axis=1) # Convert to mono if stereo # Resample only if necessary if sample_rate != 16000: print(f"Warning: Audio sample rate is {sample_rate}Hz. Resampling to 16kHz.") # For high-quality resampling, you'd use torchaudio.transforms.Resample # For simple cases, soundfile might handle basic resample on write, # or WhisperX's load_audio does its own internal resampling. # Explicitly loading/resampling here for robustness. from torchaudio.transforms import Resample waveform_tensor = torch.from_numpy(waveform).float() resampler = Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform_tensor).numpy() sample_rate = 16000 # Update sample rate after resampling sf.write(temp_wav_path, waveform, 16000, format='WAV', subtype='PCM_16') return temp_wav_path except Exception as e: print(f"Error converting uploaded audio: {e}") return None elif isinstance(audio_input, tuple): # (sr, numpy_array) from gr.Audio(type="numpy") or microphone sample_rate, numpy_array = audio_input # Ensure it's mono if numpy_array.ndim > 1: numpy_array = numpy_array.mean(axis=1) # Normalize to float32 if not already (soundfile expects this) if numpy_array.dtype != np.float32: numpy_array = numpy_array.astype(np.float32) / np.max(np.abs(numpy_array)) # Resample only if necessary for microphone input as well if sample_rate != 16000: print(f"Warning: Microphone audio sample rate is {sample_rate}Hz. Resampling to 16kHz.") from torchaudio.transforms import Resample waveform_tensor = torch.from_numpy(numpy_array).float() resampler = Resample(orig_freq=sample_rate, new_freq=16000) numpy_array = resampler(waveform_tensor).numpy() sample_rate = 16000 # Update sample rate after resampling temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name try: sf.write(temp_wav_path, numpy_array, 16000, format='WAV', subtype='PCM_16') # Always write at 16kHz return temp_wav_path except Exception as e: print(f"Error writing microphone audio to temp file: {e}") return None return None def process_audio_for_web(audio_input): """ Processes an audio input (from upload or microphone) for speaker diarization, transcription, and translation. """ if audio_input is None: return "Please upload an audio file or record from microphone.", "", "", None audio_file_path = convert_audio_for_whisper(audio_input) if not audio_file_path: return "Error: Could not process audio input. Please ensure it's a valid audio format.", "", "", None print(f"Processing audio from temp file: {audio_file_path}") try: audio = whisperx.load_audio(audio_file_path) # 1. Transcribe print("Transcribing audio...") transcription_result = whisper_model_global.transcribe(audio, batch_size=1) detected_language = transcription_result["language"] print(f"Detected overall language: {detected_language}") # 2. Align print("Aligning transcription with audio...") align_model_local = None # Initialize to None to prevent UnboundLocalError in outer except try: # Load the alignment model based on the detected language # The 'device' parameter is passed here, not to whisperx.align align_model_local, metadata = whisperx.load_align_model(language_code=detected_language, device=device) except Exception as e: # Handle cases where the alignment model for the detected language cannot be loaded print(f"Error loading alignment model for language '{detected_language}': {e}") import traceback print(traceback.format_exc()) # Provide a user-friendly message, possibly suggesting supported languages return f"Error: Could not load alignment model for language '{detected_language}'. Alignment is typically supported for English, French, German, Spanish, Italian, Japanese, Chinese, Dutch, and Portuguese. Details: {e}", "", "", None # Perform alignment using the loaded model # Removed 'device' from here as the model itself is already on the correct device transcription_result = whisperx.align(transcription_result["segments"], align_model_local, audio, device, return_char_alignments=False) # Removed the duplicate whisperx.align call and 'del align_model_local' # as it can cause issues if an error occurs later. gc.collect() if device == "cuda": torch.cuda.empty_cache() # 3. Diarize print("Performing speaker diarization...") diarize_segments = diarize_pipeline_global(audio_file_path) final_result = whisperx.assign_word_speakers(diarize_segments, transcription_result) speaker_transcripts_raw = {} # Prepare for display in diarized_transcription_output diarized_display_lines = [] for segment in final_result["segments"]: speaker_id = segment.get("speaker", "UNKNOWN_SPEAKER") text = segment["text"].strip() start = segment["start"] end = segment["end"] if speaker_id not in speaker_transcripts_raw: speaker_transcripts_raw[speaker_id] = [] speaker_transcripts_raw[speaker_id].append({ "start": start, "end": end, "text": text }) diarized_display_lines.append(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker_id}: {text}") full_diarized_text_str = "\n".join(diarized_display_lines) # 4. Translate translated_display_lines = [] if translation_pipeline_global: translated_speaker_data = {} # To hold translated segments per speaker for speaker, segments in speaker_transcripts_raw.items(): translated_speaker_data[speaker] = [] # Initialize for current speaker translated_display_lines.append(f"\n--- Speaker {speaker} (Original & Translated) ---") for seg in segments: original_text = seg['text'] translated_text_output = original_text is_tamil_char_present = any(ord(char) > 0x0B80 and ord(char) < 0x0BFF for char in original_text) if original_text and (detected_language == 'ta' or is_tamil_char_present): try: translated_result = translation_pipeline_global(original_text, src_lang="ta", tgt_lang="en") translated_text_output = translated_result[0]['translation_text'] except Exception as e: print(f"Error translating segment for speaker {speaker}: '{original_text}'. Error: {e}. Keeping original text.") translated_speaker_data[speaker].append({ "start": seg['start'], "end": seg['end'], "original_text": original_text, "translated_text": translated_text_output }) translated_display_lines.append(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {original_text}") translated_display_lines.append(f" Translated: {translated_text_output}") translated_output_str = "\n".join(translated_display_lines) else: translated_output_str = "Translation model not loaded. Skipping translation." # Create a temporary file for download output_filename = tempfile.NamedTemporaryFile(suffix=".txt", delete=False).name with open(output_filename, "w", encoding="utf-8") as f: f.write("--- Speaker-wise Original Transcription ---\n\n") # Write original transcription per speaker for speaker, segments in speaker_transcripts_raw.items(): f.write(f"\n### Speaker {speaker} ###\n") for seg in segments: f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text']}\n") f.write("\n\n--- Speaker-wise Translated Transcription (to English) ---\n\n") # Write translated transcription per speaker if translation_pipeline_global and 'translated_speaker_data' in locals(): for speaker, segments in translated_speaker_data.items(): f.write(f"\n### Speaker {speaker} ###\n") for seg in segments: f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {seg['original_text']}\n") f.write(f" Translated: {seg['translated_text']}\n") else: f.write("Translation output not available or translation model not loaded.\n") f.write(f"\n\nOverall Detected Language: {detected_language}") # Clean up the temporary audio file os.unlink(audio_file_path) return full_diarized_text_str, translated_output_str, f"Detected overall language: {detected_language}", output_filename except Exception as e: import traceback error_message = f"An error occurred: {e}\n{traceback.format_exc()}" print(error_message) # Clean up temp audio file even on error if audio_file_path and os.path.exists(audio_file_path): os.unlink(audio_file_path) return error_message, "", "", None # --- Gradio Interface --- with gr.Blocks(title="Language-Agnostic Speaker Diarization, Transcription, and Translation") as demo: gr.Markdown( """ # Language-Agnostic Speaker Diarization, Transcription, and Translation Upload an audio file (WAV, MP3, etc.) or record directly from your microphone. The system will identify speakers, transcribe their speech (in detected language), and provide an English translation for relevant segments. """ ) with gr.Row(): audio_input = gr.Audio( type="filepath", sources=["upload", "microphone"], label="Upload Audio File or Record from Microphone" ) with gr.Row(): process_button = gr.Button("Process Audio", variant="primary") with gr.Column(): detected_language_output = gr.Textbox(label="Detected Overall Language") # Diarized Transcription will still be chronological with speaker labels diarized_transcription_output = gr.Textbox(label="Diarized Transcription (Chronological with Speaker Labels)", lines=10, interactive=False) # Translated transcription will now be clearly separated by speaker translated_transcription_output = gr.Textbox(label="Translated Transcription (to English, per Speaker)", lines=10, interactive=False) download_button = gr.File(label="Download Transcription (.txt)", interactive=False, visible=False) process_button.click( fn=process_audio_for_web, inputs=audio_input, outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button] ) gr.Examples( [ # Add paths to your example audio files here. # These files must be present in your Hugging Face Space repository. # For example, if you have 'sample_two_speakers.wav' in your repo: # "sample_two_speakers.wav" ], inputs=audio_input, outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button], fn=process_audio_for_web, cache_examples=False ) demo.launch()