Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| from typing import Any, Dict, Tuple | |
| from fastapi import BackgroundTasks, Body, FastAPI, File, Header, HTTPException, Query, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse | |
| try: | |
| import spacy | |
| except Exception: # pragma: no cover - optional dependency | |
| spacy = None | |
| app = FastAPI(title="Audio Normalizer", version="0.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=[ | |
| "X-Input-LUFS", | |
| "X-Input-TP", | |
| "X-Input-LRA", | |
| "X-Target-LUFS", | |
| "X-Applied-Gain", | |
| ], | |
| ) | |
| def _run_ffmpeg(args: list[str]) -> subprocess.CompletedProcess[str]: | |
| try: | |
| return subprocess.run( | |
| args, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| check=True, | |
| ) | |
| except FileNotFoundError as exc: | |
| raise HTTPException(status_code=500, detail="ffmpeg not found in PATH") from exc | |
| except subprocess.CalledProcessError as exc: | |
| stderr = (exc.stderr or "").strip() | |
| detail = stderr.splitlines()[-1] if stderr else "ffmpeg failed" | |
| raise HTTPException(status_code=500, detail=detail) from exc | |
| def _extract_loudnorm_json(stderr: str) -> Dict[str, Any]: | |
| start = stderr.rfind("{") | |
| end = stderr.rfind("}") | |
| if start == -1 or end == -1 or end <= start: | |
| raise ValueError("Unable to parse loudnorm output") | |
| payload = stderr[start : end + 1] | |
| return json.loads(payload) | |
| def _map_measured(data: Dict[str, Any]) -> Dict[str, float]: | |
| return { | |
| "measured_I": float(data["input_i"]), | |
| "measured_TP": float(data["input_tp"]), | |
| "measured_LRA": float(data["input_lra"]), | |
| "measured_thresh": float(data["input_thresh"]), | |
| "offset": float(data["target_offset"]), | |
| } | |
| def _clamp_target(measured_i: float, target_i: float, max_gain_db: float | None) -> Tuple[float, float]: | |
| gain = target_i - measured_i | |
| if max_gain_db is None: | |
| return target_i, gain | |
| if gain > max_gain_db: | |
| return measured_i + max_gain_db, max_gain_db | |
| if gain < -max_gain_db: | |
| return measured_i - max_gain_db, -max_gain_db | |
| return target_i, gain | |
| SPACY_MODEL_MAP = { | |
| "ca": "ca_core_news_sm", | |
| "zh": "zh_core_web_sm", | |
| "hr": "hr_core_news_sm", | |
| "da": "da_core_news_sm", | |
| "nl": "nl_core_news_sm", | |
| "en": "en_core_web_sm", | |
| "fi": "fi_core_news_sm", | |
| "fr": "fr_core_news_sm", | |
| "de": "de_core_news_sm", | |
| "el": "el_core_news_sm", | |
| "it": "it_core_news_sm", | |
| "ja": "ja_core_news_sm", | |
| "ko": "ko_core_news_sm", | |
| "lt": "lt_core_news_sm", | |
| "mk": "mk_core_news_sm", | |
| "nb": "nb_core_news_sm", | |
| "pl": "pl_core_news_sm", | |
| "pt": "pt_core_news_sm", | |
| "ro": "ro_core_news_sm", | |
| "ru": "ru_core_news_sm", | |
| "sl": "sl_core_news_sm", | |
| "es": "es_core_news_sm", | |
| "sv": "sv_core_news_sm", | |
| "uk": "uk_core_news_sm", | |
| } | |
| _SPACY_CACHE: Dict[str, Any] = {} | |
| SYNTACTIC_WEAK_BREAK_POS = {"CCONJ", "SCONJ", "ADP"} | |
| def _load_spacy_model(lang_code: str | None): | |
| if spacy is None: | |
| return None | |
| code = (lang_code or "en").lower().replace("_", "-") | |
| code = code.split("-")[0] | |
| if code in _SPACY_CACHE: | |
| return _SPACY_CACHE[code] | |
| model_name = SPACY_MODEL_MAP.get(code) | |
| nlp = None | |
| if model_name: | |
| try: | |
| nlp = spacy.load(model_name) | |
| except Exception: | |
| nlp = None | |
| if nlp is None: | |
| try: | |
| nlp = spacy.blank(code) | |
| except Exception: | |
| nlp = spacy.blank("xx") | |
| if "sentencizer" not in nlp.pipe_names: | |
| nlp.add_pipe("sentencizer") | |
| _SPACY_CACHE[code] = nlp | |
| return nlp | |
| def _coerce_word_level(word_level: Dict[str, Any]) -> Dict[str, Any]: | |
| if not isinstance(word_level, dict): | |
| return {"segments": []} | |
| if "segments" in word_level and isinstance(word_level["segments"], list): | |
| return word_level | |
| words = word_level.get("words") | |
| if isinstance(words, list): | |
| return {"segments": [{"words": words}]} | |
| return {"segments": []} | |
| def _clean_word(text: str) -> str: | |
| return re.sub(r"[^\w\s.,?!;:'\"-]", "", text).strip() | |
| def _normalize_words(word_level_result: Dict[str, Any], auto_clean: bool) -> Dict[str, Any]: | |
| segments = [] | |
| for segment in word_level_result.get("segments", []): | |
| words = [] | |
| for word_info in segment.get("words", []): | |
| raw = word_info.get("word") or word_info.get("text") or "" | |
| if not raw: | |
| continue | |
| word_text = _clean_word(raw) if auto_clean else raw.strip() | |
| if not word_text: | |
| continue | |
| try: | |
| start = float(word_info.get("start")) | |
| end = float(word_info.get("end")) | |
| except (TypeError, ValueError): | |
| continue | |
| words.append({"word": word_text, "start": start, "end": end}) | |
| if words: | |
| segments.append({"words": words}) | |
| return {"segments": segments} | |
| def _create_smart_tokens(word_level_result: Dict[str, Any]): | |
| smart_tokens = [] | |
| punctuation_pattern = re.compile(r"([^\w\s]+)$") | |
| all_words = [ | |
| word | |
| for segment in word_level_result.get("segments", []) | |
| for word in segment.get("words", []) | |
| if "start" in word | |
| ] | |
| current_char_offset = 0 | |
| for word_info in all_words: | |
| word_text = word_info.get("word", "").strip() | |
| if not word_text: | |
| continue | |
| text_part, punct_part = word_text, "" | |
| match = punctuation_pattern.search(word_text) | |
| if match: | |
| punctuation = match.group(1) | |
| text_part = word_text[: -len(punctuation)] | |
| punct_part = punctuation | |
| smart_tokens.append({ | |
| "text": text_part, | |
| "punct": punct_part, | |
| "start": word_info.get("start"), | |
| "end": word_info.get("end"), | |
| "original": word_text, | |
| "char_start_index": current_char_offset, | |
| "spacy_token": None, | |
| }) | |
| current_char_offset += len(word_text) + 1 | |
| full_text = " ".join([tok["original"] for tok in smart_tokens]) | |
| return smart_tokens, full_text | |
| def _map_spacy_to_smart_tokens(smart_tokens, full_text, nlp_model): | |
| if not nlp_model: | |
| return | |
| doc = nlp_model(full_text) | |
| if not spacy.tokens.Token.has_extension("noun_chunk_id"): | |
| spacy.tokens.Token.set_extension("noun_chunk_id", default=None) | |
| can_use_noun_chunks = False | |
| try: | |
| can_use_noun_chunks = doc.has_annotation("DEP") | |
| except Exception: | |
| can_use_noun_chunks = False | |
| if can_use_noun_chunks: | |
| try: | |
| for chunk_id, chunk in enumerate(doc.noun_chunks): | |
| for token in chunk: | |
| token._.noun_chunk_id = chunk_id | |
| except (NotImplementedError, AttributeError, ValueError): | |
| pass | |
| spacy_token_map = {spacy_tok.idx: spacy_tok for spacy_tok in doc} | |
| for smart_tok in smart_tokens: | |
| if smart_tok["char_start_index"] in spacy_token_map: | |
| smart_tok["spacy_token"] = spacy_token_map[smart_tok["char_start_index"]] | |
| def _get_break_score(current_token_index: int, smart_tokens: list, mode: str) -> int: | |
| current_token = smart_tokens[current_token_index] | |
| if not current_token: | |
| return 0 | |
| current_spacy = current_token.get("spacy_token") | |
| next_spacy = smart_tokens[current_token_index + 1].get("spacy_token") if (current_token_index + 1) < len(smart_tokens) else None | |
| if current_spacy and next_spacy and hasattr(current_spacy._, "noun_chunk_id") and hasattr(next_spacy._, "noun_chunk_id"): | |
| if current_spacy._.noun_chunk_id is not None and current_spacy._.noun_chunk_id == next_spacy._.noun_chunk_id: | |
| return -10 | |
| semantic_score = 0 | |
| if current_token["punct"]: | |
| if any(p in current_token["punct"] for p in ".?!"): | |
| semantic_score = 10 | |
| elif any(p in current_token["punct"] for p in ",:;"): | |
| semantic_score = 8 | |
| gap_score = 0 | |
| if mode == "rhythmic" and (current_token_index + 1) < len(smart_tokens): | |
| next_token = smart_tokens[current_token_index + 1] | |
| gap = next_token["start"] - current_token["end"] | |
| if gap > 0.5: | |
| gap_score = 20 | |
| elif gap > 0.3: | |
| gap_score = 15 | |
| elif gap > 0.15: | |
| gap_score = 10 | |
| syntactic_score = 0 | |
| if current_spacy: | |
| if next_spacy and next_spacy.dep_ in {"mark", "relcl"}: | |
| syntactic_score = 7 | |
| elif current_spacy.pos_ == "CCONJ": | |
| syntactic_score = 3 | |
| elif current_spacy.pos_ == "ADP": | |
| syntactic_score = 1 | |
| if mode == "rhythmic": | |
| return gap_score + semantic_score + syntactic_score | |
| return semantic_score + syntactic_score | |
| def master_segmenter( | |
| word_level_result: Dict[str, Any], | |
| lang_code: str | None, | |
| max_chars: int, | |
| max_lines: int, | |
| nlp_model, | |
| mode: str = "semantic", | |
| min_len_percent: int = 60, | |
| flex_zone_percent: int = 100, | |
| max_extension_sec: float = 0.7, | |
| gap_threshold_ms: int = 10, | |
| high_score_threshold: int = 15, | |
| ): | |
| if not word_level_result or not word_level_result.get("segments"): | |
| return [] | |
| smart_tokens, full_text = _create_smart_tokens(word_level_result) | |
| if not smart_tokens: | |
| return [] | |
| _map_spacy_to_smart_tokens(smart_tokens, full_text, nlp_model) | |
| final_blocks = [] | |
| current_token_index = 0 | |
| while current_token_index < len(smart_tokens): | |
| build_limit = int(max_chars * max_lines * (flex_zone_percent / 100.0)) | |
| segment_tokens = [] | |
| for i in range(current_token_index, len(smart_tokens)): | |
| token_to_add = smart_tokens[i] | |
| preview_segment = segment_tokens + [token_to_add] | |
| current_len = sum(len(t["original"]) for t in preview_segment) + (len(preview_segment) - 1) | |
| if current_len > build_limit and len(segment_tokens) > 0: | |
| break | |
| segment_tokens.append(token_to_add) | |
| if not segment_tokens: | |
| if current_token_index < len(smart_tokens): | |
| segment_tokens.append(smart_tokens[current_token_index]) | |
| else: | |
| break | |
| candidates = [] | |
| min_len_threshold = int(max_chars * (min_len_percent / 100.0)) | |
| for i in range(len(segment_tokens) - 1, -1, -1): | |
| temp_segment = segment_tokens[: i + 1] | |
| temp_len = sum(len(t["original"]) + 1 for t in temp_segment) - 1 | |
| real_token_index = current_token_index + i | |
| score = _get_break_score(real_token_index, smart_tokens, mode) | |
| if score > 0: | |
| if temp_len < min_len_threshold and score < 10: | |
| continue | |
| candidates.append({"index": i, "score": score, "length": temp_len}) | |
| best_break_index = len(segment_tokens) - 1 | |
| if candidates: | |
| max_score_in_candidates = max(c["score"] for c in candidates) | |
| good_candidates = [c for c in candidates if c["score"] >= max_score_in_candidates * 0.8] | |
| if good_candidates: | |
| best_candidate = min(good_candidates, key=lambda c: abs(c["length"] - max_chars)) | |
| best_break_index = best_candidate["index"] | |
| final_segment_tokens_preview = segment_tokens[: best_break_index + 1] | |
| final_len_preview = sum(len(t["original"]) + 1 for t in final_segment_tokens_preview) - 1 | |
| best_candidate_score = 0 | |
| if candidates: | |
| cand = next((c for c in candidates if c["index"] == best_break_index), None) | |
| if cand: | |
| best_candidate_score = cand["score"] | |
| if mode == "rhythmic" and final_len_preview > max_chars and best_candidate_score < high_score_threshold: | |
| safe_candidates = [c for c in candidates if c["length"] <= max_chars] | |
| if safe_candidates: | |
| best_break_index = max(safe_candidates, key=lambda c: c["score"])["index"] | |
| final_segment_tokens = segment_tokens[: best_break_index + 1] | |
| if final_segment_tokens: | |
| start_time = final_segment_tokens[0]["start"] | |
| original_end_time = final_segment_tokens[-1]["end"] | |
| new_end_time = original_end_time | |
| next_real_token_index = current_token_index + len(final_segment_tokens) | |
| if next_real_token_index < len(smart_tokens): | |
| next_token_after_segment = smart_tokens[next_real_token_index] | |
| next_start_time = next_token_after_segment["start"] | |
| ideal_extended_end = original_end_time + max_extension_sec | |
| safe_limit_end = next_start_time - (gap_threshold_ms / 1000.0) | |
| if safe_limit_end > original_end_time: | |
| new_end_time = min(ideal_extended_end, safe_limit_end) | |
| lines_text = [] | |
| current_line_text = "" | |
| for token in final_segment_tokens: | |
| word_to_add = token["original"] | |
| if not current_line_text: | |
| current_line_text = word_to_add | |
| elif len(current_line_text) + 1 + len(word_to_add) <= max_chars: | |
| current_line_text += " " + word_to_add | |
| elif len(lines_text) < max_lines - 1: | |
| lines_text.append(current_line_text) | |
| current_line_text = word_to_add | |
| else: | |
| current_line_text += " " + word_to_add | |
| lines_text.append(current_line_text) | |
| final_blocks.append({ | |
| "text": "\n".join(lines_text), | |
| "start": start_time, | |
| "end": new_end_time, | |
| }) | |
| current_token_index += len(final_segment_tokens) | |
| else: | |
| current_token_index += 1 | |
| return final_blocks | |
| def health() -> Dict[str, str]: | |
| return {"status": "ok"} | |
| async def normalize_audio( | |
| background_tasks: BackgroundTasks, | |
| audio: UploadFile = File(...), | |
| target_lufs: float = Query(-16.0, description="Target integrated loudness (LUFS)"), | |
| true_peak: float = Query(-1.0, description="True peak limit (dBTP)"), | |
| lra: float = Query(11.0, description="Target loudness range"), | |
| sample_rate: int = Query(48000, description="Output sample rate"), | |
| channels: int = Query(1, description="Output channels"), | |
| max_gain_db: float | None = Query(20.0, description="Max gain change in dB"), | |
| output_format: str = Query("wav", description="Output format (wav|mp3)"), | |
| x_worker_auth: str | None = Header(default=None, alias="x-worker-auth"), | |
| ) -> FileResponse: | |
| secret = os.getenv("NORMALIZE_WORKER_AUTH_KEY") or os.getenv("TTS_WORKER_AUTH_KEY") | |
| if secret and x_worker_auth != secret: | |
| raise HTTPException(status_code=403, detail="Invalid worker secret") | |
| if audio.filename is None: | |
| raise HTTPException(status_code=400, detail="Missing filename") | |
| normalized_format = output_format.strip().lower() | |
| if normalized_format not in {"wav", "mp3"}: | |
| raise HTTPException(status_code=400, detail="Unsupported output format") | |
| tmp_dir = tempfile.mkdtemp(prefix="normalize_") | |
| input_path = os.path.join(tmp_dir, audio.filename) | |
| output_path = os.path.join(tmp_dir, f"normalized.{normalized_format}") | |
| try: | |
| with open(input_path, "wb") as out_file: | |
| shutil.copyfileobj(audio.file, out_file) | |
| pass1 = _run_ffmpeg([ | |
| "ffmpeg", | |
| "-hide_banner", | |
| "-y", | |
| "-i", | |
| input_path, | |
| "-af", | |
| f"loudnorm=I={target_lufs}:TP={true_peak}:LRA={lra}:print_format=json", | |
| "-f", | |
| "null", | |
| "-", | |
| ]) | |
| measured = _map_measured(_extract_loudnorm_json(pass1.stderr)) | |
| adjusted_target, applied_gain = _clamp_target(measured["measured_I"], target_lufs, max_gain_db) | |
| loudnorm_filter = ( | |
| f"loudnorm=I={adjusted_target}:TP={true_peak}:LRA={lra}:" | |
| f"measured_I={measured['measured_I']}:" | |
| f"measured_TP={measured['measured_TP']}:" | |
| f"measured_LRA={measured['measured_LRA']}:" | |
| f"measured_thresh={measured['measured_thresh']}:" | |
| f"offset={measured['offset']}:" | |
| "linear=true:print_format=summary" | |
| ) | |
| output_args = [ | |
| "ffmpeg", | |
| "-hide_banner", | |
| "-y", | |
| "-i", | |
| input_path, | |
| "-af", | |
| loudnorm_filter, | |
| "-ar", | |
| str(sample_rate), | |
| "-ac", | |
| str(channels), | |
| ] | |
| if normalized_format == "mp3": | |
| output_args.extend(["-codec:a", "libmp3lame", "-q:a", "2"]) | |
| output_args.append(output_path) | |
| _run_ffmpeg(output_args) | |
| finally: | |
| await audio.close() | |
| headers = { | |
| "X-Input-LUFS": f"{measured['measured_I']:.2f}", | |
| "X-Input-TP": f"{measured['measured_TP']:.2f}", | |
| "X-Input-LRA": f"{measured['measured_LRA']:.2f}", | |
| "X-Target-LUFS": f"{adjusted_target:.2f}", | |
| "X-Applied-Gain": f"{applied_gain:.2f}", | |
| } | |
| background_tasks.add_task(shutil.rmtree, tmp_dir, ignore_errors=True) | |
| media_type = "audio/mpeg" if normalized_format == "mp3" else "audio/wav" | |
| return FileResponse(output_path, media_type=media_type, filename=f"normalized.{normalized_format}", headers=headers, background=background_tasks) | |
| async def generate_subtitles( | |
| payload: Dict[str, Any] = Body(...), | |
| x_worker_auth: str | None = Header(default=None, alias="x-worker-auth"), | |
| ) -> Dict[str, Any]: | |
| secret = ( | |
| os.getenv("SUBTITLE_WORKER_AUTH_KEY") | |
| or os.getenv("NORMALIZE_WORKER_AUTH_KEY") | |
| or os.getenv("TTS_WORKER_AUTH_KEY") | |
| ) | |
| if secret and x_worker_auth != secret: | |
| raise HTTPException(status_code=403, detail="Invalid worker secret") | |
| word_level = payload.get("word_level") or payload.get("wordLevel") | |
| if not word_level: | |
| raise HTTPException(status_code=400, detail="Missing word_level") | |
| settings = payload.get("settings") or {} | |
| word_level_result = _coerce_word_level(word_level) | |
| auto_clean = bool(settings.get("auto_clean_special_chars", False)) | |
| normalized_word_level = _normalize_words(word_level_result, auto_clean) | |
| if not normalized_word_level.get("segments"): | |
| raise HTTPException(status_code=400, detail="No words to segment") | |
| auto_segment = settings.get("auto_segment", True) | |
| if not auto_segment: | |
| words = [word for segment in normalized_word_level["segments"] for word in segment.get("words", [])] | |
| start = words[0]["start"] | |
| end = words[-1]["end"] | |
| text = " ".join(word["word"] for word in words) | |
| return {"segments": [{"text": text, "start": start, "end": end}]} | |
| max_chars = int(settings.get("max_chars", 42)) | |
| max_lines = int(settings.get("max_lines", 2)) | |
| min_len_percent = int(settings.get("min_len_percent", 60)) | |
| flex_zone_percent = int(settings.get("flex_zone_percent", 130)) | |
| mode = settings.get("mode", "semantic") | |
| language_code = payload.get("language_code") or "en" | |
| nlp_model = _load_spacy_model(language_code) | |
| segments = master_segmenter( | |
| normalized_word_level, | |
| language_code, | |
| max_chars, | |
| max_lines, | |
| nlp_model, | |
| mode=mode, | |
| min_len_percent=min_len_percent, | |
| flex_zone_percent=flex_zone_percent, | |
| ) | |
| return {"segments": segments} | |
| async def handle_unexpected_error(_, exc: Exception): | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |