Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import whisper | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import srt | |
| import torch | |
| import os | |
| from datetime import timedelta | |
| import subprocess | |
| import re | |
| # --- Configuration --- | |
| # Translation Model (NLLB) | |
| TRANSLATION_MODEL = "facebook/nllb-200-distilled-1.3B" | |
| # Whisper Model Size: "medium" is the best balance for CPU. | |
| # You can change to "large" or "large-v3" but it will be 2x slower. | |
| WHISPER_MODEL_SIZE = "medium" | |
| print("Loading Models...") | |
| # --- Load Translation Model (NLLB) --- | |
| tokenizer_nllb = AutoTokenizer.from_pretrained(TRANSLATION_MODEL) | |
| model_nllb = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL) | |
| # --- Load Audio Model (Official OpenAI Whisper) --- | |
| # This downloads the model to the container | |
| print(f"Loading Whisper '{WHISPER_MODEL_SIZE}' model...") | |
| whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cpu") | |
| print("Models Loaded Successfully!") | |
| # --------------------------------------------------------- | |
| # Helper: Extract Audio | |
| # --------------------------------------------------------- | |
| def extract_audio(video_path): | |
| output_audio_path = "temp_audio.mp3" | |
| if os.path.exists(output_audio_path): | |
| os.remove(output_audio_path) | |
| # Simple FFMPEG extraction | |
| command = [ | |
| "ffmpeg", "-i", video_path, | |
| "-vn", "-acodec", "libmp3lame", | |
| "-y", output_audio_path | |
| ] | |
| subprocess.run(command, check=True) | |
| return output_audio_path | |
| # --------------------------------------------------------- | |
| # Helper: VTT Converter (For Browser Preview) | |
| # --------------------------------------------------------- | |
| def srt_to_vtt(srt_path): | |
| """Converts SRT to VTT format for the HTML5 video player.""" | |
| vtt_path = srt_path.replace(".srt", ".vtt") | |
| with open(srt_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| vtt_content = "WEBVTT\n\n" | |
| # Regex to convert SRT comma timestamps to VTT dot timestamps | |
| vtt_content += re.sub(r'(\d{2}:\d{2}:\d{2}),(\d{3})', r'\1.\2', content) | |
| with open(vtt_path, 'w', encoding='utf-8') as f: | |
| f.write(vtt_content) | |
| return vtt_path | |
| # --------------------------------------------------------- | |
| # Logic 1: Video to SRT (Using Native Whisper) | |
| # --------------------------------------------------------- | |
| def video_to_srt(video_path, progress=gr.Progress()): | |
| if video_path is None: return None, None | |
| # 1. Extract Audio | |
| progress(0.1, desc="Extracting Audio...") | |
| try: | |
| audio_path = extract_audio(video_path) | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # 2. Transcribe using Native Whisper | |
| progress(0.3, desc=f"Transcribing with Whisper {WHISPER_MODEL_SIZE}...") | |
| # The native transcribe function handles segmentation automatically! | |
| result = whisper_model.transcribe(audio_path, language="en") | |
| # 3. Format to SRT | |
| progress(0.8, desc="Formatting SRT...") | |
| srt_subtitles = [] | |
| for i, segment in enumerate(result["segments"]): | |
| start_seconds = segment["start"] | |
| end_seconds = segment["end"] | |
| text = segment["text"].strip() | |
| srt_subtitles.append( | |
| srt.Subtitle( | |
| index=i+1, | |
| start=timedelta(seconds=start_seconds), | |
| end=timedelta(seconds=end_seconds), | |
| content=text | |
| ) | |
| ) | |
| srt_path = "generated_captions.srt" | |
| with open(srt_path, 'w', encoding='utf-8') as f: | |
| f.write(srt.compose(srt_subtitles)) | |
| # 4. Create Preview | |
| vtt_path = srt_to_vtt(srt_path) | |
| html_preview = f""" | |
| <h3>Video Preview</h3> | |
| <video controls width="100%" height="400px" style="background:black"> | |
| <source src="/file={video_path}" type="video/mp4"> | |
| <track kind="captions" src="/file={vtt_path}" srclang="en" label="English" default> | |
| Your browser does not support the video tag. | |
| </video> | |
| """ | |
| return srt_path, html_preview | |
| # --------------------------------------------------------- | |
| # Logic 2: Translation (NLLB) | |
| # --------------------------------------------------------- | |
| def batch_translate(texts, src_lang, tgt_lang, batch_size=8): | |
| results = [] | |
| tokenizer_nllb.src_lang = src_lang | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| inputs = tokenizer_nllb(batch, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| forced_bos_token_id = tokenizer_nllb.convert_tokens_to_ids(tgt_lang) | |
| with torch.no_grad(): | |
| generated_tokens = model_nllb.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512) | |
| results.extend(tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True)) | |
| return results | |
| def process_translation(filepath, src_lang_code, tgt_lang_code): | |
| if filepath is None: return None | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| subtitles = list(srt.parse(f.read())) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| texts = [sub.content for sub in subtitles] | |
| translated = batch_translate(texts, src_lang_code, tgt_lang_code) | |
| for sub, trans in zip(subtitles, translated): | |
| sub.content = trans | |
| out_path = "translated_subtitles.srt" | |
| with open(out_path, 'w', encoding='utf-8') as f: | |
| f.write(srt.compose(subtitles)) | |
| return out_path | |
| # --------------------------------------------------------- | |
| # Gradio Interface | |
| # --------------------------------------------------------- | |
| with gr.Blocks(title="SRT Master Tool") as demo: | |
| gr.Markdown(f"# 🎬 Auto Subtitle (Whisper {WHISPER_MODEL_SIZE}) & Translator") | |
| with gr.Tabs(): | |
| # --- TAB 1 --- | |
| with gr.TabItem("Step 1: Video to SRT"): | |
| gr.Markdown("### 1. Upload Video -> 2. Check Preview -> 3. Download SRT") | |
| with gr.Row(): | |
| video_input = gr.Video(label="Upload Video", sources=["upload"]) | |
| with gr.Column(): | |
| preview_output = gr.HTML(label="Preview Player") | |
| srt_output_gen = gr.File(label="Download Generated SRT") | |
| btn1 = gr.Button("Generate SRT & Preview", variant="primary") | |
| btn1.click(video_to_srt, inputs=video_input, outputs=[srt_output_gen, preview_output]) | |
| # --- TAB 2 --- | |
| with gr.TabItem("Step 2: Translate SRT"): | |
| gr.Markdown("### Translate Subtitles to Arabic") | |
| with gr.Row(): | |
| srt_input = gr.File(label="Upload SRT") | |
| with gr.Column(): | |
| src_l = gr.Dropdown(["eng_Latn", "fra_Latn"], label="From", value="eng_Latn") | |
| tgt_l = gr.Dropdown(["arb_Arab", "arz_Arab"], label="To", value="arb_Arab") | |
| srt_output_trans = gr.File(label="Translated SRT") | |
| btn2 = gr.Button("Translate", variant="primary") | |
| btn2.click(process_translation, inputs=[srt_input, src_l, tgt_l], outputs=srt_output_trans) | |
| if __name__ == "__main__": | |
| demo.launch() |