from __future__ import annotations import json import os import re import shutil import subprocess import tempfile from typing import Any, Dict, Tuple from fastapi import BackgroundTasks, Body, FastAPI, File, Header, HTTPException, Query, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse try: import spacy except Exception: # pragma: no cover - optional dependency spacy = None app = FastAPI(title="Audio Normalizer", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], expose_headers=[ "X-Input-LUFS", "X-Input-TP", "X-Input-LRA", "X-Target-LUFS", "X-Applied-Gain", ], ) def _run_ffmpeg(args: list[str]) -> subprocess.CompletedProcess[str]: try: return subprocess.run( args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, ) except FileNotFoundError as exc: raise HTTPException(status_code=500, detail="ffmpeg not found in PATH") from exc except subprocess.CalledProcessError as exc: stderr = (exc.stderr or "").strip() detail = stderr.splitlines()[-1] if stderr else "ffmpeg failed" raise HTTPException(status_code=500, detail=detail) from exc def _extract_loudnorm_json(stderr: str) -> Dict[str, Any]: start = stderr.rfind("{") end = stderr.rfind("}") if start == -1 or end == -1 or end <= start: raise ValueError("Unable to parse loudnorm output") payload = stderr[start : end + 1] return json.loads(payload) def _map_measured(data: Dict[str, Any]) -> Dict[str, float]: return { "measured_I": float(data["input_i"]), "measured_TP": float(data["input_tp"]), "measured_LRA": float(data["input_lra"]), "measured_thresh": float(data["input_thresh"]), "offset": float(data["target_offset"]), } def _clamp_target(measured_i: float, target_i: float, max_gain_db: float | None) -> Tuple[float, float]: gain = target_i - measured_i if max_gain_db is None: return target_i, gain if gain > max_gain_db: return measured_i + max_gain_db, max_gain_db if gain < -max_gain_db: return measured_i - max_gain_db, -max_gain_db return target_i, gain SPACY_MODEL_MAP = { "ca": "ca_core_news_sm", "zh": "zh_core_web_sm", "hr": "hr_core_news_sm", "da": "da_core_news_sm", "nl": "nl_core_news_sm", "en": "en_core_web_sm", "fi": "fi_core_news_sm", "fr": "fr_core_news_sm", "de": "de_core_news_sm", "el": "el_core_news_sm", "it": "it_core_news_sm", "ja": "ja_core_news_sm", "ko": "ko_core_news_sm", "lt": "lt_core_news_sm", "mk": "mk_core_news_sm", "nb": "nb_core_news_sm", "pl": "pl_core_news_sm", "pt": "pt_core_news_sm", "ro": "ro_core_news_sm", "ru": "ru_core_news_sm", "sl": "sl_core_news_sm", "es": "es_core_news_sm", "sv": "sv_core_news_sm", "uk": "uk_core_news_sm", } _SPACY_CACHE: Dict[str, Any] = {} SYNTACTIC_WEAK_BREAK_POS = {"CCONJ", "SCONJ", "ADP"} def _load_spacy_model(lang_code: str | None): if spacy is None: return None code = (lang_code or "en").lower().replace("_", "-") code = code.split("-")[0] if code in _SPACY_CACHE: return _SPACY_CACHE[code] model_name = SPACY_MODEL_MAP.get(code) nlp = None if model_name: try: nlp = spacy.load(model_name) except Exception: nlp = None if nlp is None: try: nlp = spacy.blank(code) except Exception: nlp = spacy.blank("xx") if "sentencizer" not in nlp.pipe_names: nlp.add_pipe("sentencizer") _SPACY_CACHE[code] = nlp return nlp def _coerce_word_level(word_level: Dict[str, Any]) -> Dict[str, Any]: if not isinstance(word_level, dict): return {"segments": []} if "segments" in word_level and isinstance(word_level["segments"], list): return word_level words = word_level.get("words") if isinstance(words, list): return {"segments": [{"words": words}]} return {"segments": []} def _clean_word(text: str) -> str: return re.sub(r"[^\w\s.,?!;:'\"-]", "", text).strip() def _normalize_words(word_level_result: Dict[str, Any], auto_clean: bool) -> Dict[str, Any]: segments = [] for segment in word_level_result.get("segments", []): words = [] for word_info in segment.get("words", []): raw = word_info.get("word") or word_info.get("text") or "" if not raw: continue word_text = _clean_word(raw) if auto_clean else raw.strip() if not word_text: continue try: start = float(word_info.get("start")) end = float(word_info.get("end")) except (TypeError, ValueError): continue words.append({"word": word_text, "start": start, "end": end}) if words: segments.append({"words": words}) return {"segments": segments} def _create_smart_tokens(word_level_result: Dict[str, Any]): smart_tokens = [] punctuation_pattern = re.compile(r"([^\w\s]+)$") all_words = [ word for segment in word_level_result.get("segments", []) for word in segment.get("words", []) if "start" in word ] current_char_offset = 0 for word_info in all_words: word_text = word_info.get("word", "").strip() if not word_text: continue text_part, punct_part = word_text, "" match = punctuation_pattern.search(word_text) if match: punctuation = match.group(1) text_part = word_text[: -len(punctuation)] punct_part = punctuation smart_tokens.append({ "text": text_part, "punct": punct_part, "start": word_info.get("start"), "end": word_info.get("end"), "original": word_text, "char_start_index": current_char_offset, "spacy_token": None, }) current_char_offset += len(word_text) + 1 full_text = " ".join([tok["original"] for tok in smart_tokens]) return smart_tokens, full_text def _map_spacy_to_smart_tokens(smart_tokens, full_text, nlp_model): if not nlp_model: return doc = nlp_model(full_text) if not spacy.tokens.Token.has_extension("noun_chunk_id"): spacy.tokens.Token.set_extension("noun_chunk_id", default=None) can_use_noun_chunks = False try: can_use_noun_chunks = doc.has_annotation("DEP") except Exception: can_use_noun_chunks = False if can_use_noun_chunks: try: for chunk_id, chunk in enumerate(doc.noun_chunks): for token in chunk: token._.noun_chunk_id = chunk_id except (NotImplementedError, AttributeError, ValueError): pass spacy_token_map = {spacy_tok.idx: spacy_tok for spacy_tok in doc} for smart_tok in smart_tokens: if smart_tok["char_start_index"] in spacy_token_map: smart_tok["spacy_token"] = spacy_token_map[smart_tok["char_start_index"]] def _get_break_score(current_token_index: int, smart_tokens: list, mode: str) -> int: current_token = smart_tokens[current_token_index] if not current_token: return 0 current_spacy = current_token.get("spacy_token") next_spacy = smart_tokens[current_token_index + 1].get("spacy_token") if (current_token_index + 1) < len(smart_tokens) else None if current_spacy and next_spacy and hasattr(current_spacy._, "noun_chunk_id") and hasattr(next_spacy._, "noun_chunk_id"): if current_spacy._.noun_chunk_id is not None and current_spacy._.noun_chunk_id == next_spacy._.noun_chunk_id: return -10 semantic_score = 0 if current_token["punct"]: if any(p in current_token["punct"] for p in ".?!"): semantic_score = 10 elif any(p in current_token["punct"] for p in ",:;"): semantic_score = 8 gap_score = 0 if mode == "rhythmic" and (current_token_index + 1) < len(smart_tokens): next_token = smart_tokens[current_token_index + 1] gap = next_token["start"] - current_token["end"] if gap > 0.5: gap_score = 20 elif gap > 0.3: gap_score = 15 elif gap > 0.15: gap_score = 10 syntactic_score = 0 if current_spacy: if next_spacy and next_spacy.dep_ in {"mark", "relcl"}: syntactic_score = 7 elif current_spacy.pos_ == "CCONJ": syntactic_score = 3 elif current_spacy.pos_ == "ADP": syntactic_score = 1 if mode == "rhythmic": return gap_score + semantic_score + syntactic_score return semantic_score + syntactic_score def master_segmenter( word_level_result: Dict[str, Any], lang_code: str | None, max_chars: int, max_lines: int, nlp_model, mode: str = "semantic", min_len_percent: int = 60, flex_zone_percent: int = 100, max_extension_sec: float = 0.7, gap_threshold_ms: int = 10, high_score_threshold: int = 15, ): if not word_level_result or not word_level_result.get("segments"): return [] smart_tokens, full_text = _create_smart_tokens(word_level_result) if not smart_tokens: return [] _map_spacy_to_smart_tokens(smart_tokens, full_text, nlp_model) final_blocks = [] current_token_index = 0 while current_token_index < len(smart_tokens): build_limit = int(max_chars * max_lines * (flex_zone_percent / 100.0)) segment_tokens = [] for i in range(current_token_index, len(smart_tokens)): token_to_add = smart_tokens[i] preview_segment = segment_tokens + [token_to_add] current_len = sum(len(t["original"]) for t in preview_segment) + (len(preview_segment) - 1) if current_len > build_limit and len(segment_tokens) > 0: break segment_tokens.append(token_to_add) if not segment_tokens: if current_token_index < len(smart_tokens): segment_tokens.append(smart_tokens[current_token_index]) else: break candidates = [] min_len_threshold = int(max_chars * (min_len_percent / 100.0)) for i in range(len(segment_tokens) - 1, -1, -1): temp_segment = segment_tokens[: i + 1] temp_len = sum(len(t["original"]) + 1 for t in temp_segment) - 1 real_token_index = current_token_index + i score = _get_break_score(real_token_index, smart_tokens, mode) if score > 0: if temp_len < min_len_threshold and score < 10: continue candidates.append({"index": i, "score": score, "length": temp_len}) best_break_index = len(segment_tokens) - 1 if candidates: max_score_in_candidates = max(c["score"] for c in candidates) good_candidates = [c for c in candidates if c["score"] >= max_score_in_candidates * 0.8] if good_candidates: best_candidate = min(good_candidates, key=lambda c: abs(c["length"] - max_chars)) best_break_index = best_candidate["index"] final_segment_tokens_preview = segment_tokens[: best_break_index + 1] final_len_preview = sum(len(t["original"]) + 1 for t in final_segment_tokens_preview) - 1 best_candidate_score = 0 if candidates: cand = next((c for c in candidates if c["index"] == best_break_index), None) if cand: best_candidate_score = cand["score"] if mode == "rhythmic" and final_len_preview > max_chars and best_candidate_score < high_score_threshold: safe_candidates = [c for c in candidates if c["length"] <= max_chars] if safe_candidates: best_break_index = max(safe_candidates, key=lambda c: c["score"])["index"] final_segment_tokens = segment_tokens[: best_break_index + 1] if final_segment_tokens: start_time = final_segment_tokens[0]["start"] original_end_time = final_segment_tokens[-1]["end"] new_end_time = original_end_time next_real_token_index = current_token_index + len(final_segment_tokens) if next_real_token_index < len(smart_tokens): next_token_after_segment = smart_tokens[next_real_token_index] next_start_time = next_token_after_segment["start"] ideal_extended_end = original_end_time + max_extension_sec safe_limit_end = next_start_time - (gap_threshold_ms / 1000.0) if safe_limit_end > original_end_time: new_end_time = min(ideal_extended_end, safe_limit_end) lines_text = [] current_line_text = "" for token in final_segment_tokens: word_to_add = token["original"] if not current_line_text: current_line_text = word_to_add elif len(current_line_text) + 1 + len(word_to_add) <= max_chars: current_line_text += " " + word_to_add elif len(lines_text) < max_lines - 1: lines_text.append(current_line_text) current_line_text = word_to_add else: current_line_text += " " + word_to_add lines_text.append(current_line_text) final_blocks.append({ "text": "\n".join(lines_text), "start": start_time, "end": new_end_time, }) current_token_index += len(final_segment_tokens) else: current_token_index += 1 return final_blocks @app.get("/health") def health() -> Dict[str, str]: return {"status": "ok"} @app.post("/normalize") async def normalize_audio( background_tasks: BackgroundTasks, audio: UploadFile = File(...), target_lufs: float = Query(-16.0, description="Target integrated loudness (LUFS)"), true_peak: float = Query(-1.0, description="True peak limit (dBTP)"), lra: float = Query(11.0, description="Target loudness range"), sample_rate: int = Query(48000, description="Output sample rate"), channels: int = Query(1, description="Output channels"), max_gain_db: float | None = Query(20.0, description="Max gain change in dB"), output_format: str = Query("wav", description="Output format (wav|mp3)"), x_worker_auth: str | None = Header(default=None, alias="x-worker-auth"), ) -> FileResponse: secret = os.getenv("NORMALIZE_WORKER_AUTH_KEY") or os.getenv("TTS_WORKER_AUTH_KEY") if secret and x_worker_auth != secret: raise HTTPException(status_code=403, detail="Invalid worker secret") if audio.filename is None: raise HTTPException(status_code=400, detail="Missing filename") normalized_format = output_format.strip().lower() if normalized_format not in {"wav", "mp3"}: raise HTTPException(status_code=400, detail="Unsupported output format") tmp_dir = tempfile.mkdtemp(prefix="normalize_") input_path = os.path.join(tmp_dir, audio.filename) output_path = os.path.join(tmp_dir, f"normalized.{normalized_format}") try: with open(input_path, "wb") as out_file: shutil.copyfileobj(audio.file, out_file) pass1 = _run_ffmpeg([ "ffmpeg", "-hide_banner", "-y", "-i", input_path, "-af", f"loudnorm=I={target_lufs}:TP={true_peak}:LRA={lra}:print_format=json", "-f", "null", "-", ]) measured = _map_measured(_extract_loudnorm_json(pass1.stderr)) adjusted_target, applied_gain = _clamp_target(measured["measured_I"], target_lufs, max_gain_db) loudnorm_filter = ( f"loudnorm=I={adjusted_target}:TP={true_peak}:LRA={lra}:" f"measured_I={measured['measured_I']}:" f"measured_TP={measured['measured_TP']}:" f"measured_LRA={measured['measured_LRA']}:" f"measured_thresh={measured['measured_thresh']}:" f"offset={measured['offset']}:" "linear=true:print_format=summary" ) output_args = [ "ffmpeg", "-hide_banner", "-y", "-i", input_path, "-af", loudnorm_filter, "-ar", str(sample_rate), "-ac", str(channels), ] if normalized_format == "mp3": output_args.extend(["-codec:a", "libmp3lame", "-q:a", "2"]) output_args.append(output_path) _run_ffmpeg(output_args) finally: await audio.close() headers = { "X-Input-LUFS": f"{measured['measured_I']:.2f}", "X-Input-TP": f"{measured['measured_TP']:.2f}", "X-Input-LRA": f"{measured['measured_LRA']:.2f}", "X-Target-LUFS": f"{adjusted_target:.2f}", "X-Applied-Gain": f"{applied_gain:.2f}", } background_tasks.add_task(shutil.rmtree, tmp_dir, ignore_errors=True) media_type = "audio/mpeg" if normalized_format == "mp3" else "audio/wav" return FileResponse(output_path, media_type=media_type, filename=f"normalized.{normalized_format}", headers=headers, background=background_tasks) @app.post("/subtitles") async def generate_subtitles( payload: Dict[str, Any] = Body(...), x_worker_auth: str | None = Header(default=None, alias="x-worker-auth"), ) -> Dict[str, Any]: secret = ( os.getenv("SUBTITLE_WORKER_AUTH_KEY") or os.getenv("NORMALIZE_WORKER_AUTH_KEY") or os.getenv("TTS_WORKER_AUTH_KEY") ) if secret and x_worker_auth != secret: raise HTTPException(status_code=403, detail="Invalid worker secret") word_level = payload.get("word_level") or payload.get("wordLevel") if not word_level: raise HTTPException(status_code=400, detail="Missing word_level") settings = payload.get("settings") or {} word_level_result = _coerce_word_level(word_level) auto_clean = bool(settings.get("auto_clean_special_chars", False)) normalized_word_level = _normalize_words(word_level_result, auto_clean) if not normalized_word_level.get("segments"): raise HTTPException(status_code=400, detail="No words to segment") auto_segment = settings.get("auto_segment", True) if not auto_segment: words = [word for segment in normalized_word_level["segments"] for word in segment.get("words", [])] start = words[0]["start"] end = words[-1]["end"] text = " ".join(word["word"] for word in words) return {"segments": [{"text": text, "start": start, "end": end}]} max_chars = int(settings.get("max_chars", 42)) max_lines = int(settings.get("max_lines", 2)) min_len_percent = int(settings.get("min_len_percent", 60)) flex_zone_percent = int(settings.get("flex_zone_percent", 130)) mode = settings.get("mode", "semantic") language_code = payload.get("language_code") or "en" nlp_model = _load_spacy_model(language_code) segments = master_segmenter( normalized_word_level, language_code, max_chars, max_lines, nlp_model, mode=mode, min_len_percent=min_len_percent, flex_zone_percent=flex_zone_percent, ) return {"segments": segments} @app.exception_handler(Exception) async def handle_unexpected_error(_, exc: Exception): return JSONResponse(status_code=500, content={"error": str(exc)})