Spaces:
Sleeping
Sleeping
File size: 7,118 Bytes
87aabc7 0268049 bf624b8 82da1ff bf624b8 b1f04ee 0a1e2fb b4a8b32 87aabc7 82d594e 0268049 b1f04ee 0268049 87aabc7 82d594e bf624b8 0268049 b1f04ee 0268049 b1f04ee 82d594e b1f04ee 0268049 82d594e 0a1e2fb 0268049 82d594e 0a1e2fb b4a8b32 0268049 b4a8b32 0268049 0a1e2fb 0268049 0a1e2fb 0268049 0a1e2fb 0268049 b4a8b32 0268049 b4a8b32 0268049 b4a8b32 0268049 b4a8b32 0268049 b4a8b32 0268049 b4a8b32 0268049 b4a8b32 95e5f5e b4a8b32 95e5f5e b4a8b32 0268049 b4a8b32 95e5f5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | import gradio as gr
import whisper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import srt
import torch
import os
from datetime import timedelta
import subprocess
import re
# --- Configuration ---
# Translation Model (NLLB)
TRANSLATION_MODEL = "facebook/nllb-200-distilled-1.3B"
# Whisper Model Size: "medium" is the best balance for CPU.
# You can change to "large" or "large-v3" but it will be 2x slower.
WHISPER_MODEL_SIZE = "medium"
print("Loading Models...")
# --- Load Translation Model (NLLB) ---
tokenizer_nllb = AutoTokenizer.from_pretrained(TRANSLATION_MODEL)
model_nllb = AutoModelForSeq2SeqLM.from_pretrained(TRANSLATION_MODEL)
# --- Load Audio Model (Official OpenAI Whisper) ---
# This downloads the model to the container
print(f"Loading Whisper '{WHISPER_MODEL_SIZE}' model...")
whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cpu")
print("Models Loaded Successfully!")
# ---------------------------------------------------------
# Helper: Extract Audio
# ---------------------------------------------------------
def extract_audio(video_path):
output_audio_path = "temp_audio.mp3"
if os.path.exists(output_audio_path):
os.remove(output_audio_path)
# Simple FFMPEG extraction
command = [
"ffmpeg", "-i", video_path,
"-vn", "-acodec", "libmp3lame",
"-y", output_audio_path
]
subprocess.run(command, check=True)
return output_audio_path
# ---------------------------------------------------------
# Helper: VTT Converter (For Browser Preview)
# ---------------------------------------------------------
def srt_to_vtt(srt_path):
"""Converts SRT to VTT format for the HTML5 video player."""
vtt_path = srt_path.replace(".srt", ".vtt")
with open(srt_path, 'r', encoding='utf-8') as f:
content = f.read()
vtt_content = "WEBVTT\n\n"
# Regex to convert SRT comma timestamps to VTT dot timestamps
vtt_content += re.sub(r'(\d{2}:\d{2}:\d{2}),(\d{3})', r'\1.\2', content)
with open(vtt_path, 'w', encoding='utf-8') as f:
f.write(vtt_content)
return vtt_path
# ---------------------------------------------------------
# Logic 1: Video to SRT (Using Native Whisper)
# ---------------------------------------------------------
def video_to_srt(video_path, progress=gr.Progress()):
if video_path is None: return None, None
# 1. Extract Audio
progress(0.1, desc="Extracting Audio...")
try:
audio_path = extract_audio(video_path)
except Exception as e:
return None, f"Error: {str(e)}"
# 2. Transcribe using Native Whisper
progress(0.3, desc=f"Transcribing with Whisper {WHISPER_MODEL_SIZE}...")
# The native transcribe function handles segmentation automatically!
result = whisper_model.transcribe(audio_path, language="en")
# 3. Format to SRT
progress(0.8, desc="Formatting SRT...")
srt_subtitles = []
for i, segment in enumerate(result["segments"]):
start_seconds = segment["start"]
end_seconds = segment["end"]
text = segment["text"].strip()
srt_subtitles.append(
srt.Subtitle(
index=i+1,
start=timedelta(seconds=start_seconds),
end=timedelta(seconds=end_seconds),
content=text
)
)
srt_path = "generated_captions.srt"
with open(srt_path, 'w', encoding='utf-8') as f:
f.write(srt.compose(srt_subtitles))
# 4. Create Preview
vtt_path = srt_to_vtt(srt_path)
html_preview = f"""
<h3>Video Preview</h3>
<video controls width="100%" height="400px" style="background:black">
<source src="/file={video_path}" type="video/mp4">
<track kind="captions" src="/file={vtt_path}" srclang="en" label="English" default>
Your browser does not support the video tag.
</video>
"""
return srt_path, html_preview
# ---------------------------------------------------------
# Logic 2: Translation (NLLB)
# ---------------------------------------------------------
def batch_translate(texts, src_lang, tgt_lang, batch_size=8):
results = []
tokenizer_nllb.src_lang = src_lang
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
inputs = tokenizer_nllb(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
forced_bos_token_id = tokenizer_nllb.convert_tokens_to_ids(tgt_lang)
with torch.no_grad():
generated_tokens = model_nllb.generate(**inputs, forced_bos_token_id=forced_bos_token_id, max_length=512)
results.extend(tokenizer_nllb.batch_decode(generated_tokens, skip_special_tokens=True))
return results
def process_translation(filepath, src_lang_code, tgt_lang_code):
if filepath is None: return None
try:
with open(filepath, 'r', encoding='utf-8') as f:
subtitles = list(srt.parse(f.read()))
except Exception as e:
return f"Error: {str(e)}"
texts = [sub.content for sub in subtitles]
translated = batch_translate(texts, src_lang_code, tgt_lang_code)
for sub, trans in zip(subtitles, translated):
sub.content = trans
out_path = "translated_subtitles.srt"
with open(out_path, 'w', encoding='utf-8') as f:
f.write(srt.compose(subtitles))
return out_path
# ---------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------
with gr.Blocks(title="SRT Master Tool") as demo:
gr.Markdown(f"# 🎬 Auto Subtitle (Whisper {WHISPER_MODEL_SIZE}) & Translator")
with gr.Tabs():
# --- TAB 1 ---
with gr.TabItem("Step 1: Video to SRT"):
gr.Markdown("### 1. Upload Video -> 2. Check Preview -> 3. Download SRT")
with gr.Row():
video_input = gr.Video(label="Upload Video", sources=["upload"])
with gr.Column():
preview_output = gr.HTML(label="Preview Player")
srt_output_gen = gr.File(label="Download Generated SRT")
btn1 = gr.Button("Generate SRT & Preview", variant="primary")
btn1.click(video_to_srt, inputs=video_input, outputs=[srt_output_gen, preview_output])
# --- TAB 2 ---
with gr.TabItem("Step 2: Translate SRT"):
gr.Markdown("### Translate Subtitles to Arabic")
with gr.Row():
srt_input = gr.File(label="Upload SRT")
with gr.Column():
src_l = gr.Dropdown(["eng_Latn", "fra_Latn"], label="From", value="eng_Latn")
tgt_l = gr.Dropdown(["arb_Arab", "arz_Arab"], label="To", value="arb_Arab")
srt_output_trans = gr.File(label="Translated SRT")
btn2 = gr.Button("Translate", variant="primary")
btn2.click(process_translation, inputs=[srt_input, src_l, tgt_l], outputs=srt_output_trans)
if __name__ == "__main__":
demo.launch() |