toanatp's picture
Update app.py
428e9e4 verified
from __future__ import annotations
import json
import os
import re
import shutil
import subprocess
import tempfile
from typing import Any, Dict, Tuple
from fastapi import BackgroundTasks, Body, FastAPI, File, Header, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
try:
import spacy
except Exception: # pragma: no cover - optional dependency
spacy = None
app = FastAPI(title="Audio Normalizer", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=[
"X-Input-LUFS",
"X-Input-TP",
"X-Input-LRA",
"X-Target-LUFS",
"X-Applied-Gain",
],
)
def _run_ffmpeg(args: list[str]) -> subprocess.CompletedProcess[str]:
try:
return subprocess.run(
args,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
except FileNotFoundError as exc:
raise HTTPException(status_code=500, detail="ffmpeg not found in PATH") from exc
except subprocess.CalledProcessError as exc:
stderr = (exc.stderr or "").strip()
detail = stderr.splitlines()[-1] if stderr else "ffmpeg failed"
raise HTTPException(status_code=500, detail=detail) from exc
def _extract_loudnorm_json(stderr: str) -> Dict[str, Any]:
start = stderr.rfind("{")
end = stderr.rfind("}")
if start == -1 or end == -1 or end <= start:
raise ValueError("Unable to parse loudnorm output")
payload = stderr[start : end + 1]
return json.loads(payload)
def _map_measured(data: Dict[str, Any]) -> Dict[str, float]:
return {
"measured_I": float(data["input_i"]),
"measured_TP": float(data["input_tp"]),
"measured_LRA": float(data["input_lra"]),
"measured_thresh": float(data["input_thresh"]),
"offset": float(data["target_offset"]),
}
def _clamp_target(measured_i: float, target_i: float, max_gain_db: float | None) -> Tuple[float, float]:
gain = target_i - measured_i
if max_gain_db is None:
return target_i, gain
if gain > max_gain_db:
return measured_i + max_gain_db, max_gain_db
if gain < -max_gain_db:
return measured_i - max_gain_db, -max_gain_db
return target_i, gain
SPACY_MODEL_MAP = {
"ca": "ca_core_news_sm",
"zh": "zh_core_web_sm",
"hr": "hr_core_news_sm",
"da": "da_core_news_sm",
"nl": "nl_core_news_sm",
"en": "en_core_web_sm",
"fi": "fi_core_news_sm",
"fr": "fr_core_news_sm",
"de": "de_core_news_sm",
"el": "el_core_news_sm",
"it": "it_core_news_sm",
"ja": "ja_core_news_sm",
"ko": "ko_core_news_sm",
"lt": "lt_core_news_sm",
"mk": "mk_core_news_sm",
"nb": "nb_core_news_sm",
"pl": "pl_core_news_sm",
"pt": "pt_core_news_sm",
"ro": "ro_core_news_sm",
"ru": "ru_core_news_sm",
"sl": "sl_core_news_sm",
"es": "es_core_news_sm",
"sv": "sv_core_news_sm",
"uk": "uk_core_news_sm",
}
_SPACY_CACHE: Dict[str, Any] = {}
SYNTACTIC_WEAK_BREAK_POS = {"CCONJ", "SCONJ", "ADP"}
def _load_spacy_model(lang_code: str | None):
if spacy is None:
return None
code = (lang_code or "en").lower().replace("_", "-")
code = code.split("-")[0]
if code in _SPACY_CACHE:
return _SPACY_CACHE[code]
model_name = SPACY_MODEL_MAP.get(code)
nlp = None
if model_name:
try:
nlp = spacy.load(model_name)
except Exception:
nlp = None
if nlp is None:
try:
nlp = spacy.blank(code)
except Exception:
nlp = spacy.blank("xx")
if "sentencizer" not in nlp.pipe_names:
nlp.add_pipe("sentencizer")
_SPACY_CACHE[code] = nlp
return nlp
def _coerce_word_level(word_level: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(word_level, dict):
return {"segments": []}
if "segments" in word_level and isinstance(word_level["segments"], list):
return word_level
words = word_level.get("words")
if isinstance(words, list):
return {"segments": [{"words": words}]}
return {"segments": []}
def _clean_word(text: str) -> str:
return re.sub(r"[^\w\s.,?!;:'\"-]", "", text).strip()
def _normalize_words(word_level_result: Dict[str, Any], auto_clean: bool) -> Dict[str, Any]:
segments = []
for segment in word_level_result.get("segments", []):
words = []
for word_info in segment.get("words", []):
raw = word_info.get("word") or word_info.get("text") or ""
if not raw:
continue
word_text = _clean_word(raw) if auto_clean else raw.strip()
if not word_text:
continue
try:
start = float(word_info.get("start"))
end = float(word_info.get("end"))
except (TypeError, ValueError):
continue
words.append({"word": word_text, "start": start, "end": end})
if words:
segments.append({"words": words})
return {"segments": segments}
def _create_smart_tokens(word_level_result: Dict[str, Any]):
smart_tokens = []
punctuation_pattern = re.compile(r"([^\w\s]+)$")
all_words = [
word
for segment in word_level_result.get("segments", [])
for word in segment.get("words", [])
if "start" in word
]
current_char_offset = 0
for word_info in all_words:
word_text = word_info.get("word", "").strip()
if not word_text:
continue
text_part, punct_part = word_text, ""
match = punctuation_pattern.search(word_text)
if match:
punctuation = match.group(1)
text_part = word_text[: -len(punctuation)]
punct_part = punctuation
smart_tokens.append({
"text": text_part,
"punct": punct_part,
"start": word_info.get("start"),
"end": word_info.get("end"),
"original": word_text,
"char_start_index": current_char_offset,
"spacy_token": None,
})
current_char_offset += len(word_text) + 1
full_text = " ".join([tok["original"] for tok in smart_tokens])
return smart_tokens, full_text
def _map_spacy_to_smart_tokens(smart_tokens, full_text, nlp_model):
if not nlp_model:
return
doc = nlp_model(full_text)
if not spacy.tokens.Token.has_extension("noun_chunk_id"):
spacy.tokens.Token.set_extension("noun_chunk_id", default=None)
can_use_noun_chunks = False
try:
can_use_noun_chunks = doc.has_annotation("DEP")
except Exception:
can_use_noun_chunks = False
if can_use_noun_chunks:
try:
for chunk_id, chunk in enumerate(doc.noun_chunks):
for token in chunk:
token._.noun_chunk_id = chunk_id
except (NotImplementedError, AttributeError, ValueError):
pass
spacy_token_map = {spacy_tok.idx: spacy_tok for spacy_tok in doc}
for smart_tok in smart_tokens:
if smart_tok["char_start_index"] in spacy_token_map:
smart_tok["spacy_token"] = spacy_token_map[smart_tok["char_start_index"]]
def _get_break_score(current_token_index: int, smart_tokens: list, mode: str) -> int:
current_token = smart_tokens[current_token_index]
if not current_token:
return 0
current_spacy = current_token.get("spacy_token")
next_spacy = smart_tokens[current_token_index + 1].get("spacy_token") if (current_token_index + 1) < len(smart_tokens) else None
if current_spacy and next_spacy and hasattr(current_spacy._, "noun_chunk_id") and hasattr(next_spacy._, "noun_chunk_id"):
if current_spacy._.noun_chunk_id is not None and current_spacy._.noun_chunk_id == next_spacy._.noun_chunk_id:
return -10
semantic_score = 0
if current_token["punct"]:
if any(p in current_token["punct"] for p in ".?!"):
semantic_score = 10
elif any(p in current_token["punct"] for p in ",:;"):
semantic_score = 8
gap_score = 0
if mode == "rhythmic" and (current_token_index + 1) < len(smart_tokens):
next_token = smart_tokens[current_token_index + 1]
gap = next_token["start"] - current_token["end"]
if gap > 0.5:
gap_score = 20
elif gap > 0.3:
gap_score = 15
elif gap > 0.15:
gap_score = 10
syntactic_score = 0
if current_spacy:
if next_spacy and next_spacy.dep_ in {"mark", "relcl"}:
syntactic_score = 7
elif current_spacy.pos_ == "CCONJ":
syntactic_score = 3
elif current_spacy.pos_ == "ADP":
syntactic_score = 1
if mode == "rhythmic":
return gap_score + semantic_score + syntactic_score
return semantic_score + syntactic_score
def master_segmenter(
word_level_result: Dict[str, Any],
lang_code: str | None,
max_chars: int,
max_lines: int,
nlp_model,
mode: str = "semantic",
min_len_percent: int = 60,
flex_zone_percent: int = 100,
max_extension_sec: float = 0.7,
gap_threshold_ms: int = 10,
high_score_threshold: int = 15,
):
if not word_level_result or not word_level_result.get("segments"):
return []
smart_tokens, full_text = _create_smart_tokens(word_level_result)
if not smart_tokens:
return []
_map_spacy_to_smart_tokens(smart_tokens, full_text, nlp_model)
final_blocks = []
current_token_index = 0
while current_token_index < len(smart_tokens):
build_limit = int(max_chars * max_lines * (flex_zone_percent / 100.0))
segment_tokens = []
for i in range(current_token_index, len(smart_tokens)):
token_to_add = smart_tokens[i]
preview_segment = segment_tokens + [token_to_add]
current_len = sum(len(t["original"]) for t in preview_segment) + (len(preview_segment) - 1)
if current_len > build_limit and len(segment_tokens) > 0:
break
segment_tokens.append(token_to_add)
if not segment_tokens:
if current_token_index < len(smart_tokens):
segment_tokens.append(smart_tokens[current_token_index])
else:
break
candidates = []
min_len_threshold = int(max_chars * (min_len_percent / 100.0))
for i in range(len(segment_tokens) - 1, -1, -1):
temp_segment = segment_tokens[: i + 1]
temp_len = sum(len(t["original"]) + 1 for t in temp_segment) - 1
real_token_index = current_token_index + i
score = _get_break_score(real_token_index, smart_tokens, mode)
if score > 0:
if temp_len < min_len_threshold and score < 10:
continue
candidates.append({"index": i, "score": score, "length": temp_len})
best_break_index = len(segment_tokens) - 1
if candidates:
max_score_in_candidates = max(c["score"] for c in candidates)
good_candidates = [c for c in candidates if c["score"] >= max_score_in_candidates * 0.8]
if good_candidates:
best_candidate = min(good_candidates, key=lambda c: abs(c["length"] - max_chars))
best_break_index = best_candidate["index"]
final_segment_tokens_preview = segment_tokens[: best_break_index + 1]
final_len_preview = sum(len(t["original"]) + 1 for t in final_segment_tokens_preview) - 1
best_candidate_score = 0
if candidates:
cand = next((c for c in candidates if c["index"] == best_break_index), None)
if cand:
best_candidate_score = cand["score"]
if mode == "rhythmic" and final_len_preview > max_chars and best_candidate_score < high_score_threshold:
safe_candidates = [c for c in candidates if c["length"] <= max_chars]
if safe_candidates:
best_break_index = max(safe_candidates, key=lambda c: c["score"])["index"]
final_segment_tokens = segment_tokens[: best_break_index + 1]
if final_segment_tokens:
start_time = final_segment_tokens[0]["start"]
original_end_time = final_segment_tokens[-1]["end"]
new_end_time = original_end_time
next_real_token_index = current_token_index + len(final_segment_tokens)
if next_real_token_index < len(smart_tokens):
next_token_after_segment = smart_tokens[next_real_token_index]
next_start_time = next_token_after_segment["start"]
ideal_extended_end = original_end_time + max_extension_sec
safe_limit_end = next_start_time - (gap_threshold_ms / 1000.0)
if safe_limit_end > original_end_time:
new_end_time = min(ideal_extended_end, safe_limit_end)
lines_text = []
current_line_text = ""
for token in final_segment_tokens:
word_to_add = token["original"]
if not current_line_text:
current_line_text = word_to_add
elif len(current_line_text) + 1 + len(word_to_add) <= max_chars:
current_line_text += " " + word_to_add
elif len(lines_text) < max_lines - 1:
lines_text.append(current_line_text)
current_line_text = word_to_add
else:
current_line_text += " " + word_to_add
lines_text.append(current_line_text)
final_blocks.append({
"text": "\n".join(lines_text),
"start": start_time,
"end": new_end_time,
})
current_token_index += len(final_segment_tokens)
else:
current_token_index += 1
return final_blocks
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.post("/normalize")
async def normalize_audio(
background_tasks: BackgroundTasks,
audio: UploadFile = File(...),
target_lufs: float = Query(-16.0, description="Target integrated loudness (LUFS)"),
true_peak: float = Query(-1.0, description="True peak limit (dBTP)"),
lra: float = Query(11.0, description="Target loudness range"),
sample_rate: int = Query(48000, description="Output sample rate"),
channels: int = Query(1, description="Output channels"),
max_gain_db: float | None = Query(20.0, description="Max gain change in dB"),
output_format: str = Query("wav", description="Output format (wav|mp3)"),
x_worker_auth: str | None = Header(default=None, alias="x-worker-auth"),
) -> FileResponse:
secret = os.getenv("NORMALIZE_WORKER_AUTH_KEY") or os.getenv("TTS_WORKER_AUTH_KEY")
if secret and x_worker_auth != secret:
raise HTTPException(status_code=403, detail="Invalid worker secret")
if audio.filename is None:
raise HTTPException(status_code=400, detail="Missing filename")
normalized_format = output_format.strip().lower()
if normalized_format not in {"wav", "mp3"}:
raise HTTPException(status_code=400, detail="Unsupported output format")
tmp_dir = tempfile.mkdtemp(prefix="normalize_")
input_path = os.path.join(tmp_dir, audio.filename)
output_path = os.path.join(tmp_dir, f"normalized.{normalized_format}")
try:
with open(input_path, "wb") as out_file:
shutil.copyfileobj(audio.file, out_file)
pass1 = _run_ffmpeg([
"ffmpeg",
"-hide_banner",
"-y",
"-i",
input_path,
"-af",
f"loudnorm=I={target_lufs}:TP={true_peak}:LRA={lra}:print_format=json",
"-f",
"null",
"-",
])
measured = _map_measured(_extract_loudnorm_json(pass1.stderr))
adjusted_target, applied_gain = _clamp_target(measured["measured_I"], target_lufs, max_gain_db)
loudnorm_filter = (
f"loudnorm=I={adjusted_target}:TP={true_peak}:LRA={lra}:"
f"measured_I={measured['measured_I']}:"
f"measured_TP={measured['measured_TP']}:"
f"measured_LRA={measured['measured_LRA']}:"
f"measured_thresh={measured['measured_thresh']}:"
f"offset={measured['offset']}:"
"linear=true:print_format=summary"
)
output_args = [
"ffmpeg",
"-hide_banner",
"-y",
"-i",
input_path,
"-af",
loudnorm_filter,
"-ar",
str(sample_rate),
"-ac",
str(channels),
]
if normalized_format == "mp3":
output_args.extend(["-codec:a", "libmp3lame", "-q:a", "2"])
output_args.append(output_path)
_run_ffmpeg(output_args)
finally:
await audio.close()
headers = {
"X-Input-LUFS": f"{measured['measured_I']:.2f}",
"X-Input-TP": f"{measured['measured_TP']:.2f}",
"X-Input-LRA": f"{measured['measured_LRA']:.2f}",
"X-Target-LUFS": f"{adjusted_target:.2f}",
"X-Applied-Gain": f"{applied_gain:.2f}",
}
background_tasks.add_task(shutil.rmtree, tmp_dir, ignore_errors=True)
media_type = "audio/mpeg" if normalized_format == "mp3" else "audio/wav"
return FileResponse(output_path, media_type=media_type, filename=f"normalized.{normalized_format}", headers=headers, background=background_tasks)
@app.post("/subtitles")
async def generate_subtitles(
payload: Dict[str, Any] = Body(...),
x_worker_auth: str | None = Header(default=None, alias="x-worker-auth"),
) -> Dict[str, Any]:
secret = (
os.getenv("SUBTITLE_WORKER_AUTH_KEY")
or os.getenv("NORMALIZE_WORKER_AUTH_KEY")
or os.getenv("TTS_WORKER_AUTH_KEY")
)
if secret and x_worker_auth != secret:
raise HTTPException(status_code=403, detail="Invalid worker secret")
word_level = payload.get("word_level") or payload.get("wordLevel")
if not word_level:
raise HTTPException(status_code=400, detail="Missing word_level")
settings = payload.get("settings") or {}
word_level_result = _coerce_word_level(word_level)
auto_clean = bool(settings.get("auto_clean_special_chars", False))
normalized_word_level = _normalize_words(word_level_result, auto_clean)
if not normalized_word_level.get("segments"):
raise HTTPException(status_code=400, detail="No words to segment")
auto_segment = settings.get("auto_segment", True)
if not auto_segment:
words = [word for segment in normalized_word_level["segments"] for word in segment.get("words", [])]
start = words[0]["start"]
end = words[-1]["end"]
text = " ".join(word["word"] for word in words)
return {"segments": [{"text": text, "start": start, "end": end}]}
max_chars = int(settings.get("max_chars", 42))
max_lines = int(settings.get("max_lines", 2))
min_len_percent = int(settings.get("min_len_percent", 60))
flex_zone_percent = int(settings.get("flex_zone_percent", 130))
mode = settings.get("mode", "semantic")
language_code = payload.get("language_code") or "en"
nlp_model = _load_spacy_model(language_code)
segments = master_segmenter(
normalized_word_level,
language_code,
max_chars,
max_lines,
nlp_model,
mode=mode,
min_len_percent=min_len_percent,
flex_zone_percent=flex_zone_percent,
)
return {"segments": segments}
@app.exception_handler(Exception)
async def handle_unexpected_error(_, exc: Exception):
return JSONResponse(status_code=500, content={"error": str(exc)})