SRT-Translator / app.py
benkamin's picture
fix
72a5033 verified
import os, re, time, tempfile, traceback
import gradio as gr
from dotenv import load_dotenv
from openai import OpenAI
from langdetect import detect, DetectorFactory
from srt_utils import (
parse_srt, blocks_to_srt, split_batches,
validate_srt_batch, last_end_time_ms
)
from prompts import build_prompt, RTL_LANGS
load_dotenv()
DetectorFactory.seed = 42
DEFAULT_GLOSSARY = """agency - יכולת פעולה עצמאית
attachment - היקשרות
awakening - התעוררות
alaya - אלאיה
ayatana - אייטנה (בסיס החושים)"""
LANG_NAME_TO_CODE = {
"English": "en",
"Hebrew": "he",
"Spanish": "es",
"French": "fr",
"German": "de",
"Arabic": "ar",
"Auto-detect": "auto",
}
def simple_token_estimate(text: str) -> int:
lines = []
for block in parse_srt(text):
lines.extend(block[2:]) # text-only lines
only_text = " ".join(lines)
words = len(re.findall(r"\S+", only_text))
return int(words * 1.33)
def estimate_cost(total_in_tokens: int,
total_out_tokens: int,
price_in_per_million: float,
price_out_per_million: float) -> float:
cost_in = (total_in_tokens / 1_000_000.0) * price_in_per_million
cost_out = (total_out_tokens / 1_000_000.0) * price_out_per_million
return round(cost_in + cost_out, 4)
def autodetect_source_lang(srt_text: str) -> str:
texts = []
for block in parse_srt(srt_text)[:50]:
texts.extend(block[2:])
sample = " ".join(texts)[:1000].strip()
if not sample:
return "English"
try:
code = detect(sample)
except Exception:
return "English"
for name, c in LANG_NAME_TO_CODE.items():
if c == code:
return name
return "English"
def _read_srt_input(file_input):
"""
Accept bytes (from gr.File with type='binary'), dict payloads, a filepath string,
or a NamedString-like object. Return decoded UTF-8 text.
"""
if file_input is None:
return None
# Dict payload (some gradio versions)
if isinstance(file_input, dict):
data = file_input.get("data")
name = file_input.get("name")
if isinstance(data, (bytes, bytearray)):
try:
return data.decode("utf-8", errors="replace")
except Exception:
return data.decode("latin-1", errors="replace")
if isinstance(name, str) and os.path.exists(name):
with open(name, "rb") as f:
raw = f.read()
try:
return raw.decode("utf-8", errors="replace")
except Exception:
return raw.decode("latin-1", errors="replace")
# bytes or bytearray
if isinstance(file_input, (bytes, bytearray)):
try:
return file_input.decode("utf-8", errors="replace")
except Exception:
return file_input.decode("latin-1", errors="replace")
# NamedString with .name -> temp path
name = getattr(file_input, "name", None)
if isinstance(name, str) and os.path.exists(name):
with open(name, "rb") as f:
raw = f.read()
try:
return raw.decode("utf-8", errors="replace")
except Exception:
return raw.decode("latin-1", errors="replace")
# string path
if isinstance(file_input, str) and os.path.exists(file_input):
with open(file_input, "rb") as f:
raw = f.read()
try:
return raw.decode("utf-8", errors="replace")
except Exception:
return raw.decode("latin-1", errors="replace")
# Fallback
return str(file_input)
def call_gpt(client: OpenAI, model: str, prompt: str, request_timeout: int = 60) -> str:
resp = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
top_p=1.0,
extra_body={"verbosity": "low"},
timeout=request_timeout
)
text = resp.choices[0].message.content
m = re.search(r'<<<SRT>>>\s*(.*?)\s*<<<END>>>', text, re.DOTALL)
return (m.group(1).strip() if m else text.strip())
def prepare_download_file(content: str, suffix: str):
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
with open(tmp.name, "w", encoding="utf-8") as f:
f.write(content)
return tmp.name
def compute_estimates(file_bytes, approx_blocks, use_prev_ctx,
price_in_per_million, price_out_per_million, debug=False):
try:
raw = _read_srt_input(file_bytes)
if not raw:
return "Upload an SRT to estimate.", gr.update(visible=False), gr.update(visible=False)
base_tokens = simple_token_estimate(raw)
blocks = parse_srt(raw)
batch_count = max(1, (len(blocks) + approx_blocks - 1) // approx_blocks)
prefix_tokens_per_batch = 300
context_overhead = 0
if use_prev_ctx and batch_count > 1:
context_overhead = int((base_tokens / batch_count) * 0.5) * (batch_count - 1)
in_tokens = base_tokens + batch_count * prefix_tokens_per_batch + context_overhead
out_tokens = base_tokens
total_cost = estimate_cost(in_tokens, out_tokens, price_in_per_million, price_out_per_million)
if debug:
msg = (
"[DEBUG] base tokens: ~%s, batches: %s\n" % (base_tokens, batch_count)
+
f"Estimated tokens — input: ~{in_tokens:,}, output: ~{out_tokens:,}\n"
f"Estimated total cost: ~${total_cost:.4f} (rates: in ${price_in_per_million}/M, out ${price_out_per_million}/M)\n"
f"Assumptions: words→tokens≈1.33, per-batch prefix≈{prefix_tokens_per_batch}, "
f"{'with' if use_prev_ctx else 'no'} previous-batch context."
)
# Important: keep download buttons hidden for estimator
return msg, gr.update(visible=False), gr.update(visible=False)
except Exception as e:
return f"[Estimator error] {e}", gr.update(visible=False), gr.update(visible=False)
def pipeline(file_bytes, user_api_key, source_lang, target_lang, glossary, extra, model, approx_blocks, use_prev_ctx, debug=False):
SAFE_FILE_HIDDEN = gr.update(value=None, visible=False)
try:
api_key = (user_api_key or "").strip() or os.getenv("OPENAI_API_KEY", "").strip()
if not api_key:
return "", "Please paste your OpenAI API key or configure OPENAI_API_KEY.", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr"
client = OpenAI(api_key=api_key, timeout=60)
raw = _read_srt_input(file_bytes)
if not raw:
return "", "Please upload an SRT file.", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr"
in_blocks = parse_srt(raw)
if source_lang == "Auto-detect":
source_lang = autodetect_source_lang(raw)
for b in in_blocks:
if len(b) < 3 or not b[0].strip().isdigit() or "-->" not in b[1]:
return "", "Input SRT failed basic validation (numbers/timecodes).", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr"
out_blocks_all, logs = [], []
if debug:
logs.append('[DEBUG] Starting pipeline with %d blocks' % len(in_blocks))
prev_source, prev_target = None, None
# initial progress ping
yield "", "Starting translation…", gr.update(value=None, visible=False), gr.update(value=None, visible=False), ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr")
for i, batch in enumerate(split_batches(in_blocks, approx_blocks), start=1):
batch_srt_in = blocks_to_srt(batch)
prompt = build_prompt(
source_lang=source_lang, target_lang=target_lang,
batch_srt=batch_srt_in, glossary_text=glossary, extra_instructions=extra,
prev_source=prev_source if use_prev_ctx else None,
prev_target=prev_target if use_prev_ctx else None
)
if debug:
logs.append('[DEBUG] Prompt length chars=%d' % len(prompt))
logs.append('Calling OpenAI…')
yield blocks_to_srt(out_blocks_all), "\n".join(logs), gr.update(value=None, visible=False), gr.update(value=None, visible=False), ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr")
try:
translated = call_gpt(client, model, prompt, request_timeout=60)
except Exception as e:
logs.append(f"[ERROR] API call failed in batch {i}: {e}")
srt_path = prepare_download_file(blocks_to_srt(out_blocks_all), ".srt")
log_path = prepare_download_file("\n".join(logs), ".log.txt")
return blocks_to_srt(out_blocks_all), "\n".join(logs), srt_path, log_path, ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr")
out_batch = parse_srt(translated)
prev_end = last_end_time_ms(out_blocks_all)
ok, rep = validate_srt_batch(batch, out_batch, prev_last_end=prev_end)
logs.append(f"Batch {i}: {'OK' if ok else 'ISSUES'}")
logs += rep
if not ok:
prompt_strict = prompt + "\n\n(HARD MODE) Repeat EXACT numbers/timecodes/line counts. Output SRT only."
try:
logs.append('Retrying with strict constraints…')
yield blocks_to_srt(out_blocks_all), "\n".join(logs), gr.update(value=None, visible=False), gr.update(value=None, visible=False), ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr")
translated2 = call_gpt(client, model, prompt_strict, request_timeout=60)
out_batch2 = parse_srt(translated2)
ok2, rep2 = validate_srt_batch(batch, out_batch2, prev_last_end=prev_end)
logs.append(f"Batch {i} (retry): {'OK' if ok2 else 'ISSUES'}")
logs += rep2
if ok2:
out_batch = out_batch2
ok = True
except Exception as e:
logs.append(f"[ERROR] Retry failed in batch {i}: {e}")
out_blocks_all.extend(out_batch)
prev_source, prev_target = batch_srt_in, blocks_to_srt(out_batch)
yield blocks_to_srt(out_blocks_all), "\n".join(logs), SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr")
time.sleep(0.05)
final_srt = blocks_to_srt(out_blocks_all)
direction = "rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr"
srt_path = prepare_download_file(final_srt, ".srt")
log_path = prepare_download_file("\n".join(logs) if logs else "Done.", ".log.txt")
return final_srt, "\n".join(logs) if logs else "Done.", srt_path, log_path, direction
except Exception as e:
tb = traceback.format_exc()
return "", f"[FATAL] {e}\n\n{tb}", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr"
with gr.Blocks(title="Open Subtitle Translator (GPT-5)") as demo:
gr.Markdown("## Open Subtitle Translator — GPT-5\nPaste your API key, upload an SRT, pick languages, and translate with strict SRT validation.")
with gr.Row():
key = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...", info="Used only for this session; not stored.")
model = gr.Dropdown(choices=["gpt-5", "gpt-5-mini"], value="gpt-5", label="Model")
approx_blocks = gr.Slider(5, 20, value=10, step=1, label="Approx. SRT blocks per batch")
use_prev = gr.Checkbox(value=True, label="Use previous-batch target as context")
with gr.Row():
src = gr.Dropdown(choices=["Auto-detect", "English", "Hebrew", "Spanish", "French", "German", "Arabic"], value="English", label="Source language")
tgt = gr.Dropdown(choices=["Hebrew", "English", "Spanish", "French", "German", "Arabic"], value="Hebrew", label="Target language")
with gr.Row():
price_in = gr.Number(value=1.25, precision=2, label="Price — input $/M tokens (configurable)")
price_out = gr.Number(value=10.0, precision=2, label="Price — output $/M tokens (configurable)")
glossary = gr.Textbox(label="Glossary / Policy", value=DEFAULT_GLOSSARY, lines=6)
extra = gr.Textbox(label="Extra instructions (optional)", lines=4, placeholder="Tone, domain hints, speaker info…")
srt_in = gr.File(label="Upload SRT", file_types=[".srt"], type="binary")
with gr.Row():
estimate_btn = gr.Button("Estimate Cost")
run_btn = gr.Button("Translate")
debug = gr.Checkbox(value=False, label="Debug mode")
srt_preview = gr.Textbox(label="Translated SRT (preview)", lines=18)
log = gr.Textbox(label="Validation / Log", lines=18)
with gr.Row():
dl_srt = gr.File(label="Download Translated SRT", visible=False)
dl_log = gr.File(label="Download Log", visible=False)
dir_state = gr.State("ltr")
estimate_btn.click(
fn=compute_estimates,
inputs=[srt_in, approx_blocks, use_prev, price_in, price_out, debug],
outputs=[log, dl_srt, dl_log],
api_name="estimate"
)
run_btn.click(
fn=pipeline,
inputs=[srt_in, key, src, tgt, glossary, extra, model, approx_blocks, use_prev, debug],
outputs=[srt_preview, log, dl_srt, dl_log, dir_state],
api_name="translate",
queue=True
)
if __name__ == "__main__":
demo.queue().launch()