Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import torch | |
| import ffmpeg | |
| import json | |
| import os | |
| import uuid | |
| import tempfile | |
| import gc | |
| from io import BytesIO | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Optional, Tuple | |
| import whisperx | |
| import spaces | |
| import numpy as np | |
| import soundfile as sf | |
| from deep_translator import GoogleTranslator | |
| # Load Google language codes | |
| with open('google_lang_codes.json', 'r') as f: | |
| google_lang_codes = json.load(f) | |
| # ============================================================================ | |
| # GLOBAL MODEL CACHE - Load once, reuse forever | |
| # ============================================================================ | |
| _whisper_model = None | |
| _align_models = {} # Cache align models by language | |
| _diarize_model = None | |
| def get_whisper_model(device: str, compute_type: str): | |
| """Get cached WhisperX model (large-v3-turbo for speed).""" | |
| global _whisper_model | |
| if _whisper_model is None: | |
| print("[DEBUG] Loading WhisperX model (large-v3-turbo)...") | |
| _whisper_model = whisperx.load_model( | |
| "large-v3-turbo", # Faster than large-v3 with similar quality | |
| device, | |
| compute_type=compute_type | |
| ) | |
| print("[DEBUG] WhisperX model loaded successfully") | |
| return _whisper_model | |
| def get_align_model(language_code: str, device: str): | |
| """Get cached alignment model for a specific language.""" | |
| global _align_models | |
| if language_code not in _align_models: | |
| print(f"[DEBUG] Loading alignment model for language: {language_code}") | |
| model, metadata = whisperx.load_align_model( | |
| language_code=language_code, | |
| device=device, | |
| model_name="WAV2VEC2_ASR_LARGE_LV60K_960H" | |
| ) | |
| _align_models[language_code] = (model, metadata) | |
| print(f"[DEBUG] Alignment model for {language_code} loaded successfully") | |
| return _align_models[language_code] | |
| # ============================================================================ | |
| # Helper Functions | |
| # ============================================================================ | |
| def ffmpeg_read(input_data_bytes: bytes, sampling_rate: int) -> np.ndarray: | |
| """Convert audio bytes to numpy array using ffmpeg.""" | |
| process = ( | |
| ffmpeg.input('pipe:0') | |
| .output('pipe:1', format='wav', acodec='pcm_s16le', ar=sampling_rate) | |
| .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True) | |
| ) | |
| out, _ = process.communicate(input=input_data_bytes) | |
| audio_array = np.frombuffer(out, np.int16) | |
| return audio_array | |
| def format_timestamp(seconds: float) -> str: | |
| """Convert seconds to SRT timestamp format.""" | |
| millis = int((seconds - int(seconds)) * 1000) | |
| hours, remainder = divmod(int(seconds), 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| return f"{hours:02}:{minutes:02}:{seconds:02},{millis:03}" | |
| def translate_segment_text(text: str, target_language_code: str) -> str: | |
| """Translate a single text segment.""" | |
| if not text.strip(): | |
| return text | |
| try: | |
| return GoogleTranslator(source='auto', target=target_language_code).translate(text.strip()) | |
| except Exception as e: | |
| print(f"[WARNING] Translation failed for '{text[:50]}...': {e}") | |
| return text | |
| def translate_segments_parallel(segments: list, target_language_code: str) -> list: | |
| """Translate multiple segments in parallel using ThreadPoolExecutor.""" | |
| texts = [s['text'].strip() for s in segments] | |
| print(f"[DEBUG] Translating {len(texts)} segments in parallel...") | |
| with ThreadPoolExecutor(max_workers=8) as executor: | |
| translated = list(executor.map( | |
| lambda t: translate_segment_text(t, target_language_code), | |
| texts | |
| )) | |
| # Update segments with translated text | |
| for i, segment in enumerate(segments): | |
| segment['text'] = translated[i] | |
| return segments | |
| def generate_srt(segments: list, filepath: str): | |
| """Generate SRT file from segments.""" | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| for i, segment in enumerate(segments, 1): | |
| start_time = format_timestamp(segment['start']) | |
| end_time = format_timestamp(segment['end']) | |
| f.write(f"{i}\n") | |
| f.write(f"{start_time} --> {end_time}\n") | |
| f.write(f"{segment['text'].strip()}\n\n") | |
| # ============================================================================ | |
| # Main Processing Functions | |
| # ============================================================================ | |
| def transcribe_and_align( | |
| audio_path: str, | |
| device: str, | |
| compute_type: str, | |
| progress: gr.Progress | |
| ) -> Tuple[list, str]: | |
| """ | |
| Transcribe audio and align timestamps. | |
| Returns (segments, detected_language). | |
| """ | |
| progress(0.3, desc="Transcribing audio...") | |
| # Load audio | |
| audio = whisperx.load_audio(audio_path) | |
| # Get cached whisper model | |
| whisper_model = get_whisper_model(device, compute_type) | |
| # Transcribe (WhisperX detects language automatically) | |
| batch_size = 16 | |
| result = whisper_model.transcribe(audio, batch_size=batch_size) | |
| # Get detected language from transcription | |
| detected_language = result.get("language", "en") | |
| print(f"[DEBUG] Detected language: {detected_language}") | |
| if not result.get("segments"): | |
| raise ValueError("No segments found in transcription") | |
| print(f"[DEBUG] Transcribed {len(result['segments'])} segments") | |
| progress(0.5, desc="Aligning timestamps...") | |
| # Get cached align model for detected language | |
| align_model, align_metadata = get_align_model(detected_language, device) | |
| # Align timestamps | |
| result = whisperx.align( | |
| result["segments"], | |
| align_model, | |
| align_metadata, | |
| audio, | |
| device, | |
| return_char_alignments=False | |
| ) | |
| print(f"[DEBUG] Aligned {len(result['segments'])} segments") | |
| # Cleanup | |
| del audio | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return result["segments"], detected_language | |
| def diarize_audio( | |
| audio_path: str, | |
| segments: list, | |
| hf_token: Optional[str], | |
| device: str, | |
| progress: gr.Progress | |
| ) -> list: | |
| """Identify speakers in audio (optional feature).""" | |
| if not hf_token: | |
| print("[DEBUG] No HF token provided, skipping diarization") | |
| return segments | |
| progress(0.6, desc="Identifying speakers...") | |
| global _diarize_model | |
| if _diarize_model is None: | |
| print("[DEBUG] Loading diarization model...") | |
| _diarize_model = whisperx.DiarizationPipeline( | |
| use_auth_token=hf_token, | |
| device=device | |
| ) | |
| try: | |
| audio = whisperx.load_audio(audio_path) | |
| diarize_segments = _diarize_model(audio) | |
| result = whisperx.assign_word_speakers(diarize_segments, {"segments": segments}) | |
| print(f"[DEBUG] Diarization complete, found speakers") | |
| return result["segments"] | |
| except Exception as e: | |
| print(f"[WARNING] Diarization failed: {e}") | |
| return segments | |
| # ============================================================================ | |
| # Main Video Processing Function | |
| # ============================================================================ | |
| def process_video( | |
| video_path: str, | |
| target_language: str, | |
| translate_video: bool, | |
| enable_diarization: bool, | |
| progress: gr.Progress = gr.Progress() | |
| ): | |
| """Main function to process video with transcription and optional translation.""" | |
| print("=" * 60) | |
| print("VIDEO PROCESSING STARTED") | |
| print("=" * 60) | |
| if not video_path: | |
| raise gr.Error("Please upload a video file") | |
| # Get target language code | |
| target_language_code = google_lang_codes.get(target_language, "en") | |
| print(f"[DEBUG] Target language: {target_language} ({target_language_code})") | |
| # Setup device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| compute_type = "float16" if device == "cuda" else "int8" | |
| print(f"[DEBUG] Device: {device}, Compute type: {compute_type}") | |
| # Generate unique ID for this job | |
| job_id = uuid.uuid4() | |
| progress(0.1, desc="Extracting audio from video...") | |
| # Extract audio using context manager | |
| audio_file = f"/tmp/{job_id}_audio.wav" | |
| try: | |
| print(f"[DEBUG] Extracting audio to {audio_file}") | |
| ffmpeg.input(video_path).output(audio_file, ac=1, ar=16000).run( | |
| quiet=True, | |
| overwrite_output=True | |
| ) | |
| except ffmpeg.Error as e: | |
| raise gr.Error(f"Failed to extract audio: {e.stderr.decode()}") | |
| progress(0.2, desc="Loading audio...") | |
| # Transcribe and align | |
| segments, detected_language = transcribe_and_align( | |
| audio_file, | |
| device, | |
| compute_type, | |
| progress | |
| ) | |
| # Optional: Diarization | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if enable_diarization and hf_token: | |
| segments = diarize_audio(audio_file, segments, hf_token, device, progress) | |
| # Translate if requested | |
| if translate_video: | |
| progress(0.7, desc=f"Translating to {target_language}...") | |
| print(f"[DEBUG] Translating {len(segments)} segments to {target_language_code}") | |
| segments = translate_segments_parallel(segments, target_language_code) | |
| progress(0.8, desc="Generating subtitles...") | |
| # Generate SRT file | |
| srt_file = f"/tmp/{job_id}_subtitles.srt" | |
| generate_srt(segments, srt_file) | |
| print(f"[DEBUG] Generated SRT file: {srt_file}") | |
| # Generate plain text transcription | |
| transcription_text = "\n".join([s['text'].strip() for s in segments]) | |
| progress(0.9, desc="Embedding subtitles into video...") | |
| # Embed subtitles | |
| output_video = f"/tmp/{job_id}_output.mp4" | |
| # Choose subtitle style based on language | |
| if target_language_code in ['ja', 'zh-cn', 'zh-tw', 'ko']: | |
| subtitle_style = "FontName=Noto Sans CJK JP,PrimaryColour=&H00FFFFFF,OutlineColour=&H000000,BackColour=&H80000000,BorderStyle=3,Outline=2,Shadow=1" | |
| else: | |
| subtitle_style = "FontName=Arial,PrimaryColour=&H00FFFFFF,OutlineColour=&H000000,BackColour=&H80000000,BorderStyle=3,Outline=2,Shadow=1" | |
| try: | |
| ( | |
| ffmpeg | |
| .input(video_path) | |
| .output( | |
| output_video, | |
| vf=f"subtitles={srt_file}:force_style='{subtitle_style}'", | |
| codec="libx264", | |
| preset="fast" | |
| ) | |
| .run(quiet=True, overwrite_output=True) | |
| ) | |
| print(f"[DEBUG] Output video created: {output_video}") | |
| except ffmpeg.Error as e: | |
| raise gr.Error(f"Failed to embed subtitles: {e.stderr.decode()}") | |
| # Cleanup temporary files | |
| try: | |
| os.unlink(audio_file) | |
| os.unlink(srt_file) | |
| except: | |
| pass | |
| progress(1.0, desc="Complete!") | |
| print("=" * 60) | |
| print("VIDEO PROCESSING COMPLETE") | |
| print("=" * 60) | |
| return output_video, srt_file, transcription_text | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| with gr.Blocks(title="Video Transcription & Translation") as demo: | |
| gr.Markdown(""" | |
| # 🎬 Video Transcription & Translation | |
| Powered by **WhisperX (large-v3-turbo)** for fast, accurate transcription with word-level timestamps. | |
| Developed by [@artificialguybr](https://twitter.com/artificialguybr) • [Video Dubbing](https://huggingface.co/spaces/artificialguybr/video-dubbing) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| video_input = gr.Video( | |
| label="Upload Video (max 15 min)", | |
| include_audio=True | |
| ) | |
| with gr.Row(): | |
| target_language = gr.Dropdown( | |
| choices=list(google_lang_codes.keys()), | |
| label="Target Language", | |
| value="English" | |
| ) | |
| translate_checkbox = gr.Checkbox( | |
| label="Translate Subtitles", | |
| value=True, | |
| info="Translate to target language" | |
| ) | |
| diarization_checkbox = gr.Checkbox( | |
| label="Speaker Diarization", | |
| value=False, | |
| info="Identify different speakers (requires HF_TOKEN)" | |
| ) | |
| process_btn = gr.Button("🚀 Process Video", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_video = gr.Video(label="Output Video") | |
| with gr.Row(): | |
| srt_file = gr.File(label="Download .SRT") | |
| transcription_text = gr.Textbox( | |
| label="Transcription", | |
| lines=10, | |
| max_lines=20, | |
| interactive=False | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Notes:** | |
| - Video limit: 15 minutes | |
| - Uses WhisperX large-v3-turbo for fast transcription | |
| - Automatic language detection | |
| - Parallel translation for speed | |
| - Speaker diarization optional (set HF_TOKEN secret) | |
| """) | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, target_language, translate_checkbox, diarization_checkbox], | |
| outputs=[output_video, srt_file, transcription_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |