srt / app.py
xTHExBEASTx's picture
Update app.py
0268049 verified
import gradio as gr
import whisper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import srt
import torch
import os
from datetime import timedelta
import subprocess
import re
# --- Configuration ---
# Translation Model (NLLB)
TRANSLATION_MODEL = "facebook/nllb-200-distilled-1.3B"
# Whisper Model Size: "medium" is the best balance for CPU.
# You can change to "large" or "large-v3" but it will be 2x slower.
WHISPER_MODEL_SIZE = "medium"
print("Loading Models...")
# --- Load Translation Model (NLLB) ---
tokenizer_nllb = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
model_nllb = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL)
# --- Load Audio Model (Official OpenAI Whisper) ---
# This downloads the model to the container
print(f"Loading Whisper '{WHISPER_MODEL_SIZE}' model...")
whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cpu")
print("Models Loaded Successfully!")
# ---------------------------------------------------------
# Helper: Extract Audio
# ---------------------------------------------------------
def extract_audio(video_path):
output_audio_path = "temp_audio.mp3"
if os.path.exists(output_audio_path):
os.remove(output_audio_path)
# Simple FFMPEG extraction
command = [
"ffmpeg", "-i", video_path,
"-vn", "-acodec", "libmp3lame",
"-y", output_audio_path
]
subprocess.run(command, check=True)
return output_audio_path
# ---------------------------------------------------------
# Helper: VTT Converter (For Browser Preview)
# ---------------------------------------------------------
def srt_to_vtt(srt_path):
"""Converts SRT to VTT format for the HTML5 video player."""
vtt_path = srt_path.replace(".srt", ".vtt")
with open(srt_path, 'r', encoding='utf-8') as f:
content = f.read()
vtt_content = "WEBVTT\n\n"
# Regex to convert SRT comma timestamps to VTT dot timestamps
vtt_content += re.sub(r'(\d{2}:\d{2}:\d{2}),(\d{3})', r'\1.\2', content)
with open(vtt_path, 'w', encoding='utf-8') as f:
f.write(vtt_content)
return vtt_path
# ---------------------------------------------------------
# Logic 1: Video to SRT (Using Native Whisper)
# ---------------------------------------------------------
def video_to_srt(video_path, progress=gr.Progress()):
if video_path is None: return None, None
# 1. Extract Audio
progress(0.1, desc="Extracting Audio...")
try:
audio_path = extract_audio(video_path)
except Exception as e:
return None, f"Error: {str(e)}"
# 2. Transcribe using Native Whisper
progress(0.3, desc=f"Transcribing with Whisper {WHISPER_MODEL_SIZE}...")
# The native transcribe function handles segmentation automatically!
result = whisper_model.transcribe(audio_path, language="en")
# 3. Format to SRT
progress(0.8, desc="Formatting SRT...")
srt_subtitles = []
for i, segment in enumerate(result["segments"]):
start_seconds = segment["start"]
end_seconds = segment["end"]
text = segment["text"].strip()
srt_subtitles.append(
srt.Subtitle(
index=i+1,
start=timedelta(seconds=start_seconds),
end=timedelta(seconds=end_seconds),
content=text
)
)
srt_path = "generated_captions.srt"
with open(srt_path, 'w', encoding='utf-8') as f:
f.write(srt.compose(srt_subtitles))
# 4. Create Preview
vtt_path = srt_to_vtt(srt_path)
html_preview = f"""
<h3>Video Preview</h3>
<video controls width="100%" height="400px" style="background:black">
<source src="/file={video_path}" type="video/mp4">
<track kind="captions" src="/file={vtt_path}" srclang="en" label="English" default>
Your browser does not support the video tag.
</video>
"""
return srt_path, html_preview
# ---------------------------------------------------------
# Logic 2: Translation (NLLB)
# ---------------------------------------------------------
def batch_translate(texts, src_lang, tgt_lang, batch_size=8):
results = []
tokenizer_nllb.src_lang = src_lang
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
inputs = tokenizer_nllb(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
forced_bos_token_id = tokenizer_nllb.convert_tokens_to_ids(tgt_lang)
with torch.no_grad():
generated_tokens = model_nllb.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512)
results.extend(tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True))
return results
def process_translation(filepath, src_lang_code, tgt_lang_code):
if filepath is None: return None
try:
with open(filepath, 'r', encoding='utf-8') as f:
subtitles = list(srt.parse(f.read()))
except Exception as e:
return f"Error: {str(e)}"
texts = [sub.content for sub in subtitles]
translated = batch_translate(texts, src_lang_code, tgt_lang_code)
for sub, trans in zip(subtitles, translated):
sub.content = trans
out_path = "translated_subtitles.srt"
with open(out_path, 'w', encoding='utf-8') as f:
f.write(srt.compose(subtitles))
return out_path
# ---------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------
with gr.Blocks(title="SRT Master Tool") as demo:
gr.Markdown(f"# 🎬 Auto Subtitle (Whisper {WHISPER_MODEL_SIZE}) & Translator")
with gr.Tabs():
# --- TAB 1 ---
with gr.TabItem("Step 1: Video to SRT"):
gr.Markdown("### 1. Upload Video -> 2. Check Preview -> 3. Download SRT")
with gr.Row():
video_input = gr.Video(label="Upload Video", sources=["upload"])
with gr.Column():
preview_output = gr.HTML(label="Preview Player")
srt_output_gen = gr.File(label="Download Generated SRT")
btn1 = gr.Button("Generate SRT & Preview", variant="primary")
btn1.click(video_to_srt, inputs=video_input, outputs=[srt_output_gen, preview_output])
# --- TAB 2 ---
with gr.TabItem("Step 2: Translate SRT"):
gr.Markdown("### Translate Subtitles to Arabic")
with gr.Row():
srt_input = gr.File(label="Upload SRT")
with gr.Column():
src_l = gr.Dropdown(["eng_Latn", "fra_Latn"], label="From", value="eng_Latn")
tgt_l = gr.Dropdown(["arb_Arab", "arz_Arab"], label="To", value="arb_Arab")
srt_output_trans = gr.File(label="Translated SRT")
btn2 = gr.Button("Translate", variant="primary")
btn2.click(process_translation, inputs=[srt_input, src_l, tgt_l], outputs=srt_output_trans)
if __name__ == "__main__":
demo.launch()