Spaces:
Sleeping
Sleeping
| import os, re, time, tempfile, traceback | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from langdetect import detect, DetectorFactory | |
| from srt_utils import ( | |
| parse_srt, blocks_to_srt, split_batches, | |
| validate_srt_batch, last_end_time_ms | |
| ) | |
| from prompts import build_prompt, RTL_LANGS | |
| load_dotenv() | |
| DetectorFactory.seed = 42 | |
| DEFAULT_GLOSSARY = """agency - יכולת פעולה עצמאית | |
| attachment - היקשרות | |
| awakening - התעוררות | |
| alaya - אלאיה | |
| ayatana - אייטנה (בסיס החושים)""" | |
| LANG_NAME_TO_CODE = { | |
| "English": "en", | |
| "Hebrew": "he", | |
| "Spanish": "es", | |
| "French": "fr", | |
| "German": "de", | |
| "Arabic": "ar", | |
| "Auto-detect": "auto", | |
| } | |
| def simple_token_estimate(text: str) -> int: | |
| lines = [] | |
| for block in parse_srt(text): | |
| lines.extend(block[2:]) # text-only lines | |
| only_text = " ".join(lines) | |
| words = len(re.findall(r"\S+", only_text)) | |
| return int(words * 1.33) | |
| def estimate_cost(total_in_tokens: int, | |
| total_out_tokens: int, | |
| price_in_per_million: float, | |
| price_out_per_million: float) -> float: | |
| cost_in = (total_in_tokens / 1_000_000.0) * price_in_per_million | |
| cost_out = (total_out_tokens / 1_000_000.0) * price_out_per_million | |
| return round(cost_in + cost_out, 4) | |
| def autodetect_source_lang(srt_text: str) -> str: | |
| texts = [] | |
| for block in parse_srt(srt_text)[:50]: | |
| texts.extend(block[2:]) | |
| sample = " ".join(texts)[:1000].strip() | |
| if not sample: | |
| return "English" | |
| try: | |
| code = detect(sample) | |
| except Exception: | |
| return "English" | |
| for name, c in LANG_NAME_TO_CODE.items(): | |
| if c == code: | |
| return name | |
| return "English" | |
| def _read_srt_input(file_input): | |
| """ | |
| Accept bytes (from gr.File with type='binary'), dict payloads, a filepath string, | |
| or a NamedString-like object. Return decoded UTF-8 text. | |
| """ | |
| if file_input is None: | |
| return None | |
| # Dict payload (some gradio versions) | |
| if isinstance(file_input, dict): | |
| data = file_input.get("data") | |
| name = file_input.get("name") | |
| if isinstance(data, (bytes, bytearray)): | |
| try: | |
| return data.decode("utf-8", errors="replace") | |
| except Exception: | |
| return data.decode("latin-1", errors="replace") | |
| if isinstance(name, str) and os.path.exists(name): | |
| with open(name, "rb") as f: | |
| raw = f.read() | |
| try: | |
| return raw.decode("utf-8", errors="replace") | |
| except Exception: | |
| return raw.decode("latin-1", errors="replace") | |
| # bytes or bytearray | |
| if isinstance(file_input, (bytes, bytearray)): | |
| try: | |
| return file_input.decode("utf-8", errors="replace") | |
| except Exception: | |
| return file_input.decode("latin-1", errors="replace") | |
| # NamedString with .name -> temp path | |
| name = getattr(file_input, "name", None) | |
| if isinstance(name, str) and os.path.exists(name): | |
| with open(name, "rb") as f: | |
| raw = f.read() | |
| try: | |
| return raw.decode("utf-8", errors="replace") | |
| except Exception: | |
| return raw.decode("latin-1", errors="replace") | |
| # string path | |
| if isinstance(file_input, str) and os.path.exists(file_input): | |
| with open(file_input, "rb") as f: | |
| raw = f.read() | |
| try: | |
| return raw.decode("utf-8", errors="replace") | |
| except Exception: | |
| return raw.decode("latin-1", errors="replace") | |
| # Fallback | |
| return str(file_input) | |
| def call_gpt(client: OpenAI, model: str, prompt: str, request_timeout: int = 60) -> str: | |
| resp = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.0, | |
| top_p=1.0, | |
| extra_body={"verbosity": "low"}, | |
| timeout=request_timeout | |
| ) | |
| text = resp.choices[0].message.content | |
| m = re.search(r'<<<SRT>>>\s*(.*?)\s*<<<END>>>', text, re.DOTALL) | |
| return (m.group(1).strip() if m else text.strip()) | |
| def prepare_download_file(content: str, suffix: str): | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| with open(tmp.name, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return tmp.name | |
| def compute_estimates(file_bytes, approx_blocks, use_prev_ctx, | |
| price_in_per_million, price_out_per_million, debug=False): | |
| try: | |
| raw = _read_srt_input(file_bytes) | |
| if not raw: | |
| return "Upload an SRT to estimate.", gr.update(visible=False), gr.update(visible=False) | |
| base_tokens = simple_token_estimate(raw) | |
| blocks = parse_srt(raw) | |
| batch_count = max(1, (len(blocks) + approx_blocks - 1) // approx_blocks) | |
| prefix_tokens_per_batch = 300 | |
| context_overhead = 0 | |
| if use_prev_ctx and batch_count > 1: | |
| context_overhead = int((base_tokens / batch_count) * 0.5) * (batch_count - 1) | |
| in_tokens = base_tokens + batch_count * prefix_tokens_per_batch + context_overhead | |
| out_tokens = base_tokens | |
| total_cost = estimate_cost(in_tokens, out_tokens, price_in_per_million, price_out_per_million) | |
| if debug: | |
| msg = ( | |
| "[DEBUG] base tokens: ~%s, batches: %s\n" % (base_tokens, batch_count) | |
| + | |
| f"Estimated tokens — input: ~{in_tokens:,}, output: ~{out_tokens:,}\n" | |
| f"Estimated total cost: ~${total_cost:.4f} (rates: in ${price_in_per_million}/M, out ${price_out_per_million}/M)\n" | |
| f"Assumptions: words→tokens≈1.33, per-batch prefix≈{prefix_tokens_per_batch}, " | |
| f"{'with' if use_prev_ctx else 'no'} previous-batch context." | |
| ) | |
| # Important: keep download buttons hidden for estimator | |
| return msg, gr.update(visible=False), gr.update(visible=False) | |
| except Exception as e: | |
| return f"[Estimator error] {e}", gr.update(visible=False), gr.update(visible=False) | |
| def pipeline(file_bytes, user_api_key, source_lang, target_lang, glossary, extra, model, approx_blocks, use_prev_ctx, debug=False): | |
| SAFE_FILE_HIDDEN = gr.update(value=None, visible=False) | |
| try: | |
| api_key = (user_api_key or "").strip() or os.getenv("OPENAI_API_KEY", "").strip() | |
| if not api_key: | |
| return "", "Please paste your OpenAI API key or configure OPENAI_API_KEY.", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr" | |
| client = OpenAI(api_key=api_key, timeout=60) | |
| raw = _read_srt_input(file_bytes) | |
| if not raw: | |
| return "", "Please upload an SRT file.", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr" | |
| in_blocks = parse_srt(raw) | |
| if source_lang == "Auto-detect": | |
| source_lang = autodetect_source_lang(raw) | |
| for b in in_blocks: | |
| if len(b) < 3 or not b[0].strip().isdigit() or "-->" not in b[1]: | |
| return "", "Input SRT failed basic validation (numbers/timecodes).", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr" | |
| out_blocks_all, logs = [], [] | |
| if debug: | |
| logs.append('[DEBUG] Starting pipeline with %d blocks' % len(in_blocks)) | |
| prev_source, prev_target = None, None | |
| # initial progress ping | |
| yield "", "Starting translation…", gr.update(value=None, visible=False), gr.update(value=None, visible=False), ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr") | |
| for i, batch in enumerate(split_batches(in_blocks, approx_blocks), start=1): | |
| batch_srt_in = blocks_to_srt(batch) | |
| prompt = build_prompt( | |
| source_lang=source_lang, target_lang=target_lang, | |
| batch_srt=batch_srt_in, glossary_text=glossary, extra_instructions=extra, | |
| prev_source=prev_source if use_prev_ctx else None, | |
| prev_target=prev_target if use_prev_ctx else None | |
| ) | |
| if debug: | |
| logs.append('[DEBUG] Prompt length chars=%d' % len(prompt)) | |
| logs.append('Calling OpenAI…') | |
| yield blocks_to_srt(out_blocks_all), "\n".join(logs), gr.update(value=None, visible=False), gr.update(value=None, visible=False), ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr") | |
| try: | |
| translated = call_gpt(client, model, prompt, request_timeout=60) | |
| except Exception as e: | |
| logs.append(f"[ERROR] API call failed in batch {i}: {e}") | |
| srt_path = prepare_download_file(blocks_to_srt(out_blocks_all), ".srt") | |
| log_path = prepare_download_file("\n".join(logs), ".log.txt") | |
| return blocks_to_srt(out_blocks_all), "\n".join(logs), srt_path, log_path, ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr") | |
| out_batch = parse_srt(translated) | |
| prev_end = last_end_time_ms(out_blocks_all) | |
| ok, rep = validate_srt_batch(batch, out_batch, prev_last_end=prev_end) | |
| logs.append(f"Batch {i}: {'OK' if ok else 'ISSUES'}") | |
| logs += rep | |
| if not ok: | |
| prompt_strict = prompt + "\n\n(HARD MODE) Repeat EXACT numbers/timecodes/line counts. Output SRT only." | |
| try: | |
| logs.append('Retrying with strict constraints…') | |
| yield blocks_to_srt(out_blocks_all), "\n".join(logs), gr.update(value=None, visible=False), gr.update(value=None, visible=False), ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr") | |
| translated2 = call_gpt(client, model, prompt_strict, request_timeout=60) | |
| out_batch2 = parse_srt(translated2) | |
| ok2, rep2 = validate_srt_batch(batch, out_batch2, prev_last_end=prev_end) | |
| logs.append(f"Batch {i} (retry): {'OK' if ok2 else 'ISSUES'}") | |
| logs += rep2 | |
| if ok2: | |
| out_batch = out_batch2 | |
| ok = True | |
| except Exception as e: | |
| logs.append(f"[ERROR] Retry failed in batch {i}: {e}") | |
| out_blocks_all.extend(out_batch) | |
| prev_source, prev_target = batch_srt_in, blocks_to_srt(out_batch) | |
| yield blocks_to_srt(out_blocks_all), "\n".join(logs), SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, ("rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr") | |
| time.sleep(0.05) | |
| final_srt = blocks_to_srt(out_blocks_all) | |
| direction = "rtl" if target_lang.lower()[:2] in RTL_LANGS else "ltr" | |
| srt_path = prepare_download_file(final_srt, ".srt") | |
| log_path = prepare_download_file("\n".join(logs) if logs else "Done.", ".log.txt") | |
| return final_srt, "\n".join(logs) if logs else "Done.", srt_path, log_path, direction | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return "", f"[FATAL] {e}\n\n{tb}", SAFE_FILE_HIDDEN, SAFE_FILE_HIDDEN, "ltr" | |
| with gr.Blocks(title="Open Subtitle Translator (GPT-5)") as demo: | |
| gr.Markdown("## Open Subtitle Translator — GPT-5\nPaste your API key, upload an SRT, pick languages, and translate with strict SRT validation.") | |
| with gr.Row(): | |
| key = gr.Textbox(label="OpenAI API key", type="password", placeholder="sk-...", info="Used only for this session; not stored.") | |
| model = gr.Dropdown(choices=["gpt-5", "gpt-5-mini"], value="gpt-5", label="Model") | |
| approx_blocks = gr.Slider(5, 20, value=10, step=1, label="Approx. SRT blocks per batch") | |
| use_prev = gr.Checkbox(value=True, label="Use previous-batch target as context") | |
| with gr.Row(): | |
| src = gr.Dropdown(choices=["Auto-detect", "English", "Hebrew", "Spanish", "French", "German", "Arabic"], value="English", label="Source language") | |
| tgt = gr.Dropdown(choices=["Hebrew", "English", "Spanish", "French", "German", "Arabic"], value="Hebrew", label="Target language") | |
| with gr.Row(): | |
| price_in = gr.Number(value=1.25, precision=2, label="Price — input $/M tokens (configurable)") | |
| price_out = gr.Number(value=10.0, precision=2, label="Price — output $/M tokens (configurable)") | |
| glossary = gr.Textbox(label="Glossary / Policy", value=DEFAULT_GLOSSARY, lines=6) | |
| extra = gr.Textbox(label="Extra instructions (optional)", lines=4, placeholder="Tone, domain hints, speaker info…") | |
| srt_in = gr.File(label="Upload SRT", file_types=[".srt"], type="binary") | |
| with gr.Row(): | |
| estimate_btn = gr.Button("Estimate Cost") | |
| run_btn = gr.Button("Translate") | |
| debug = gr.Checkbox(value=False, label="Debug mode") | |
| srt_preview = gr.Textbox(label="Translated SRT (preview)", lines=18) | |
| log = gr.Textbox(label="Validation / Log", lines=18) | |
| with gr.Row(): | |
| dl_srt = gr.File(label="Download Translated SRT", visible=False) | |
| dl_log = gr.File(label="Download Log", visible=False) | |
| dir_state = gr.State("ltr") | |
| estimate_btn.click( | |
| fn=compute_estimates, | |
| inputs=[srt_in, approx_blocks, use_prev, price_in, price_out, debug], | |
| outputs=[log, dl_srt, dl_log], | |
| api_name="estimate" | |
| ) | |
| run_btn.click( | |
| fn=pipeline, | |
| inputs=[srt_in, key, src, tgt, glossary, extra, model, approx_blocks, use_prev, debug], | |
| outputs=[srt_preview, log, dl_srt, dl_log, dir_state], | |
| api_name="translate", | |
| queue=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |