Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Literal, Optional | |
| import numpy as np | |
| import soundfile as sf | |
| import webrtcvad | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import PlainTextResponse | |
| from faster_whisper import WhisperModel | |
| DEVICE = "cpu" | |
| MODEL_NAME = "large-v2" | |
| COMPUTE_TYPE = "int8" | |
| SrtMode = Literal["sentence", "paragraph"] | |
| MIN_GAP_S = 0.08 | |
| MIN_DUR_S = 0.30 | |
| SILENCE_GAP_S = 0.50 | |
| whisper_model: Optional[WhisperModel] = None | |
| async def lifespan(app: FastAPI): | |
| global whisper_model | |
| print(f"Startup: loading faster-whisper '{MODEL_NAME}' on {DEVICE} ({COMPUTE_TYPE})...") | |
| whisper_model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE) | |
| print("Startup: ASR model ready") | |
| yield | |
| print("Shutdown: done") | |
| app = FastAPI(title="LyricSync Backend", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return { | |
| "service": "LyricSync Backend", | |
| "engine": "faster-whisper + demucs + VAD", | |
| "model": MODEL_NAME, | |
| "device": DEVICE, | |
| "compute_type": COMPUTE_TYPE, | |
| "status": "operational", | |
| } | |
| async def health(): | |
| return {"status": "healthy"} | |
| def _cleanup_temp_dir(path: str) -> None: | |
| shutil.rmtree(path, ignore_errors=True) | |
| def _format_srt_time(seconds: float) -> str: | |
| milliseconds_total = int(max(0.0, float(seconds)) * 1000) | |
| hours = milliseconds_total // 3_600_000 | |
| minutes = (milliseconds_total % 3_600_000) // 60_000 | |
| secs = (milliseconds_total % 60_000) // 1_000 | |
| millis = milliseconds_total % 1_000 | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" | |
| def _build_srt(segments: list[dict]) -> str: | |
| if not segments: | |
| return "" | |
| lines: list[str] = [] | |
| for idx, seg in enumerate(segments, start=1): | |
| text = (seg.get("text") or "").strip() | |
| start = seg.get("start") | |
| end = seg.get("end") | |
| if not text or start is None or end is None: | |
| continue | |
| lines.append(str(idx)) | |
| lines.append(f"{_format_srt_time(start)} --> {_format_srt_time(end)}") | |
| lines.append(text) | |
| lines.append("") | |
| return "\n".join(lines).rstrip() + "\n" | |
| def _run_cmd(cmd: list[str]) -> None: | |
| try: | |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
| except subprocess.CalledProcessError as e: | |
| stderr = (e.stderr or "") if isinstance(e.stderr, str) else str(e.stderr) | |
| stdout = (e.stdout or "") if isinstance(e.stdout, str) else str(e.stdout) | |
| def sanitize(stream: str) -> str: | |
| if not stream: | |
| return "" | |
| # Demucs/tqdm progress bars often use '\r' to rewrite the same line. | |
| stream = stream.replace("\r", "\n") | |
| # Keep only the last chunk to avoid flooding the UI. | |
| lines = [ln.rstrip() for ln in stream.splitlines() if ln.strip()] | |
| tail = lines[-60:] | |
| return "\n".join(tail) | |
| s_err = sanitize(stderr) | |
| s_out = sanitize(stdout) | |
| hint = "" | |
| if "No module named 'torchcodec'" in stderr or "TorchCodec is required" in stderr: | |
| hint = ( | |
| "\nHint: Demucs failed while saving audio because torchaudio requires torchcodec. " | |
| "Install/ship the 'torchcodec' Python package in the backend environment." | |
| ) | |
| detail_parts = [f"Command failed: {' '.join(cmd)}"] | |
| if s_err: | |
| detail_parts.append(s_err) | |
| elif s_out: | |
| detail_parts.append(s_out) | |
| if hint: | |
| detail_parts.append(hint) | |
| raise HTTPException(status_code=500, detail="\n".join(detail_parts)) from e | |
| def _ffmpeg_to_wav( | |
| input_path: str, | |
| output_path: str, | |
| *, | |
| sample_rate: int, | |
| mono: bool, | |
| ) -> None: | |
| channels = "1" if mono else "2" | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", | |
| "-i", | |
| input_path, | |
| "-vn", | |
| "-ac", | |
| channels, | |
| "-ar", | |
| str(sample_rate), | |
| "-f", | |
| "wav", | |
| output_path, | |
| ] | |
| _run_cmd(cmd) | |
| def _demucs_extract_vocals(input_wav_path: str, out_dir: str) -> str: | |
| """ | |
| Run Demucs vocals separation (two-stems=vocals) and return the vocals wav path. | |
| Uses the CLI for maximum compatibility across environments. | |
| """ | |
| separated_dir = os.path.join(out_dir, "demucs_separated") | |
| os.makedirs(separated_dir, exist_ok=True) | |
| cmd = [ | |
| "python", | |
| "-m", | |
| "demucs.separate", | |
| "-n", | |
| "htdemucs", | |
| "--two-stems", | |
| "vocals", | |
| "-o", | |
| separated_dir, | |
| input_wav_path, | |
| ] | |
| _run_cmd(cmd) | |
| # Demucs outputs: <out>/htdemucs/<trackname>/vocals.wav (may vary by version) | |
| vocals_candidates = list(Path(separated_dir).rglob("vocals.wav")) | |
| if not vocals_candidates: | |
| raise HTTPException(status_code=500, detail="Demucs did not produce vocals.wav") | |
| # Pick the newest/closest match deterministically | |
| vocals_candidates.sort(key=lambda p: (len(p.parts), str(p))) | |
| return str(vocals_candidates[0]) | |
| def _read_wav_pcm16_mono(path: str, *, sample_rate: int) -> bytes: | |
| audio, sr = sf.read(path, dtype="int16", always_2d=True) | |
| if sr != sample_rate: | |
| raise HTTPException(status_code=500, detail=f"VAD expected {sample_rate}Hz mono wav; got {sr}Hz") | |
| if audio.shape[1] != 1: | |
| raise HTTPException(status_code=500, detail="VAD expected mono wav (1 channel)") | |
| return audio[:, 0].tobytes() | |
| def _vad_voice_segments( | |
| wav_16k_mono_path: str, | |
| *, | |
| gap_s: float = SILENCE_GAP_S, | |
| vad_mode: int = 2, | |
| ) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]: | |
| """ | |
| Return (voice_segments, instrumental_gaps). | |
| - voice_segments: merged voiced ranges | |
| - instrumental_gaps: gaps between voiced ranges longer than gap_s | |
| """ | |
| sample_rate = 16000 | |
| pcm = _read_wav_pcm16_mono(wav_16k_mono_path, sample_rate=sample_rate) | |
| vad = webrtcvad.Vad(int(vad_mode)) | |
| frame_ms = 30 | |
| frame_bytes = int(sample_rate * (frame_ms / 1000.0) * 2) # 16-bit mono | |
| voiced_frames: list[tuple[float, float]] = [] | |
| offset_bytes = 0 | |
| total_bytes = len(pcm) | |
| while offset_bytes + frame_bytes <= total_bytes: | |
| frame = pcm[offset_bytes : offset_bytes + frame_bytes] | |
| t0 = (offset_bytes / 2) / sample_rate | |
| t1 = ((offset_bytes + frame_bytes) / 2) / sample_rate | |
| if vad.is_speech(frame, sample_rate): | |
| voiced_frames.append((t0, t1)) | |
| offset_bytes += frame_bytes | |
| if not voiced_frames: | |
| return ([], []) | |
| # Merge contiguous voiced frames with a small tolerance. | |
| merged: list[tuple[float, float]] = [] | |
| cur_s, cur_e = voiced_frames[0] | |
| tol = 0.06 | |
| for s, e in voiced_frames[1:]: | |
| if s <= cur_e + tol: | |
| cur_e = max(cur_e, e) | |
| else: | |
| merged.append((cur_s, cur_e)) | |
| cur_s, cur_e = s, e | |
| merged.append((cur_s, cur_e)) | |
| gaps: list[tuple[float, float]] = [] | |
| for (_s1, e1), (s2, _e2) in zip(merged, merged[1:]): | |
| if (s2 - e1) >= gap_s: | |
| gaps.append((e1, s2)) | |
| return (merged, gaps) | |
| class WordTS: | |
| display: str | |
| norm: str | |
| start: Optional[float] = None | |
| end: Optional[float] = None | |
| boundary_after: bool = False # punctuation boundary | |
| line_break_after: bool = False | |
| _STRONG_BOUNDARY_RE = re.compile(r"[.!?]+$") | |
| _PUNCT_STRIP_RE = re.compile(r"^[^\w']+|[^\w']+$", re.UNICODE) | |
| _NONWORD_RE = re.compile(r"[^\w']+", re.UNICODE) | |
| def _normalize_token(token: str) -> str: | |
| token = (token or "").strip().lower() | |
| token = token.replace("’", "'").replace("‘", "'").replace("´", "'") | |
| token = _PUNCT_STRIP_RE.sub("", token) | |
| token = _NONWORD_RE.sub("", token) | |
| return token | |
| def _cleanup_spacing(text: str) -> str: | |
| text = re.sub(r"\s+([,.;:!?])", r"\1", text) | |
| text = re.sub(r"\(\s+", "(", text) | |
| text = re.sub(r"\s+\)", ")", text) | |
| return text.strip() | |
| def _parse_lyrics_words(lyrics_text: str) -> list[WordTS]: | |
| words: list[WordTS] = [] | |
| for line in (lyrics_text or "").splitlines(): | |
| line = line.strip() | |
| if not line: | |
| # Preserve a strong boundary (line break) if we already have words. | |
| if words: | |
| last = words[-1] | |
| words[-1] = WordTS( | |
| display=last.display, | |
| norm=last.norm, | |
| start=last.start, | |
| end=last.end, | |
| boundary_after=True, | |
| line_break_after=True, | |
| ) | |
| continue | |
| tokens = [t for t in re.split(r"\s+", line) if t] | |
| for idx, tok in enumerate(tokens): | |
| norm = _normalize_token(tok) | |
| if not norm: | |
| continue | |
| boundary_after = bool(_STRONG_BOUNDARY_RE.search(tok)) | |
| line_break_after = idx == len(tokens) - 1 | |
| words.append(WordTS(display=tok, norm=norm, boundary_after=boundary_after, line_break_after=line_break_after)) | |
| return words | |
| def _flatten_asr_words(transcribe_segments) -> list[WordTS]: | |
| out: list[WordTS] = [] | |
| for seg in transcribe_segments: | |
| for w in (seg.words or []): | |
| tok = (w.word or "").strip() | |
| norm = _normalize_token(tok) | |
| if not norm: | |
| continue | |
| boundary_after = bool(_STRONG_BOUNDARY_RE.search(tok)) | |
| out.append(WordTS(display=tok, norm=norm, start=float(w.start), end=float(w.end), boundary_after=boundary_after)) | |
| out.sort(key=lambda x: (x.start or 0.0, x.end or 0.0)) | |
| return out | |
| def _similarity(a: str, b: str) -> float: | |
| if not a or not b: | |
| return 0.0 | |
| if a == b: | |
| return 1.0 | |
| # Cheap heuristics first | |
| if len(a) >= 3 and (a in b or b in a): | |
| return 0.86 | |
| # Standard library fuzzy match | |
| import difflib | |
| return difflib.SequenceMatcher(None, a, b).ratio() | |
| def _align_words_dp(lyrics: list[WordTS], asr: list[WordTS]) -> list[Optional[int]]: | |
| """ | |
| Needleman–Wunsch alignment on normalized token sequences. | |
| Returns: mapping lyric_index -> asr_index or None. | |
| """ | |
| n = len(lyrics) | |
| m = len(asr) | |
| if n == 0 or m == 0: | |
| return [None] * n | |
| # Backpointers for each row: 0=diag, 1=up (delete lyric), 2=left (insert asr) | |
| back: list[bytearray] = [bytearray(m + 1) for _ in range(n + 1)] | |
| prev = [float(j) for j in range(m + 1)] | |
| for i in range(1, n + 1): | |
| cur = [float(i)] + [0.0] * m | |
| for j in range(1, m + 1): | |
| sim = _similarity(lyrics[i - 1].norm, asr[j - 1].norm) | |
| sub_cost = 0.0 if sim >= 0.90 else (0.25 if sim >= 0.82 else (0.6 if sim >= 0.74 else 1.0)) | |
| diag = prev[j - 1] + sub_cost | |
| up = prev[j] + 1.0 | |
| left = cur[j - 1] + 1.0 | |
| best = diag | |
| move = 0 | |
| if up < best: | |
| best = up | |
| move = 1 | |
| if left < best: | |
| best = left | |
| move = 2 | |
| cur[j] = best | |
| back[i][j] = move | |
| prev = cur | |
| mapping: list[Optional[int]] = [None] * n | |
| i, j = n, m | |
| while i > 0 or j > 0: | |
| move = back[i][j] if i >= 0 and j >= 0 else 0 | |
| if i > 0 and j > 0 and move == 0: | |
| sim = _similarity(lyrics[i - 1].norm, asr[j - 1].norm) | |
| if sim >= 0.74: | |
| mapping[i - 1] = j - 1 | |
| i -= 1 | |
| j -= 1 | |
| elif i > 0 and (j == 0 or move == 1): | |
| i -= 1 | |
| else: | |
| j -= 1 | |
| return mapping | |
| def _interpolate_missing_timestamps( | |
| words: list[WordTS], | |
| *, | |
| voice_segments: Optional[list[tuple[float, float]]] = None, | |
| default_dur: float = 0.25, | |
| ) -> list[WordTS]: | |
| starts = [w.start for w in words] | |
| ends = [w.end for w in words] | |
| matched_durs = [float(e) - float(s) for s, e in zip(starts, ends, strict=False) if s is not None and e is not None and e > s] | |
| avg_dur = float(np.median(matched_durs)) if matched_durs else default_dur | |
| avg_dur = float(max(0.08, min(0.60, avg_dur))) | |
| def set_word(idx: int, s: float, e: float) -> None: | |
| nonlocal words | |
| w = words[idx] | |
| words[idx] = WordTS( | |
| display=w.display, | |
| norm=w.norm, | |
| start=float(s), | |
| end=float(e), | |
| boundary_after=w.boundary_after, | |
| line_break_after=w.line_break_after, | |
| ) | |
| def available_voice_ranges(left: float, right: float) -> list[tuple[float, float]]: | |
| if not voice_segments: | |
| return [] | |
| out: list[tuple[float, float]] = [] | |
| for vs, ve in voice_segments: | |
| s = max(left, float(vs)) | |
| e = min(right, float(ve)) | |
| if e > s: | |
| out.append((s, e)) | |
| return out | |
| # Fill internal runs | |
| i = 0 | |
| while i < len(words): | |
| if words[i].start is not None and words[i].end is not None: | |
| i += 1 | |
| continue | |
| run_start = i | |
| while i < len(words) and (words[i].start is None or words[i].end is None): | |
| i += 1 | |
| run_end = i - 1 | |
| prev_idx = run_start - 1 | |
| next_idx = i if i < len(words) else None | |
| if prev_idx >= 0 and next_idx is not None and words[prev_idx].end is not None and words[next_idx].start is not None: | |
| left_t = float(words[prev_idx].end) | |
| right_t = float(words[next_idx].start) | |
| k = (run_end - run_start) + 1 | |
| voice_ranges = available_voice_ranges(left_t, right_t) | |
| total_voice = sum(e - s for s, e in voice_ranges) | |
| if total_voice >= 0.20: | |
| # Distribute words across voiced regions only. | |
| cum = 0.0 | |
| for r in range(k): | |
| target = (r + 1) / (k + 1) * total_voice | |
| t = left_t | |
| cum_local = 0.0 | |
| for s, e in voice_ranges: | |
| dur = e - s | |
| if cum_local + dur >= target: | |
| t = s + (target - cum_local) | |
| t = min(max(t, s), e) | |
| break | |
| cum_local += dur | |
| s0 = float(t) | |
| # Keep the word fully inside the voice range when possible. | |
| end_limit = right_t | |
| for s, e in voice_ranges: | |
| if s0 >= s and s0 <= e: | |
| end_limit = e | |
| break | |
| e0 = min(float(end_limit), s0 + avg_dur) | |
| set_word(run_start + r, s0, max(e0, s0 + 0.06)) | |
| else: | |
| # Fallback: linear interpolation over the full span. | |
| span = max(0.001, right_t - left_t) | |
| step = span / (k + 1) | |
| for r in range(k): | |
| s0 = left_t + step * (r + 1) | |
| e0 = min(right_t, s0 + min(avg_dur, step * 0.9)) | |
| set_word(run_start + r, s0, max(e0, s0 + 0.06)) | |
| elif next_idx is not None and words[next_idx].start is not None: | |
| right_t = float(words[next_idx].start) | |
| k = (run_end - run_start) + 1 | |
| start_base = max(0.0, right_t - (avg_dur + 0.02) * k) | |
| for r in range(k): | |
| s = start_base + (avg_dur + 0.02) * r | |
| e = s + avg_dur | |
| set_word(run_start + r, s, e) | |
| elif prev_idx >= 0 and words[prev_idx].end is not None: | |
| left_t = float(words[prev_idx].end) | |
| k = (run_end - run_start) + 1 | |
| for r in range(k): | |
| s = left_t + (avg_dur + 0.02) * (r + 1) | |
| e = s + avg_dur | |
| set_word(run_start + r, s, e) | |
| else: | |
| # All missing; assign a simple ramp. | |
| for r in range(run_end - run_start + 1): | |
| s = (avg_dur + 0.02) * r | |
| e = s + avg_dur | |
| set_word(run_start + r, s, e) | |
| return words | |
| def _segment_from_words( | |
| words: list[WordTS], | |
| *, | |
| mode: SrtMode, | |
| silence_gap_s: float = SILENCE_GAP_S, | |
| ) -> list[dict]: | |
| if not words: | |
| return [] | |
| max_words = 8 if mode == "sentence" else 24 | |
| max_block_dur = 7.0 if mode == "sentence" else 14.0 | |
| segs: list[dict] = [] | |
| cur: list[WordTS] = [] | |
| def flush() -> None: | |
| nonlocal cur | |
| if not cur: | |
| return | |
| start = float(cur[0].start or 0.0) | |
| end = float(cur[-1].end or start) | |
| text = _cleanup_spacing(" ".join(w.display for w in cur)) | |
| if text: | |
| segs.append({"start": start, "end": end, "text": text}) | |
| cur = [] | |
| for idx, w in enumerate(words): | |
| if w.start is None or w.end is None: | |
| continue | |
| if cur: | |
| gap = float(w.start) - float(cur[-1].end or w.start) | |
| if gap >= silence_gap_s: | |
| flush() | |
| cur.append(w) | |
| # Splitting rules | |
| if len(cur) >= max_words: | |
| flush() | |
| continue | |
| if cur and (float(cur[-1].end or 0.0) - float(cur[0].start or 0.0)) >= max_block_dur: | |
| flush() | |
| continue | |
| if w.boundary_after: | |
| flush() | |
| continue | |
| if mode == "sentence" and w.line_break_after: | |
| flush() | |
| continue | |
| if mode == "paragraph" and w.line_break_after and len(cur) >= 16: | |
| flush() | |
| flush() | |
| return segs | |
| def _enforce_timing_rules(segments: list[dict]) -> list[dict]: | |
| if not segments: | |
| return [] | |
| segments = sorted(segments, key=lambda s: (float(s["start"]), float(s["end"]))) | |
| fixed: list[dict] = [] | |
| prev_end = 0.0 | |
| for seg in segments: | |
| start = float(seg["start"]) | |
| end = float(seg["end"]) | |
| text = (seg.get("text") or "").strip() | |
| if not text: | |
| continue | |
| start = max(start, prev_end + MIN_GAP_S) if fixed else max(0.0, start) | |
| end = max(end, start + MIN_DUR_S) | |
| fixed.append({"start": start, "end": end, "text": text}) | |
| prev_end = end | |
| return fixed | |
| def _overlaps_voice(start: float, end: float, voice_segments: list[tuple[float, float]]) -> bool: | |
| for vs, ve in voice_segments: | |
| if max(start, vs) < min(end, ve): | |
| return True | |
| return False | |
| def _instrumental_tag_segments(gaps: list[tuple[float, float]]) -> list[dict]: | |
| out: list[dict] = [] | |
| for s, e in gaps: | |
| if (e - s) >= SILENCE_GAP_S: | |
| out.append({"start": float(s), "end": float(e), "text": "[INSTRUMENTAL]"}) | |
| return out | |
| def _extract_window_wav(input_wav_16k: str, out_wav: str, start_s: float, end_s: float) -> None: | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", | |
| "-i", | |
| input_wav_16k, | |
| "-ss", | |
| f"{max(0.0, start_s):.3f}", | |
| "-to", | |
| f"{max(0.0, end_s):.3f}", | |
| "-ac", | |
| "1", | |
| "-ar", | |
| "16000", | |
| "-f", | |
| "wav", | |
| out_wav, | |
| ] | |
| _run_cmd(cmd) | |
| def _transcribe_words(wav_16k_mono_path: str, *, beam_size: int) -> list[WordTS]: | |
| if whisper_model is None: | |
| raise HTTPException(status_code=503, detail="ASR model is not ready") | |
| segments, _info = whisper_model.transcribe( | |
| wav_16k_mono_path, | |
| word_timestamps=True, | |
| beam_size=int(beam_size), | |
| best_of=max(beam_size, 5), | |
| temperature=0.0, | |
| vad_filter=False, | |
| condition_on_previous_text=False, | |
| ) | |
| return _flatten_asr_words(segments) | |
| def _fill_lyrics_timestamps_with_fallback( | |
| lyrics_words: list[WordTS], | |
| asr_words: list[WordTS], | |
| vocals_16k_path: str, | |
| voice_segments: list[tuple[float, float]], | |
| temp_dir: str, | |
| ) -> list[WordTS]: | |
| mapping = _align_words_dp(lyrics_words, asr_words) | |
| # Apply direct timestamps where matched. | |
| filled: list[WordTS] = [] | |
| for i, lw in enumerate(lyrics_words): | |
| j = mapping[i] | |
| if j is not None and asr_words[j].start is not None and asr_words[j].end is not None: | |
| aw = asr_words[j] | |
| filled.append( | |
| WordTS( | |
| display=lw.display, | |
| norm=lw.norm, | |
| start=float(aw.start), | |
| end=float(aw.end), | |
| boundary_after=lw.boundary_after, | |
| line_break_after=lw.line_break_after, | |
| ) | |
| ) | |
| else: | |
| filled.append( | |
| WordTS( | |
| display=lw.display, | |
| norm=lw.norm, | |
| start=None, | |
| end=None, | |
| boundary_after=lw.boundary_after, | |
| line_break_after=lw.line_break_after, | |
| ) | |
| ) | |
| # Identify long mismatch runs and try a windowed ASR pass (limited). | |
| max_windows = 3 | |
| i = 0 | |
| windows_done = 0 | |
| while i < len(filled) and windows_done < max_windows: | |
| if filled[i].start is not None: | |
| i += 1 | |
| continue | |
| run_start = i | |
| while i < len(filled) and filled[i].start is None: | |
| i += 1 | |
| run_end = i - 1 | |
| run_len = run_end - run_start + 1 | |
| if run_len < 10: | |
| continue | |
| # Window bounds from neighboring known timestamps. | |
| left_end = None | |
| right_start = None | |
| if run_start - 1 >= 0: | |
| left_end = filled[run_start - 1].end | |
| if run_end + 1 < len(filled): | |
| right_start = filled[run_end + 1].start | |
| if left_end is None or right_start is None: | |
| continue | |
| w_start = max(0.0, float(left_end) - 0.8) | |
| w_end = float(right_start) + 0.8 | |
| if (w_end - w_start) < 2.0: | |
| continue | |
| clip_path = os.path.join(temp_dir, f"asr_clip_{windows_done}.wav") | |
| _extract_window_wav(vocals_16k_path, clip_path, w_start, w_end) | |
| clip_words = _transcribe_words(clip_path, beam_size=10) | |
| # Offset clip words into global timeline | |
| clip_words_off = [ | |
| WordTS(display=w.display, norm=w.norm, start=float(w.start or 0.0) + w_start, end=float(w.end or 0.0) + w_start, boundary_after=w.boundary_after) | |
| for w in clip_words | |
| ] | |
| sub_lyrics = filled[run_start : run_end + 1] | |
| sub_mapping = _align_words_dp(sub_lyrics, clip_words_off) | |
| for k, j in enumerate(sub_mapping): | |
| if j is None: | |
| continue | |
| aw = clip_words_off[j] | |
| filled[run_start + k] = WordTS( | |
| display=filled[run_start + k].display, | |
| norm=filled[run_start + k].norm, | |
| start=float(aw.start or 0.0), | |
| end=float(aw.end or 0.0), | |
| boundary_after=filled[run_start + k].boundary_after, | |
| line_break_after=filled[run_start + k].line_break_after, | |
| ) | |
| windows_done += 1 | |
| return _interpolate_missing_timestamps(filled, voice_segments=voice_segments) | |
| async def generate_srt( | |
| audio_file: UploadFile = File(...), | |
| lyrics_text: str = Form(""), | |
| srt_mode: str = Form("sentence"), | |
| add_instrumental_tags: bool = Form(False), | |
| ): | |
| """ | |
| Production lyric-to-SRT pipeline (open-source only): | |
| 1) Demucs vocal isolation (vocals stem) | |
| 2) VAD on vocals stem (instrumental gaps) | |
| 3) faster-whisper ASR on vocals stem (word timestamps) | |
| 4) Optional lyrics-guided alignment (lyrics text becomes source of truth) | |
| 5) Segment into SRT (sentence/paragraph) with silence-aware splits | |
| """ | |
| if whisper_model is None: | |
| raise HTTPException(status_code=503, detail="ASR model is not ready") | |
| mode = (srt_mode or "").strip().lower() | |
| if mode not in ("sentence", "paragraph"): | |
| raise HTTPException(status_code=400, detail="Invalid srt_mode (expected 'sentence' or 'paragraph')") | |
| temp_dir = tempfile.mkdtemp(prefix="lyric-sync-") | |
| try: | |
| source_name = audio_file.filename or "audio" | |
| input_path = os.path.join(temp_dir, source_name) | |
| with open(input_path, "wb") as f: | |
| shutil.copyfileobj(audio_file.file, f) | |
| # Convert to a stable wav for Demucs | |
| input_wav = os.path.join(temp_dir, "input_44k_stereo.wav") | |
| _ffmpeg_to_wav(input_path, input_wav, sample_rate=44100, mono=False) | |
| vocals_wav = _demucs_extract_vocals(input_wav, temp_dir) | |
| # Canonical vocals wav for VAD + ASR (16k mono) | |
| vocals_16k = os.path.join(temp_dir, "vocals_16k_mono.wav") | |
| _ffmpeg_to_wav(vocals_wav, vocals_16k, sample_rate=16000, mono=True) | |
| voice_segments, instrumental_gaps = _vad_voice_segments(vocals_16k, gap_s=SILENCE_GAP_S) | |
| # ASR pass on vocals (word timestamps) | |
| asr_words = _transcribe_words(vocals_16k, beam_size=6) | |
| if not asr_words: | |
| return PlainTextResponse(content="", media_type="application/x-subrip") | |
| # Choose source-of-truth tokens | |
| lyrics_provided = bool((lyrics_text or "").strip()) | |
| if lyrics_provided: | |
| lyric_words = _parse_lyrics_words(lyrics_text) | |
| if not lyric_words: | |
| raise HTTPException(status_code=400, detail="Lyrics provided but no usable words were found") | |
| aligned_words = _fill_lyrics_timestamps_with_fallback(lyric_words, asr_words, vocals_16k, voice_segments, temp_dir) | |
| else: | |
| aligned_words = asr_words | |
| # Segment AFTER alignment/transcription | |
| segments = _segment_from_words(aligned_words, mode=mode) # type: ignore[arg-type] | |
| # Enforce "no subtitles during instrumentals" via VAD (drop segments outside voice) | |
| segments = [s for s in segments if _overlaps_voice(float(s["start"]), float(s["end"]), voice_segments)] | |
| if add_instrumental_tags: | |
| segments.extend(_instrumental_tag_segments(instrumental_gaps)) | |
| segments = _enforce_timing_rules(segments) | |
| return PlainTextResponse(content=_build_srt(segments), media_type="application/x-subrip") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) from e | |
| finally: | |
| try: | |
| audio_file.file.close() | |
| finally: | |
| _cleanup_temp_dir(temp_dir) | |
| async def align_compat( | |
| audio_file: UploadFile = File(...), | |
| lyrics_text: str = Form(""), | |
| srt_mode: str = Form("sentence"), | |
| add_instrumental_tags: bool = Form(False), | |
| ): | |
| # Backward-compat route for older frontend builds. | |
| return await generate_srt( | |
| audio_file=audio_file, | |
| lyrics_text=lyrics_text, | |
| srt_mode=srt_mode, | |
| add_instrumental_tags=add_instrumental_tags, | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 10000)) | |
| print(f"Starting LyricSync backend on port {port}...") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |