screenshow / engine /captions.py
unknownfriend00007's picture
Upload 13 files
7731346 verified
"""
Caption generation helpers.
This module handles:
- Groq Whisper transcription with optional auto language detection
- SRT generation for compatibility
- ASS subtitle generation with animated word highlighting
- Converting custom SRT files into styled ASS overlays
"""
from __future__ import annotations
import logging
import os
import re
import tempfile
from copy import deepcopy
from typing import Dict, List, Optional
import requests
from engine.processor import ProcessingError, extract_audio
logger = logging.getLogger("ShortsEditor.Captions")
GROQ_TRANSCRIPTION_URL = "https://api.groq.com/openai/v1/audio/transcriptions"
PLAY_RES_X = 1080
PLAY_RES_Y = 1920
DEFAULT_CAPION_BOX = {
"x": 0.08,
"y": 0.72,
"w": 0.84,
"h": 0.18,
}
STYLE_PRESETS = {
"reels": {
"font_name": "Arial",
"font_size": 72,
"primary_color": "#FFFFFF",
"active_color": "#18D7FF",
"outline_color": "#000000",
"back_color": "#000000",
"outline": 6,
"shadow": 0,
"bold": True,
"uppercase": True,
"spacing": 1.0,
"line_spacing": 18,
"active_scale": 112,
"pop_ms": 130,
"max_words": 4,
},
"clean": {
"font_name": "Arial",
"font_size": 60,
"primary_color": "#FFFFFF",
"active_color": "#FFD54A",
"outline_color": "#111111",
"back_color": "#000000",
"outline": 4,
"shadow": 1,
"bold": True,
"uppercase": False,
"spacing": 0.3,
"line_spacing": 14,
"active_scale": 106,
"pop_ms": 110,
"max_words": 5,
},
"impact": {
"font_name": "Arial",
"font_size": 82,
"primary_color": "#FFFFFF",
"active_color": "#FF9F1C",
"outline_color": "#000000",
"back_color": "#000000",
"outline": 7,
"shadow": 0,
"bold": True,
"uppercase": True,
"spacing": 1.2,
"line_spacing": 20,
"active_scale": 116,
"pop_ms": 150,
"max_words": 3,
},
}
def normalize_caption_settings(settings: Optional[dict] = None) -> dict:
"""Return sanitized caption layout/style settings."""
incoming = settings or {}
preset_name = str(incoming.get("preset", "reels")).strip().lower()
preset = deepcopy(STYLE_PRESETS.get(preset_name, STYLE_PRESETS["reels"]))
font_size = int(_clamp(incoming.get("font_size", preset["font_size"]), 28, 160))
max_words = int(_clamp(incoming.get("max_words", preset["max_words"]), 1, 8))
box = incoming.get("box") or {}
x = _clamp(box.get("x", DEFAULT_CAPION_BOX["x"]), 0.0, 0.9)
y = _clamp(box.get("y", DEFAULT_CAPION_BOX["y"]), 0.0, 0.94)
w = _clamp(box.get("w", DEFAULT_CAPION_BOX["w"]), 0.12, 1.0 - x)
h = _clamp(box.get("h", DEFAULT_CAPION_BOX["h"]), 0.08, 1.0 - y)
preset.update(
{
"preset": preset_name,
"font_size": font_size,
"max_words": max_words,
"box": {
"x": round(x, 4),
"y": round(y, 4),
"w": round(w, 4),
"h": round(h, 4),
},
}
)
return preset
def transcribe_with_groq(
video_path: str,
api_key: str,
language: str = "auto",
) -> dict:
"""Transcribe a video file with Groq Whisper and return verbose JSON."""
if not api_key or api_key.strip() == "" or api_key == "your_groq_api_key_here":
raise ProcessingError(
"Groq API key not set.\n"
"Please add your key to the .env file:\n"
"GROQ_API_KEY=gsk_your_key_here"
)
logger.info("Extracting audio for transcription...")
wav_path = None
try:
wav_dir = tempfile.mkdtemp(prefix="shorts_captions_")
wav_path = os.path.join(wav_dir, "audio.wav")
extract_audio(video_path, wav_path)
if not os.path.isfile(wav_path) or os.path.getsize(wav_path) == 0:
raise ProcessingError("Audio extraction produced an empty file.")
file_size_mb = os.path.getsize(wav_path) / (1024 * 1024)
if file_size_mb > 25:
logger.warning(
"Audio file is %.1fMB, compressing before upload.",
file_size_mb,
)
mp3_path = os.path.join(wav_dir, "audio.mp3")
_compress_audio(wav_path, mp3_path)
upload_path = mp3_path
else:
upload_path = wav_path
logger.info("Sending audio to Groq Whisper (language=%s)...", language)
with open(upload_path, "rb") as audio_file:
data = [
("model", "whisper-large-v3-turbo"),
("response_format", "verbose_json"),
("temperature", "0"),
("timestamp_granularities[]", "segment"),
("timestamp_granularities[]", "word"),
]
if language and language != "auto":
data.append(("language", language))
response = requests.post(
GROQ_TRANSCRIPTION_URL,
headers={"Authorization": f"Bearer {api_key}"},
files={"file": (os.path.basename(upload_path), audio_file)},
data=data,
timeout=180,
)
if response.status_code != 200:
error_msg = response.text[:500]
raise ProcessingError(
f"Groq Whisper API error (HTTP {response.status_code}):\n{error_msg}"
)
result = response.json()
if not result.get("segments") and not result.get("text"):
raise ProcessingError(
"Groq Whisper returned no transcription.\n"
"The audio may be too short, silent, or unsupported."
)
return result
finally:
if wav_path:
try:
import shutil
shutil.rmtree(os.path.dirname(wav_path), ignore_errors=True)
except Exception:
pass
def generate_srt_groq(
video_path: str,
output_srt_path: str,
api_key: str,
language: str = "auto",
) -> str:
"""Generate a classic SRT subtitle file using Groq Whisper."""
result = transcribe_with_groq(video_path, api_key, language=language)
segments = result.get("segments", [])
if segments:
srt_content = _segments_to_srt(segments)
else:
text = result.get("text", "").strip()
duration = float(result.get("duration", 10.0))
srt_content = (
"1\n"
f"00:00:00,000 --> {_format_srt_timestamp(duration)}\n"
f"{text}\n\n"
)
with open(output_srt_path, "w", encoding="utf-8") as handle:
handle.write(srt_content)
return output_srt_path
def generate_ass_groq(
video_path: str,
output_ass_path: str,
api_key: str,
language: str = "auto",
settings: Optional[dict] = None,
) -> str:
"""Generate animated ASS subtitles from Groq Whisper output."""
result = transcribe_with_groq(video_path, api_key, language=language)
caption_settings = normalize_caption_settings(settings)
ass_content = build_ass_from_transcription(result, caption_settings)
with open(output_ass_path, "w", encoding="utf-8") as handle:
handle.write(ass_content)
logger.info("ASS captions generated: %s", output_ass_path)
return output_ass_path
def convert_srt_to_ass(
srt_path: str,
output_ass_path: str,
settings: Optional[dict] = None,
) -> str:
"""Convert an SRT file into a styled ASS subtitle file."""
cues = _parse_srt_file(srt_path)
if not cues:
raise ProcessingError("Uploaded SRT file is empty or invalid.")
ass_content = build_ass_from_srt(cues, normalize_caption_settings(settings))
with open(output_ass_path, "w", encoding="utf-8") as handle:
handle.write(ass_content)
logger.info("Converted SRT to ASS: %s", output_ass_path)
return output_ass_path
def build_ass_from_transcription(result: dict, settings: Optional[dict] = None) -> str:
"""Build an ASS subtitle document from Groq verbose JSON."""
caption_settings = normalize_caption_settings(settings)
words = _normalize_word_items(result)
if words:
events = _build_word_highlight_events(words, caption_settings)
else:
segments = result.get("segments", [])
events = _build_segment_events(segments, caption_settings)
if not events:
text = result.get("text", "").strip()
if text:
fallback_segment = [{"start": 0.0, "end": float(result.get("duration", 5.0)), "text": text}]
events = _build_segment_events(fallback_segment, caption_settings)
if not events:
raise ProcessingError("Could not build subtitle events from the transcription.")
return _build_ass_document(events, caption_settings)
def build_ass_from_srt(cues: List[dict], settings: Optional[dict] = None) -> str:
"""Build a styled ASS document from SRT cues."""
caption_settings = normalize_caption_settings(settings)
events = []
for cue in cues:
wrapped = _wrap_plain_text(cue["text"], caption_settings)
text = (
f"{_box_override(caption_settings)}"
f"{_cue_intro_override(caption_settings)}"
f"{wrapped}"
)
events.append(
{
"start": cue["start"],
"end": cue["end"],
"style": "Caption",
"text": text,
}
)
return _build_ass_document(events, caption_settings)
def validate_srt_file(srt_path: str) -> bool:
"""Basic validation of an SRT file."""
if not os.path.isfile(srt_path):
return False
try:
with open(srt_path, "r", encoding="utf-8") as handle:
content = handle.read(2000)
except Exception:
return False
if not content.strip():
return False
return bool(
re.search(
r"\d{2}:\d{2}:\d{2},\d{3}\s*-->\s*\d{2}:\d{2}:\d{2},\d{3}",
content,
)
)
def _compress_audio(input_wav: str, output_mp3: str):
"""Compress WAV to MP3 to fit within Groq's 25MB limit."""
import shutil
import subprocess
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg:
raise ProcessingError("FFmpeg not found, needed to compress audio.")
cmd = [
ffmpeg,
"-y",
"-i",
input_wav,
"-codec:a",
"libmp3lame",
"-b:a",
"64k",
"-ar",
"16000",
"-ac",
"1",
output_mp3,
]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60,
creationflags=subprocess.CREATE_NO_WINDOW if os.name == "nt" else 0,
)
if result.returncode != 0:
raise ProcessingError(f"Audio compression failed:\n{result.stderr[:500]}")
def _normalize_word_items(result: dict) -> List[dict]:
"""Return normalized word timing items from a Groq transcription result."""
raw_words = result.get("words") or []
words = []
for raw in raw_words:
text = str(raw.get("word", "")).strip()
if not text:
continue
start = _safe_float(raw.get("start"))
end = _safe_float(raw.get("end"))
if start is None or end is None or end <= start:
continue
words.append({"text": text, "start": start, "end": end})
if words:
return words
for segment in result.get("segments", []):
seg_words = segment.get("words") or []
for raw in seg_words:
text = str(raw.get("word", "")).strip()
start = _safe_float(raw.get("start"))
end = _safe_float(raw.get("end"))
if not text or start is None or end is None or end <= start:
continue
words.append({"text": text, "start": start, "end": end})
return words
def _build_word_highlight_events(words: List[dict], settings: dict) -> List[dict]:
"""Build one ASS event per spoken word, highlighting the active word."""
chunks = _chunk_words(words, settings)
events = []
active_color = _hex_to_ass_color(settings["active_color"])
active_scale = int(settings["active_scale"])
pop_ms = int(settings["pop_ms"])
for chunk in chunks:
wrapped_lines = _wrap_word_items(chunk, settings)
for idx, word in enumerate(chunk):
if idx + 1 < len(chunk):
event_end = max(chunk[idx + 1]["start"], word["end"])
else:
event_end = word["end"] + 0.06
line_text = []
for line in wrapped_lines:
parts = []
for item in line:
token = item["display"]
if item["chunk_index"] == idx:
token = (
f"{{\\c{active_color}\\fscx{active_scale}\\fscy{active_scale}"
f"\\t(0,{pop_ms},\\fscx100\\fscy100)}}"
f"{token}{{\\r}}"
)
parts.append(token)
line_text.append(" ".join(parts))
joined_lines = r"\N".join(line_text)
text = f"{_box_override(settings)}{joined_lines}"
events.append(
{
"start": word["start"],
"end": max(event_end, word["start"] + 0.04),
"style": "Caption",
"text": text,
}
)
return events
def _build_segment_events(segments: List[dict], settings: dict) -> List[dict]:
"""Fallback event builder when word-level timestamps are unavailable."""
events = []
for segment in segments:
text = str(segment.get("text", "")).strip()
if not text:
continue
start = _safe_float(segment.get("start"))
end = _safe_float(segment.get("end"))
if start is None or end is None or end <= start:
continue
wrapped = _wrap_plain_text(text, settings)
events.append(
{
"start": start,
"end": end,
"style": "Caption",
"text": f"{_box_override(settings)}{_cue_intro_override(settings)}{wrapped}",
}
)
return events
def _build_ass_document(events: List[dict], settings: dict) -> str:
"""Assemble a full ASS document."""
left = int(settings["box"]["x"] * PLAY_RES_X)
top = int(settings["box"]["y"] * PLAY_RES_Y)
right_margin = int((1 - settings["box"]["x"] - settings["box"]["w"]) * PLAY_RES_X)
bottom_margin = int((1 - settings["box"]["y"] - settings["box"]["h"]) * PLAY_RES_Y)
primary = _hex_to_ass_color(settings["primary_color"])
secondary = _hex_to_ass_color(settings["active_color"])
outline = _hex_to_ass_color(settings["outline_color"])
back = _hex_to_ass_color(settings["back_color"], alpha=0x64)
bold = -1 if settings.get("bold") else 0
header = [
"[Script Info]",
"Title: ShortsEditor Animated Captions",
"ScriptType: v4.00+",
f"PlayResX: {PLAY_RES_X}",
f"PlayResY: {PLAY_RES_Y}",
"ScaledBorderAndShadow: yes",
"WrapStyle: 2",
"",
"[V4+ Styles]",
(
"Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, "
"OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, "
"ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, "
"Alignment, MarginL, MarginR, MarginV, Encoding"
),
(
"Style: Caption,"
f"{settings['font_name']},{settings['font_size']},{primary},{secondary},"
f"{outline},{back},{bold},0,0,0,100,100,{settings['spacing']},0,1,"
f"{settings['outline']},{settings['shadow']},8,{left},{right_margin},"
f"{max(top, bottom_margin)},1"
),
"",
"[Events]",
"Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text",
]
lines = header[:]
for event in events:
lines.append(
"Dialogue: 0,"
f"{_format_ass_timestamp(event['start'])},"
f"{_format_ass_timestamp(event['end'])},"
f"{event['style']},,0,0,0,,{event['text']}"
)
return "\n".join(lines) + "\n"
def _chunk_words(words: List[dict], settings: dict) -> List[List[dict]]:
"""Group words into short caption bursts that fit common reels pacing."""
max_words = int(settings["max_words"])
pause_threshold = 0.42
max_duration = 2.2
chunks = []
current = []
for word in words:
if not current:
current = [word]
continue
previous = current[-1]
gap = max(0.0, word["start"] - previous["end"])
duration = word["end"] - current[0]["start"]
previous_text = previous["text"]
if (
len(current) >= max_words
or gap >= pause_threshold
or duration >= max_duration
or previous_text.endswith((".", "?", "!", ";", ":"))
):
chunks.append(current)
current = [word]
else:
current.append(word)
if current:
chunks.append(current)
return chunks
def _wrap_word_items(chunk: List[dict], settings: dict) -> List[List[dict]]:
"""Wrap caption words into one or two lines based on box width."""
display_items = []
for idx, word in enumerate(chunk):
display_items.append(
{
"chunk_index": idx,
"display": _display_text(word["text"], settings),
}
)
max_chars = _max_chars_per_line(settings)
lines: List[List[dict]] = [[]]
for item in display_items:
current_line = lines[-1]
candidate = current_line + [item]
if current_line and len(_plain_line(candidate)) > max_chars and len(lines) < 2:
lines.append([item])
else:
current_line.append(item)
if any(not line for line in lines):
lines = [line for line in lines if line]
return lines
def _wrap_plain_text(text: str, settings: dict) -> str:
"""Wrap plain cue text into one or two lines for the caption box."""
tokens = [token for token in re.split(r"\s+", text.strip()) if token]
if not tokens:
return ""
converted = [{"display": _display_text(token, settings)} for token in tokens]
max_chars = _max_chars_per_line(settings)
lines: List[List[dict]] = [[]]
for token in converted:
current_line = lines[-1]
candidate = current_line + [token]
if current_line and len(_plain_line(candidate)) > max_chars and len(lines) < 2:
lines.append([token])
else:
current_line.append(token)
return "\\N".join(_plain_line(line) for line in lines if line)
def _box_override(settings: dict) -> str:
"""ASS override that pins captions to the chosen preview rectangle."""
left = int(settings["box"]["x"] * PLAY_RES_X)
top = int(settings["box"]["y"] * PLAY_RES_Y)
width = int(settings["box"]["w"] * PLAY_RES_X)
height = int(settings["box"]["h"] * PLAY_RES_Y)
center_x = left + width // 2
top_y = top + int(settings["font_size"] * 0.15)
right = left + width
bottom = top + height
return f"{{\\an8\\pos({center_x},{top_y})\\clip({left},{top},{right},{bottom})}}"
def _cue_intro_override(settings: dict) -> str:
"""A small pop-in transform for non-word-timed caption cues."""
return "{\\fad(50,80)\\fscx96\\fscy96\\t(0,120,\\fscx100\\fscy100)}"
def _parse_srt_file(srt_path: str) -> List[dict]:
"""Parse SRT blocks into cue dictionaries."""
with open(srt_path, "r", encoding="utf-8-sig") as handle:
content = handle.read()
blocks = re.split(r"\r?\n\r?\n+", content.strip())
cues = []
for block in blocks:
lines = [line.rstrip() for line in block.splitlines() if line.strip()]
if not lines:
continue
if "-->" in lines[0]:
timing_line = lines[0]
text_lines = lines[1:]
elif len(lines) >= 2 and "-->" in lines[1]:
timing_line = lines[1]
text_lines = lines[2:]
else:
continue
start_text, end_text = [part.strip() for part in timing_line.split("-->", 1)]
start = _parse_srt_timestamp(start_text)
end = _parse_srt_timestamp(end_text)
if start is None or end is None or end <= start:
continue
text = " ".join(part.strip() for part in text_lines if part.strip())
if text:
cues.append({"start": start, "end": end, "text": text})
return cues
def _segments_to_srt(segments: list) -> str:
"""Convert Whisper segments to SRT content."""
srt_lines = []
index = 1
for segment in segments:
text = str(segment.get("text", "")).strip()
start = _safe_float(segment.get("start"))
end = _safe_float(segment.get("end"))
if not text or start is None or end is None or end <= start:
continue
srt_lines.append(
f"{index}\n"
f"{_format_srt_timestamp(start)} --> {_format_srt_timestamp(end)}\n"
f"{text}\n"
)
index += 1
return "\n".join(srt_lines)
def _format_srt_timestamp(seconds: float) -> str:
"""Convert seconds to HH:MM:SS,mmm."""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
millis = int(round((seconds - int(seconds)) * 1000))
if millis == 1000:
secs += 1
millis = 0
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
def _format_ass_timestamp(seconds: float) -> str:
"""Convert seconds to H:MM:SS.cc used by ASS subtitles."""
total_centis = int(round(seconds * 100))
hours = total_centis // 360000
minutes = (total_centis % 360000) // 6000
secs = (total_centis % 6000) // 100
centis = total_centis % 100
return f"{hours}:{minutes:02d}:{secs:02d}.{centis:02d}"
def _parse_srt_timestamp(value: str) -> Optional[float]:
"""Parse an SRT timestamp into seconds."""
match = re.match(r"(\d+):(\d+):(\d+),(\d+)", value)
if not match:
return None
hours, minutes, seconds, millis = match.groups()
return (
int(hours) * 3600
+ int(minutes) * 60
+ int(seconds)
+ int(millis) / 1000.0
)
def _display_text(text: str, settings: dict) -> str:
"""Apply preset casing and escape ASS special characters."""
display = text.upper() if settings.get("uppercase") else text
return _escape_ass_text(display)
def _escape_ass_text(text: str) -> str:
"""Escape special characters in ASS dialogue text."""
escaped = text.replace("\\", r"\\")
escaped = escaped.replace("{", r"\{").replace("}", r"\}")
return escaped.replace("\n", r"\N")
def _plain_line(items: List[dict]) -> str:
return " ".join(item["display"] for item in items)
def _max_chars_per_line(settings: dict) -> int:
box_width_px = max(int(settings["box"]["w"] * PLAY_RES_X), 120)
font_size = max(int(settings["font_size"]), 1)
approx_char_width = max(font_size * 0.58, 14)
return max(8, int(box_width_px / approx_char_width))
def _hex_to_ass_color(hex_color: str, alpha: int = 0x00) -> str:
"""Convert #RRGGBB into ASS color format &HAABBGGRR."""
value = hex_color.lstrip("#")
if len(value) != 6:
value = "FFFFFF"
rr = value[0:2]
gg = value[2:4]
bb = value[4:6]
return f"&H{alpha:02X}{bb}{gg}{rr}"
def _safe_float(value) -> Optional[float]:
try:
return float(value)
except (TypeError, ValueError):
return None
def _clamp(value, low, high):
try:
numeric = float(value)
except (TypeError, ValueError):
numeric = low
return max(low, min(high, numeric))