Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| import gc | |
| import json | |
| import whisperx | |
| from pyannote.audio import Pipeline | |
| from huggingface_hub import HfFolder | |
| from transformers import pipeline | |
| import numpy as np | |
| import soundfile as sf | |
| import io | |
| import tempfile | |
| # --- Configuration --- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| print("WARNING: HF_TOKEN environment variable not set. Please set it as a Space secret or directly for local testing.") | |
| print("Visit https://huggingface.co/settings/tokens to create one and accept model conditions for pyannote/speaker-diarization, etc.") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| compute_type = "float16" if device == "cuda" else "int8" | |
| whisper_model_size = "medium" # 'large-v2' is best but most resource intensive. | |
| # 'small' or 'medium' are better for free tiers. | |
| # --- Global Models (loaded once) --- | |
| whisper_model_global = None | |
| diarize_pipeline_global = None | |
| translation_pipeline_global = None | |
| def load_all_models(): | |
| global whisper_model_global, diarize_pipeline_global, translation_pipeline_global | |
| print(f"Loading WhisperX model ({whisper_model_size})...") | |
| whisper_model_global = whisperx.load_model(whisper_model_size, device=device, compute_type=compute_type) | |
| print("Loading Pyannote Diarization Pipeline...") | |
| if not HF_TOKEN: | |
| raise ValueError("Hugging Face token (HF_TOKEN) not set. Please set it as a Space secret.") | |
| diarize_pipeline_global = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device) | |
| print("Loading translation model (Helsinki-NLP/opus-mt-ta-en)...") | |
| try: | |
| translation_pipeline_global = pipeline( | |
| "translation", | |
| model="Helsinki-NLP/opus-mt-ta-en", | |
| device=0 if device == "cuda" else -1 | |
| ) | |
| except Exception as e: | |
| print(f"Could not load translation model: {e}") | |
| translation_pipeline_global = None | |
| # Load models when the Gradio app starts | |
| load_all_models() | |
| def convert_audio_for_whisper(audio_input): | |
| """ | |
| Converts Gradio audio input (filepath or (sr, numpy_array)) to a 16kHz mono WAV file | |
| that WhisperX expects. Returns the path to the temporary WAV file. | |
| """ | |
| temp_wav_path = None | |
| if isinstance(audio_input, str): # Filepath from gr.Audio(type="filepath") | |
| input_filepath = audio_input | |
| temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name | |
| try: | |
| waveform, sample_rate = sf.read(input_filepath) | |
| if waveform.ndim > 1: | |
| waveform = waveform.mean(axis=1) # Convert to mono if stereo | |
| # Resample only if necessary | |
| if sample_rate != 16000: | |
| print(f"Warning: Audio sample rate is {sample_rate}Hz. Resampling to 16kHz.") | |
| # For high-quality resampling, you'd use torchaudio.transforms.Resample | |
| # For simple cases, soundfile might handle basic resample on write, | |
| # or WhisperX's load_audio does its own internal resampling. | |
| # Explicitly loading/resampling here for robustness. | |
| from torchaudio.transforms import Resample | |
| waveform_tensor = torch.from_numpy(waveform).float() | |
| resampler = Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(waveform_tensor).numpy() | |
| sample_rate = 16000 # Update sample rate after resampling | |
| sf.write(temp_wav_path, waveform, 16000, format='WAV', subtype='PCM_16') | |
| return temp_wav_path | |
| except Exception as e: | |
| print(f"Error converting uploaded audio: {e}") | |
| return None | |
| elif isinstance(audio_input, tuple): # (sr, numpy_array) from gr.Audio(type="numpy") or microphone | |
| sample_rate, numpy_array = audio_input | |
| # Ensure it's mono | |
| if numpy_array.ndim > 1: | |
| numpy_array = numpy_array.mean(axis=1) | |
| # Normalize to float32 if not already (soundfile expects this) | |
| if numpy_array.dtype != np.float32: | |
| numpy_array = numpy_array.astype(np.float32) / np.max(np.abs(numpy_array)) | |
| # Resample only if necessary for microphone input as well | |
| if sample_rate != 16000: | |
| print(f"Warning: Microphone audio sample rate is {sample_rate}Hz. Resampling to 16kHz.") | |
| from torchaudio.transforms import Resample | |
| waveform_tensor = torch.from_numpy(numpy_array).float() | |
| resampler = Resample(orig_freq=sample_rate, new_freq=16000) | |
| numpy_array = resampler(waveform_tensor).numpy() | |
| sample_rate = 16000 # Update sample rate after resampling | |
| temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name | |
| try: | |
| sf.write(temp_wav_path, numpy_array, 16000, format='WAV', subtype='PCM_16') # Always write at 16kHz | |
| return temp_wav_path | |
| except Exception as e: | |
| print(f"Error writing microphone audio to temp file: {e}") | |
| return None | |
| return None | |
| def process_audio_for_web(audio_input): | |
| """ | |
| Processes an audio input (from upload or microphone) for speaker diarization, | |
| transcription, and translation. | |
| """ | |
| if audio_input is None: | |
| return "Please upload an audio file or record from microphone.", "", "", None | |
| audio_file_path = convert_audio_for_whisper(audio_input) | |
| if not audio_file_path: | |
| return "Error: Could not process audio input. Please ensure it's a valid audio format.", "", "", None | |
| print(f"Processing audio from temp file: {audio_file_path}") | |
| try: | |
| audio = whisperx.load_audio(audio_file_path) | |
| # 1. Transcribe | |
| print("Transcribing audio...") | |
| transcription_result = whisper_model_global.transcribe(audio, batch_size=1) | |
| detected_language = transcription_result["language"] | |
| print(f"Detected overall language: {detected_language}") | |
| # 2. Align | |
| print("Aligning transcription with audio...") | |
| align_model_local = None # Initialize to None to prevent UnboundLocalError in outer except | |
| try: | |
| # Load the alignment model based on the detected language | |
| # The 'device' parameter is passed here, not to whisperx.align | |
| align_model_local, metadata = whisperx.load_align_model(language_code=detected_language, device=device) | |
| except Exception as e: | |
| # Handle cases where the alignment model for the detected language cannot be loaded | |
| print(f"Error loading alignment model for language '{detected_language}': {e}") | |
| import traceback | |
| print(traceback.format_exc()) | |
| # Provide a user-friendly message, possibly suggesting supported languages | |
| return f"Error: Could not load alignment model for language '{detected_language}'. Alignment is typically supported for English, French, German, Spanish, Italian, Japanese, Chinese, Dutch, and Portuguese. Details: {e}", "", "", None | |
| # Perform alignment using the loaded model | |
| # Removed 'device' from here as the model itself is already on the correct device | |
| transcription_result = whisperx.align(transcription_result["segments"], align_model_local, audio, device, return_char_alignments=False) | |
| # Removed the duplicate whisperx.align call and 'del align_model_local' | |
| # as it can cause issues if an error occurs later. | |
| gc.collect() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # 3. Diarize | |
| print("Performing speaker diarization...") | |
| diarize_segments = diarize_pipeline_global(audio_file_path) | |
| final_result = whisperx.assign_word_speakers(diarize_segments, transcription_result) | |
| speaker_transcripts_raw = {} | |
| # Prepare for display in diarized_transcription_output | |
| diarized_display_lines = [] | |
| for segment in final_result["segments"]: | |
| speaker_id = segment.get("speaker", "UNKNOWN_SPEAKER") | |
| text = segment["text"].strip() | |
| start = segment["start"] | |
| end = segment["end"] | |
| if speaker_id not in speaker_transcripts_raw: | |
| speaker_transcripts_raw[speaker_id] = [] | |
| speaker_transcripts_raw[speaker_id].append({ | |
| "start": start, | |
| "end": end, | |
| "text": text | |
| }) | |
| diarized_display_lines.append(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker_id}: {text}") | |
| full_diarized_text_str = "\n".join(diarized_display_lines) | |
| # 4. Translate | |
| translated_display_lines = [] | |
| if translation_pipeline_global: | |
| translated_speaker_data = {} # To hold translated segments per speaker | |
| for speaker, segments in speaker_transcripts_raw.items(): | |
| translated_speaker_data[speaker] = [] # Initialize for current speaker | |
| translated_display_lines.append(f"\n--- Speaker {speaker} (Original & Translated) ---") | |
| for seg in segments: | |
| original_text = seg['text'] | |
| translated_text_output = original_text | |
| is_tamil_char_present = any(ord(char) > 0x0B80 and ord(char) < 0x0BFF for char in original_text) | |
| if original_text and (detected_language == 'ta' or is_tamil_char_present): | |
| try: | |
| translated_result = translation_pipeline_global(original_text, src_lang="ta", tgt_lang="en") | |
| translated_text_output = translated_result[0]['translation_text'] | |
| except Exception as e: | |
| print(f"Error translating segment for speaker {speaker}: '{original_text}'. Error: {e}. Keeping original text.") | |
| translated_speaker_data[speaker].append({ | |
| "start": seg['start'], | |
| "end": seg['end'], | |
| "original_text": original_text, | |
| "translated_text": translated_text_output | |
| }) | |
| translated_display_lines.append(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {original_text}") | |
| translated_display_lines.append(f" Translated: {translated_text_output}") | |
| translated_output_str = "\n".join(translated_display_lines) | |
| else: | |
| translated_output_str = "Translation model not loaded. Skipping translation." | |
| # Create a temporary file for download | |
| output_filename = tempfile.NamedTemporaryFile(suffix=".txt", delete=False).name | |
| with open(output_filename, "w", encoding="utf-8") as f: | |
| f.write("--- Speaker-wise Original Transcription ---\n\n") | |
| # Write original transcription per speaker | |
| for speaker, segments in speaker_transcripts_raw.items(): | |
| f.write(f"\n### Speaker {speaker} ###\n") | |
| for seg in segments: | |
| f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text']}\n") | |
| f.write("\n\n--- Speaker-wise Translated Transcription (to English) ---\n\n") | |
| # Write translated transcription per speaker | |
| if translation_pipeline_global and 'translated_speaker_data' in locals(): | |
| for speaker, segments in translated_speaker_data.items(): | |
| f.write(f"\n### Speaker {speaker} ###\n") | |
| for seg in segments: | |
| f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {seg['original_text']}\n") | |
| f.write(f" Translated: {seg['translated_text']}\n") | |
| else: | |
| f.write("Translation output not available or translation model not loaded.\n") | |
| f.write(f"\n\nOverall Detected Language: {detected_language}") | |
| # Clean up the temporary audio file | |
| os.unlink(audio_file_path) | |
| return full_diarized_text_str, translated_output_str, f"Detected overall language: {detected_language}", output_filename | |
| except Exception as e: | |
| import traceback | |
| error_message = f"An error occurred: {e}\n{traceback.format_exc()}" | |
| print(error_message) | |
| # Clean up temp audio file even on error | |
| if audio_file_path and os.path.exists(audio_file_path): | |
| os.unlink(audio_file_path) | |
| return error_message, "", "", None | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Language-Agnostic Speaker Diarization, Transcription, and Translation") as demo: | |
| gr.Markdown( | |
| """ | |
| # Language-Agnostic Speaker Diarization, Transcription, and Translation | |
| Upload an audio file (WAV, MP3, etc.) or record directly from your microphone. | |
| The system will identify speakers, transcribe their speech (in detected language), | |
| and provide an English translation for relevant segments. | |
| """ | |
| ) | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| sources=["upload", "microphone"], | |
| label="Upload Audio File or Record from Microphone" | |
| ) | |
| with gr.Row(): | |
| process_button = gr.Button("Process Audio", variant="primary") | |
| with gr.Column(): | |
| detected_language_output = gr.Textbox(label="Detected Overall Language") | |
| # Diarized Transcription will still be chronological with speaker labels | |
| diarized_transcription_output = gr.Textbox(label="Diarized Transcription (Chronological with Speaker Labels)", lines=10, interactive=False) | |
| # Translated transcription will now be clearly separated by speaker | |
| translated_transcription_output = gr.Textbox(label="Translated Transcription (to English, per Speaker)", lines=10, interactive=False) | |
| download_button = gr.File(label="Download Transcription (.txt)", interactive=False, visible=False) | |
| process_button.click( | |
| fn=process_audio_for_web, | |
| inputs=audio_input, | |
| outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button] | |
| ) | |
| gr.Examples( | |
| [ | |
| # Add paths to your example audio files here. | |
| # These files must be present in your Hugging Face Space repository. | |
| # For example, if you have 'sample_two_speakers.wav' in your repo: | |
| # "sample_two_speakers.wav" | |
| ], | |
| inputs=audio_input, | |
| outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button], | |
| fn=process_audio_for_web, | |
| cache_examples=False | |
| ) | |
| demo.launch() |