Spaces:
Running
Running
| """SRT file formatting and output module for CapCut compatibility.""" | |
| from pathlib import Path | |
| from typing import Dict, List, Union | |
| from config import ( | |
| SRT_ENCODING, SRT_LINE_ENDING, | |
| ARABIC_PARTICLES, MIN_CAPTION_DURATION_MS, | |
| CAPCUT_FPS, SNAP_TO_FRAME, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Grouping helpers | |
| # --------------------------------------------------------------------------- | |
| def _is_latin_word(word: str) -> bool: | |
| """True when the word contains only non-Arabic characters (French/Latin).""" | |
| return bool(word) and not any("\u0600" <= c <= "\u06FF" for c in word) | |
| def snap_to_30fps(ms: int) -> int: | |
| """Snap milliseconds down to the CapCut import frame grid.""" | |
| return int(ms * CAPCUT_FPS / 1000) * 1000 // CAPCUT_FPS | |
| def group_words(word_segments: List[Dict]) -> List[Dict]: | |
| """Group flat word-level segments into natural caption blocks. | |
| Rules applied in priority order: | |
| 1. "Ω" β always group with next word. | |
| - If next word is "Ψ§Ω", make a 3-token block: Ω + Ψ§Ω + <word>. | |
| 2. "Ψ§Ω" standalone β group with next word (it is the Arabic definite article). | |
| 3. "ΨΉΩΩ", "Ω Ψ§" β always group with next word. | |
| 4. "Ω" β if next token is a digit, group Ω + digit + following noun (3-token). | |
| 5. "ΩΩΨ§" β group with next if next is French/Latin OR starts with "Ψ§Ω" prefix. | |
| 6. French/Latin word followed by standalone "Ψ§ΩΩ" or "Ψ§Ω" β group (e.g. cellulite Ψ§ΩΩ). | |
| All other tokens are emitted as-is. | |
| Post-grouping: enforce MIN_CAPTION_DURATION_MS and remove overlaps. | |
| """ | |
| if not word_segments: | |
| return word_segments | |
| grouped: List[Dict] = [] | |
| i = 0 | |
| n = len(word_segments) | |
| while i < n: | |
| seg = word_segments[i] | |
| w = seg["text"] | |
| # ββ Rule 1: "Ω" β always group with next ββββββββββββββββββββββββββ | |
| if w == "Ω" and i + 1 < n: | |
| nxt = word_segments[i + 1] | |
| if nxt["text"] == "Ψ§Ω" and i + 2 < n: | |
| third = word_segments[i + 2] | |
| # Only make a 3-token block if the third token is a content word | |
| # (not a particle β e.g. "Ω Ψ§Ω Ω" should not collapse to 3 tokens) | |
| if third["text"] not in ARABIC_PARTICLES: | |
| grouped.append({ | |
| "text": f"Ω Ψ§Ω {third['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": third["end_ms"], | |
| }) | |
| i += 3 | |
| else: | |
| # Fallback: 2-token "Ω Ψ§Ω" | |
| grouped.append({ | |
| "text": f"Ω {nxt['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt["end_ms"], | |
| }) | |
| i += 2 | |
| else: | |
| grouped.append({ | |
| "text": f"Ω {nxt['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt["end_ms"], | |
| }) | |
| i += 2 | |
| continue | |
| # ββ Rule 2: "Ψ§Ω" standalone β group with next βββββββββββββββββββββ | |
| if w == "Ψ§Ω" and i + 1 < n: | |
| nxt = word_segments[i + 1] | |
| grouped.append({ | |
| "text": f"Ψ§Ω {nxt['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt["end_ms"], | |
| }) | |
| i += 2 | |
| continue | |
| # ββ Rule 3: "ΨΉΩΩ", "Ω Ψ§" β always group with next ββββββββββββββββββ | |
| if w in ("ΨΉΩΩ", "Ω Ψ§") and i + 1 < n: | |
| nxt = word_segments[i + 1] | |
| grouped.append({ | |
| "text": f"{w} {nxt['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt["end_ms"], | |
| }) | |
| i += 2 | |
| continue | |
| # ββ Rule 4: "Ω" + digit β 3-token group (Ω N noun) βββββββββββββββ | |
| if w == "Ω" and i + 1 < n: | |
| nxt1 = word_segments[i + 1] | |
| if nxt1["text"] and (nxt1["text"][0].isdigit() or nxt1["text"].isdigit()): | |
| if i + 2 < n: | |
| nxt2 = word_segments[i + 2] | |
| grouped.append({ | |
| "text": f"Ω {nxt1['text']} {nxt2['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt2["end_ms"], | |
| }) | |
| i += 3 | |
| else: | |
| grouped.append({ | |
| "text": f"Ω {nxt1['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt1["end_ms"], | |
| }) | |
| i += 2 | |
| continue | |
| # ββ Rule 5: "ΩΩΨ§" β conditional group βββββββββββββββββββββββββββββ | |
| if w == "ΩΩΨ§" and i + 1 < n: | |
| nxt = word_segments[i + 1] | |
| nxt_text = nxt["text"] | |
| if _is_latin_word(nxt_text) or nxt_text.startswith("Ψ§Ω"): | |
| grouped.append({ | |
| "text": f"ΩΩΨ§ {nxt_text}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt["end_ms"], | |
| }) | |
| i += 2 | |
| continue | |
| # ββ Rule 6: French/Latin word + trailing "Ψ§ΩΩ"/"Ψ§Ω" βββββββββββββββ | |
| if _is_latin_word(w) and i + 1 < n: | |
| nxt = word_segments[i + 1] | |
| if nxt["text"] in ("Ψ§ΩΩ", "Ψ§Ω"): | |
| grouped.append({ | |
| "text": f"{w} {nxt['text']}", | |
| "start_ms": seg["start_ms"], | |
| "end_ms": nxt["end_ms"], | |
| }) | |
| i += 2 | |
| continue | |
| # ββ Default: emit as-is ββββββββββββββββββββββββββββββββββββββββββββ | |
| grouped.append(seg) | |
| i += 1 | |
| # Enforce minimum duration and remove overlaps | |
| grouped = _enforce_timing(grouped) | |
| # Post-enforcement: merge blocks that are still too short (<100ms) due to tight | |
| # audio clusters where the audio window is physically less than 100ms. | |
| grouped = _merge_short_blocks(grouped, threshold_ms=MIN_CAPTION_DURATION_MS) | |
| # Re-index from 1 | |
| for idx, s in enumerate(grouped): | |
| s["index"] = idx + 1 | |
| return grouped | |
| def _merge_short_blocks(segments: List[Dict], threshold_ms: int = 50) -> List[Dict]: | |
| """Merge blocks shorter than threshold_ms into the previous block. | |
| Handles tight audio clusters where a grouped token (e.g. "Ω Ψ§Ω") has | |
| insufficient duration. The merged block inherits the previous block's | |
| start_ms and the short block's end_ms, concatenating the text. | |
| """ | |
| if not segments: | |
| return segments | |
| result: List[Dict] = [] | |
| for seg in segments: | |
| dur = seg["end_ms"] - seg["start_ms"] | |
| if dur < threshold_ms and result: | |
| prev = result[-1] | |
| result[-1] = { | |
| "text": f"{prev['text']} {seg['text']}", | |
| "start_ms": prev["start_ms"], | |
| "end_ms": seg["end_ms"], | |
| } | |
| else: | |
| result.append(dict(seg)) | |
| return result | |
| def _enforce_timing(segments: List[Dict]) -> List[Dict]: | |
| """Enforce MIN_CAPTION_DURATION_MS and eliminate gaps between captions. | |
| Each caption's end time matches the next caption's start time exactly. | |
| Overlap (end > next_start) is never allowed. | |
| """ | |
| if not segments: | |
| return segments | |
| result = [dict(s) for s in segments] | |
| for i, seg in enumerate(result): | |
| if i + 1 < len(result): | |
| next_start = result[i + 1]["start_ms"] | |
| # Ensure minimum duration while eliminating gaps | |
| min_end = seg["start_ms"] + MIN_CAPTION_DURATION_MS | |
| if min_end <= next_start: | |
| # Set end time to match next start time exactly (no gap) | |
| seg["end_ms"] = next_start | |
| else: | |
| # If minimum duration would overlap next caption, clamp to 1ms before | |
| seg["end_ms"] = max(seg["start_ms"] + 1, next_start) | |
| else: | |
| # Last segment: just enforce minimum duration | |
| if seg["end_ms"] - seg["start_ms"] < MIN_CAPTION_DURATION_MS: | |
| seg["end_ms"] = seg["start_ms"] + MIN_CAPTION_DURATION_MS | |
| return result | |
| def write_srt(segments: List[Dict], output_path: Union[str, Path], | |
| apply_grouping: bool = False) -> str: | |
| """Write aligned segments to SRT file with CapCut-compatible formatting. | |
| When apply_grouping=True (word-level mode) the segments are first passed | |
| through group_words() to merge Arabic particles with adjacent tokens before | |
| writing. CRLF line endings are always enforced for CapCut compatibility. | |
| """ | |
| output_path = Path(output_path) | |
| if not segments: | |
| raise ValueError("No segments provided for SRT generation") | |
| # Apply particle-based grouping for word-level input | |
| if apply_grouping: | |
| segments = group_words(segments) | |
| # Finalize timestamps for CapCut import: | |
| # align/group -> snap to 30fps -> eliminate gaps -> validate/write with CRLF | |
| segments = _finalize_for_capcut(segments) | |
| # Create output directory if it doesn't exist | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Validate all segments before writing | |
| for segment in segments: | |
| _validate_segment(segment) | |
| # Generate SRT content | |
| srt_content = [] | |
| for segment in segments: | |
| index = segment["index"] | |
| text = segment["text"].strip() | |
| start_ms = segment["start_ms"] | |
| end_ms = segment["end_ms"] | |
| # Convert milliseconds to SRT timestamp format | |
| start_timestamp = _ms_to_srt_timestamp(start_ms) | |
| end_timestamp = _ms_to_srt_timestamp(end_ms) | |
| # Build SRT block | |
| srt_block = [ | |
| str(index), | |
| f"{start_timestamp} --> {end_timestamp}", | |
| text, | |
| "" # Empty line between blocks | |
| ] | |
| srt_content.extend(srt_block) | |
| # Join with CapCut-compatible line endings | |
| srt_text = SRT_LINE_ENDING.join(srt_content) | |
| # Write to file with UTF-8 encoding (no BOM) | |
| try: | |
| with open(output_path, 'w', encoding=SRT_ENCODING, newline='') as f: | |
| f.write(srt_text) | |
| except (OSError, IOError) as e: | |
| raise RuntimeError(f"Failed to write SRT file {output_path}: {e}") | |
| print(f"β SRT written β {output_path} ({len(segments)} captions)") | |
| return str(output_path) | |
| def _finalize_for_capcut(segments: List[Dict]) -> List[Dict]: | |
| """Apply CapCut frame snapping and eliminate inter-caption gaps.""" | |
| subtitles = [dict(segment) for segment in segments] | |
| if SNAP_TO_FRAME: | |
| for subtitle in subtitles: | |
| subtitle["start_ms"] = snap_to_30fps(subtitle["start_ms"]) | |
| subtitle["end_ms"] = snap_to_30fps(subtitle["end_ms"]) | |
| for i in range(len(subtitles) - 1): | |
| subtitles[i]["end_ms"] = subtitles[i + 1]["start_ms"] | |
| return subtitles | |
| def _validate_segment(segment: Dict) -> None: | |
| """Validate a single segment before SRT generation.""" | |
| # Check required fields | |
| required_fields = ["index", "text", "start_ms", "end_ms"] | |
| for field in required_fields: | |
| if field not in segment: | |
| raise ValueError(f"Missing required field '{field}' in segment") | |
| index = segment["index"] | |
| text = segment["text"] | |
| start_ms = segment["start_ms"] | |
| end_ms = segment["end_ms"] | |
| # Validate index | |
| if not isinstance(index, int) or index < 1: | |
| raise ValueError(f"Invalid segment index: {index}. Must be positive integer.") | |
| # Validate text | |
| if not isinstance(text, str): | |
| raise ValueError(f"Invalid text type in segment {index}: {type(text)}. Must be string.") | |
| if not text.strip(): | |
| raise ValueError(f"Empty text in segment {index}") | |
| # Validate timestamps | |
| if not isinstance(start_ms, int) or start_ms < 0: | |
| raise ValueError(f"Invalid start_ms in segment {index}: {start_ms}. Must be non-negative integer.") | |
| if not isinstance(end_ms, int) or end_ms < 0: | |
| raise ValueError(f"Invalid end_ms in segment {index}: {end_ms}. Must be non-negative integer.") | |
| if end_ms <= start_ms: | |
| raise ValueError(f"Invalid timestamp range in segment {index}: start={start_ms}ms, end={end_ms}ms") | |
| def _ms_to_srt_timestamp(milliseconds: int) -> str: | |
| """Convert milliseconds to SRT timestamp format: HH:MM:SS,mmm""" | |
| if milliseconds < 0: | |
| raise ValueError(f"Negative timestamp not allowed: {milliseconds}ms") | |
| # Calculate components | |
| total_seconds = milliseconds // 1000 | |
| ms = milliseconds % 1000 | |
| hours = total_seconds // 3600 | |
| minutes = (total_seconds % 3600) // 60 | |
| seconds = total_seconds % 60 | |
| # Format with leading zeros | |
| # Note: SRT uses comma as decimal separator, not period | |
| return f"{hours:02d}:{minutes:02d}:{seconds:02d},{ms:03d}" | |