Spaces:
Running
Running
| """ | |
| Unified SRT processing module combining resegmentation and translation functionality. | |
| """ | |
| import os | |
| import re | |
| import concurrent.futures | |
| from typing import List, Tuple, Optional | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| # Load environment variables from .env if present | |
| load_dotenv(override=True) | |
| # ============================================================================ | |
| # Core SRT Utilities | |
| # ============================================================================ | |
| def read_srt(file_path: str) -> str: | |
| """Read SRT file content.""" | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| def write_srt(file_path: str, content: str) -> None: | |
| """Write content to SRT file.""" | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| def parse_srt_blocks(srt_content: str) -> List[Tuple[str, str, List[str]]]: | |
| """ | |
| Parse SRT content into blocks. | |
| Returns list of (index, time, text_lines). | |
| """ | |
| blocks = re.split(r"\n\s*\n", srt_content.strip(), flags=re.MULTILINE) | |
| parsed: List[Tuple[str, str, List[str]]] = [] | |
| for block in blocks: | |
| lines = block.strip().splitlines() | |
| if len(lines) < 3: | |
| continue | |
| index = lines[0].strip() | |
| time_line = lines[1].strip() | |
| text_lines = [line.rstrip() for line in lines[2:]] | |
| parsed.append((index, time_line, text_lines)) | |
| return parsed | |
| def parse_srt_block(block: str) -> Optional[Tuple[str, str, List[str]]]: | |
| """Parse a single SRT block.""" | |
| lines = block.strip().splitlines() | |
| if len(lines) < 3: | |
| return None | |
| index = lines[0] | |
| time = lines[1] | |
| text_lines = lines[2:] | |
| return index, time, text_lines | |
| def build_srt_block(index: int, start_time: str, end_time: str, text: str) -> str: | |
| """Build SRT block with index, time range, and text.""" | |
| return f"{index}\n{start_time} --> {end_time}\n{text}" | |
| def build_srt_block_from_lines(index: str, time: str, text_lines: List[str]) -> str: | |
| """Build SRT block from parsed components.""" | |
| return f"{index}\n{time}\n" + "\n".join(text_lines) | |
| # ============================================================================ | |
| # Time Utilities | |
| # ============================================================================ | |
| def extract_times(time_line: str) -> Tuple[str, str]: | |
| """Extract start and end times from time line.""" | |
| # Expected format: HH:MM:SS,mmm --> HH:MM:SS,mmm | |
| parts = [p.strip() for p in time_line.split("-->")] | |
| if len(parts) != 2: | |
| raise ValueError(f"Invalid time line: {time_line}") | |
| return parts[0], parts[1] | |
| def time_str_to_ms(t: str) -> int: | |
| """Convert time string to milliseconds.""" | |
| # HH:MM:SS,mmm | |
| hms, ms = t.split(",") | |
| hours, minutes, seconds = hms.split(":") | |
| total_ms = ( | |
| int(hours) * 3600 * 1000 | |
| + int(minutes) * 60 * 1000 | |
| + int(seconds) * 1000 | |
| + int(ms) | |
| ) | |
| return total_ms | |
| def ms_to_time_str(ms: int) -> str: | |
| """Convert milliseconds to time string.""" | |
| if ms < 0: | |
| ms = 0 | |
| hours = ms // (3600 * 1000) | |
| ms %= 3600 * 1000 | |
| minutes = ms // (60 * 1000) | |
| ms %= 60 * 1000 | |
| seconds = ms // 1000 | |
| millis = ms % 1000 | |
| return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}" | |
| # ============================================================================ | |
| # Text Processing Utilities | |
| # ============================================================================ | |
| def ends_with_preferred_punctuation(text: str) -> bool: | |
| """Check if text ends with preferred punctuation.""" | |
| stripped = text.rstrip() | |
| return stripped.endswith(".") or stripped.endswith(",") | |
| def normalize_whitespace(text: str) -> str: | |
| """Normalize whitespace in text.""" | |
| return re.sub(r"\s+", " ", text).strip() | |
| def count_chars(text: str) -> int: | |
| """Count characters including spaces after normalization.""" | |
| return len(text) | |
| def split_text_into_chunks_by_chars_with_punctuation( | |
| text: str, max_chars: int | |
| ) -> List[str]: | |
| """Split text into chunks respecting punctuation boundaries.""" | |
| text = normalize_whitespace(text) | |
| chunks: List[str] = [] | |
| i = 0 | |
| n = len(text) | |
| while i < n: | |
| remaining = text[i:] | |
| if len(remaining) <= max_chars: | |
| chunks.append(remaining.strip()) | |
| break | |
| window = remaining[:max_chars] | |
| # Prefer last '.' or ',' within the window | |
| last_dot = window.rfind(".") | |
| last_comma = window.rfind(",") | |
| cut_at = max(last_dot, last_comma) | |
| if cut_at != -1: | |
| end = cut_at + 1 | |
| else: | |
| # If no punctuation found, look for the last space to avoid cutting words | |
| last_space = window.rfind(" ") | |
| if last_space != -1: | |
| end = last_space | |
| else: | |
| # If no space found, we have to cut at max_chars (single long word) | |
| end = max_chars | |
| chunk = remaining[:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| i += end | |
| # Skip any following spaces before next chunk | |
| while i < n and text[i] == " ": | |
| i += 1 | |
| return [c for c in chunks if c] | |
| # ============================================================================ | |
| # Translation Functionality | |
| # ============================================================================ | |
| def translate_text( | |
| text: str, target_lang: str, model: str, router: str = "dashscope" | |
| ) -> str: | |
| """Translate text using specified provider.""" | |
| if router == "dashscope": | |
| client = OpenAI( | |
| api_key=os.getenv("DASHSCOPE_API_KEY"), | |
| base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", | |
| ) | |
| prompt = ( | |
| f"Translate the following subtitle text to {target_lang}. " | |
| "Do not translate timestamps or numbers. Only translate the spoken text. " | |
| "Return only the translated text, no explanations or formatting.\n\n" | |
| f"{text}" | |
| ) | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that translates subtitles.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.3, | |
| max_tokens=1024, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| elif router == "openrouter": | |
| client = OpenAI( | |
| api_key=os.getenv("OPENROUTER_API_KEY"), | |
| base_url="https://openrouter.ai/api/v1", | |
| ) | |
| prompt = ( | |
| f"Translate the following subtitle text to {target_lang}. " | |
| "Do not translate timestamps or numbers. Only translate the spoken text. " | |
| "Return only the translated text, no explanations or formatting.\n\n" | |
| f"{text}" | |
| ) | |
| # Optional attribution headers | |
| extra_headers = {} | |
| referer = os.getenv("OPENROUTER_SITE_URL") | |
| app_title = os.getenv("OPENROUTER_APP_TITLE") | |
| if referer: | |
| extra_headers["HTTP-Referer"] = referer | |
| if app_title: | |
| extra_headers["X-Title"] = app_title | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that translates subtitles.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.3, | |
| max_tokens=1024, | |
| extra_headers=extra_headers, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| elif router == "openai": | |
| client = OpenAI() | |
| prompt = ( | |
| f"Translate the following subtitle text to {target_lang}. " | |
| "Do not translate timestamps or numbers. Only translate the spoken text. " | |
| "Return only the translated text, no explanations or formatting.\n\n" | |
| f"{text}" | |
| ) | |
| try: | |
| # Use Responses API for newer models (e.g., gpt-4.1, gpt-4o) | |
| if model and (model.startswith("gpt-4.1") or model.startswith("gpt-4o")): | |
| response = client.responses.create( | |
| model=model, | |
| input=prompt, | |
| instructions="You are a helpful assistant that translates subtitles.", | |
| temperature=0.3, | |
| max_output_tokens=1024, | |
| ) | |
| # Prefer helper if available | |
| try: | |
| return response.output_text.strip() | |
| except Exception: | |
| # Fallback parsing if helper is unavailable | |
| try: | |
| segments = [] | |
| if hasattr(response, "output") and response.output: | |
| for content_item in response.output[0].content: | |
| text_val = getattr(content_item, "text", None) | |
| if text_val: | |
| segments.append(text_val) | |
| if segments: | |
| return "\n".join(segments).strip() | |
| except Exception: | |
| pass | |
| return str(response).strip() | |
| else: | |
| # Backward compatibility: use Chat Completions for older models | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that translates subtitles.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.3, | |
| max_tokens=1024, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| # Last-resort fallback to ensure we return something | |
| return str(e) | |
| else: | |
| return f"Unsupported provider: {router}" | |
| def translate_block(args: Tuple[str, str, str, str]) -> str: | |
| """Translate a single SRT block.""" | |
| block, target_lang, model, router = args | |
| parsed = parse_srt_block(block) | |
| if not parsed: | |
| return block | |
| index, time, text_lines = parsed | |
| text = "\n".join(text_lines) | |
| if text.strip(): | |
| translated_text = translate_text(text, target_lang, model=model, router=router) | |
| translated_text_lines = translated_text.splitlines() or [translated_text] | |
| else: | |
| translated_text_lines = text_lines | |
| translated_block = build_srt_block_from_lines(index, time, translated_text_lines) | |
| return translated_block | |
| def translate_srt( | |
| input_path: str, | |
| output_path: str, | |
| target_lang: str, | |
| model: Optional[str] = None, | |
| workers: int = 15, | |
| router: str = "dashscope", | |
| max_chars: int = 125, | |
| ) -> str: | |
| """Translate SRT file using specified provider with resegmentation.""" | |
| # Check API keys based on router | |
| if router == "openai": | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "Error: OPENAI_API_KEY not found in environment variables." | |
| ) | |
| if not model: | |
| model = os.getenv("MODEL") or "gpt-4.1" | |
| elif router == "openrouter": | |
| openrouter_key = os.getenv("OPENROUTER_API_KEY") | |
| if not openrouter_key: | |
| raise RuntimeError( | |
| "Error: OPENROUTER_API_KEY not found in environment variables." | |
| ) | |
| if not model: | |
| model = os.getenv("MODEL") or "openai/gpt-4o" | |
| elif router == "dashscope": | |
| dashscope_key = os.getenv("DASHSCOPE_API_KEY") | |
| if not dashscope_key: | |
| raise RuntimeError( | |
| "Error: DASHSCOPE_API_KEY not found in environment variables." | |
| ) | |
| if not model: | |
| model = os.getenv("MODEL") or "qwen-max" | |
| else: | |
| raise RuntimeError( | |
| f"Error: Unknown provider '{router}'. Expected one of: openai, openrouter, dashscope." | |
| ) | |
| # First resegment the SRT to get optimal chunks for translation | |
| srt_content = read_srt(input_path) | |
| parsed_blocks = parse_srt_blocks(srt_content) | |
| resegmented_blocks = resegment_blocks(parsed_blocks, max_chars) | |
| # Now translate the resegmented blocks | |
| block_args = [(block, target_lang, model, router) for block in resegmented_blocks] | |
| translated_blocks = [] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: | |
| for translated_block in executor.map(translate_block, block_args): | |
| translated_blocks.append(translated_block) | |
| translated_content = "\n\n".join(translated_blocks) | |
| write_srt(output_path, translated_content) | |
| return output_path | |
| # ============================================================================ | |
| # Resegmentation Functionality | |
| # ============================================================================ | |
| def resegment_blocks( | |
| parsed_blocks: List[Tuple[str, str, List[str]]], max_chars: int | |
| ) -> List[str]: | |
| """Resegment SRT blocks based on character limit.""" | |
| output_blocks: List[str] = [] | |
| current_index = 1 | |
| group_start_time: str = "" | |
| group_end_time: str = "" | |
| group_text_parts: List[str] = [] | |
| group_char_count = 0 | |
| def flush_group(): | |
| nonlocal current_index, group_start_time, group_end_time, group_text_parts, group_char_count | |
| if group_char_count > 0 and group_text_parts: | |
| block_text = normalize_whitespace(" ".join(group_text_parts)) | |
| output_blocks.append( | |
| build_srt_block( | |
| current_index, group_start_time, group_end_time, block_text | |
| ) | |
| ) | |
| current_index += 1 | |
| group_start_time = "" | |
| group_end_time = "" | |
| group_text_parts = [] | |
| group_char_count = 0 | |
| for _, time_line, text_lines in parsed_blocks: | |
| start_time_str, end_time_str = extract_times(time_line) | |
| start_ms = time_str_to_ms(start_time_str) | |
| end_ms = time_str_to_ms(end_time_str) | |
| duration_ms = max(0, end_ms - start_ms) | |
| text = normalize_whitespace(" ".join(text_lines)) | |
| if not text: | |
| continue | |
| this_count = count_chars(text) | |
| # If adding this block would exceed the limit, flush the current group first | |
| if group_char_count > 0 and (group_char_count + this_count) > max_chars: | |
| flush_group() | |
| # If the single block itself exceeds max_chars, split it internally | |
| if this_count > max_chars: | |
| # Ensure any pending group is flushed before inserting split pieces | |
| flush_group() | |
| sub_texts = split_text_into_chunks_by_chars_with_punctuation( | |
| text, max_chars | |
| ) | |
| # Distribute timings proportionally by character count | |
| total_chars = sum(count_chars(st) for st in sub_texts) or 1 | |
| accumulated_ms = 0 | |
| for idx, st in enumerate(sub_texts): | |
| chars_in_chunk = count_chars(st) or 1 | |
| # compute chunk duration (last chunk takes remaining to avoid rounding drift) | |
| if idx < len(sub_texts) - 1: | |
| chunk_ms = int(duration_ms * (chars_in_chunk / total_chars)) | |
| else: | |
| chunk_ms = max(0, duration_ms - accumulated_ms) | |
| chunk_start_ms = start_ms + accumulated_ms | |
| chunk_end_ms = chunk_start_ms + chunk_ms | |
| accumulated_ms += chunk_ms | |
| output_blocks.append( | |
| build_srt_block( | |
| current_index, | |
| ms_to_time_str(chunk_start_ms), | |
| ms_to_time_str(chunk_end_ms), | |
| st, | |
| ) | |
| ) | |
| current_index += 1 | |
| # Done with this overlong block | |
| continue | |
| # Otherwise, safe to merge this whole block into the group | |
| if group_char_count == 0: | |
| group_start_time = start_time_str | |
| group_text_parts.append(text) | |
| group_end_time = end_time_str | |
| group_char_count += this_count | |
| # Prefer flushing on punctuation at the end of this block | |
| if ends_with_preferred_punctuation(text): | |
| flush_group() | |
| elif group_char_count >= max_chars: | |
| flush_group() | |
| # Flush any remaining group | |
| if group_char_count > 0: | |
| flush_group() | |
| return output_blocks | |
| def resegment_srt(input_path: str, output_path: str, max_chars: int = 125) -> str: | |
| """Resegment SRT file based on character limit.""" | |
| srt_content = read_srt(input_path) | |
| parsed = parse_srt_blocks(srt_content) | |
| merged_blocks = resegment_blocks(parsed, max_chars=max_chars) | |
| output_content = "\n\n".join(merged_blocks) + "\n" | |
| write_srt(output_path, output_content) | |
| return output_path | |
| # ============================================================================ | |
| # Combined Processing Functions | |
| # ============================================================================ | |
| def process_srt_file( | |
| input_path: str, | |
| output_path: str, | |
| operation: str = "resegment", | |
| max_chars: int = 125, | |
| target_lang: Optional[str] = None, | |
| model: Optional[str] = None, | |
| workers: int = 15, | |
| router: str = "dashscope", | |
| ) -> str: | |
| """ | |
| Process SRT file with specified operation. | |
| Args: | |
| input_path: Path to input SRT file | |
| output_path: Path to output SRT file | |
| operation: "resegment" or "translate" | |
| max_chars: Maximum characters per segment (for resegmentation) | |
| target_lang: Target language code (for translation) | |
| model: Model to use for translation | |
| workers: Number of concurrent workers for translation | |
| router: Translation provider ("dashscope", "openai", "openrouter") | |
| Returns: | |
| Path to output file | |
| """ | |
| if operation == "resegment": | |
| return resegment_srt(input_path, output_path, max_chars) | |
| elif operation == "translate": | |
| if not target_lang: | |
| raise ValueError("target_lang is required for translation") | |
| return translate_srt( | |
| input_path, output_path, target_lang, model, workers, router, max_chars | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unknown operation: {operation}. Must be 'resegment' or 'translate'" | |
| ) | |
| # ============================================================================ | |
| # CLI Interface (for backward compatibility) | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Unified SRT processing tool for resegmentation and translation. Translation automatically includes resegmentation for optimal chunk sizes." | |
| ) | |
| parser.add_argument("input", help="Input SRT file path") | |
| parser.add_argument("output", help="Output SRT file path") | |
| parser.add_argument( | |
| "--operation", | |
| choices=["resegment", "translate"], | |
| default="resegment", | |
| help="Operation to perform (default: resegment)", | |
| ) | |
| parser.add_argument( | |
| "--max-chars", | |
| dest="max_chars", | |
| type=int, | |
| default=125, | |
| help="Maximum characters per segment (default: 125)", | |
| ) | |
| parser.add_argument( | |
| "--target-lang", help="Target language code (e.g., fr, es, de, zh)" | |
| ) | |
| parser.add_argument( | |
| "--model", help="Model to use for translation (default: value of MODEL in .env)" | |
| ) | |
| parser.add_argument( | |
| "--workers", | |
| type=int, | |
| default=25, | |
| help="Number of concurrent workers for translation (default: 25)", | |
| ) | |
| parser.add_argument( | |
| "--provider", | |
| choices=["openai", "dashscope", "openrouter"], | |
| default="dashscope", | |
| help="Translation provider (default: dashscope)", | |
| ) | |
| args = parser.parse_args() | |
| try: | |
| result = process_srt_file( | |
| args.input, | |
| args.output, | |
| operation=args.operation, | |
| max_chars=args.max_chars, | |
| target_lang=args.target_lang, | |
| model=args.model, | |
| workers=args.workers, | |
| router=args.provider, | |
| ) | |
| print(f"Processing complete. Output written to {result}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| exit(1) | |