""" Trelis Chorus โ€” HF Space demo (CPU inference). Loads the merged Chorus model (base Whisper Turbo + LoRA merged + expanded tokenizer) once and serves a FastAPI + vanilla-JS UI that accepts uploaded or recorded audio and returns S1/S2 transcripts. CPU inference takes ~30-60s per 30s clip on the free HF Space tier. GPU tier would make this near-instant. """ import os, io, re, time from pathlib import Path import numpy as np import soundfile as sf import torch from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import HTMLResponse, JSONResponse, FileResponse import uvicorn # Merged model containing base Whisper Turbo + LoRA merged in + expanded tokenizer MODEL_REPO = os.environ.get("CHORUS_MODEL_REPO", "Trelis/Chorus-v1") SPEAKER1_TOKEN = "<|speaker1|>" SPEAKER2_TOKEN = "<|speaker2|>" SR = 16_000 if torch.cuda.is_available(): DEVICE, DTYPE = "cuda", torch.float16 _GPU_NAME = torch.cuda.get_device_name(0) else: DEVICE, DTYPE = "cpu", torch.float32 _GPU_NAME = None print(f"[chorus-space] Device: {DEVICE} ({DTYPE}){' โ€” ' + _GPU_NAME if _GPU_NAME else ''}, model: {MODEL_REPO}") _model = None _processor = None _tok_ids: dict = {} _TS_START_ID: int = -1 _TS_END_ID: int = -1 _TS_STEP = 0.02 def load_model(): global _model, _processor, _tok_ids, _TS_START_ID, _TS_END_ID if _model is not None: return from transformers import WhisperForConditionalGeneration, WhisperProcessor hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") print(f"[chorus-space] Loading {MODEL_REPO}...") t = time.time() proc = WhisperProcessor.from_pretrained(MODEL_REPO, token=hf_token) m = WhisperForConditionalGeneration.from_pretrained(MODEL_REPO, token=hf_token, dtype=DTYPE) m = m.to(DEVICE).eval() m.generation_config.predict_timestamps = True m.generation_config.max_initial_timestamp_index = 1500 _tok_ids["spk1"] = proc.tokenizer.convert_tokens_to_ids(SPEAKER1_TOKEN) _tok_ids["spk2"] = proc.tokenizer.convert_tokens_to_ids(SPEAKER2_TOKEN) _tok_ids["en"] = proc.tokenizer.convert_tokens_to_ids("<|en|>") _tok_ids["transcribe"] = proc.tokenizer.convert_tokens_to_ids("<|transcribe|>") _TS_START_ID = proc.tokenizer.convert_tokens_to_ids("<|0.00|>") _TS_END_ID = proc.tokenizer.convert_tokens_to_ids("<|30.00|>") _processor = proc _model = m print(f"[chorus-space] Model ready in {time.time()-t:.1f}s (ts range: {_TS_START_ID}..{_TS_END_ID})") def _infer(arr: np.ndarray, spk_id: int) -> list[dict]: feats = _processor.feature_extractor( [arr], sampling_rate=SR, return_tensors="pt" ).input_features.to(DEVICE).to(DTYPE) forced = [[1, _tok_ids["en"]], [2, _tok_ids["transcribe"]], [3, spk_id]] with torch.no_grad(): out = _model.generate( feats, forced_decoder_ids=forced, return_timestamps=True, max_new_tokens=444, ) return _parse_segments(out[0].tolist()) def _parse_segments(ids: list[int]) -> list[dict]: segments = [] cur_start = None cur_text_ids: list[int] = [] for t in ids: if _TS_START_ID <= t <= _TS_END_ID: ts = (t - _TS_START_ID) * _TS_STEP if cur_start is None: cur_start = ts else: text = _processor.tokenizer.decode(cur_text_ids, skip_special_tokens=True).strip() if text: segments.append({"start": round(cur_start, 2), "end": round(ts, 2), "text": text}) cur_start = None cur_text_ids = [] elif cur_start is not None: cur_text_ids.append(t) return segments def _decode_audio(audio_bytes: bytes) -> tuple[np.ndarray, int]: try: return sf.read(io.BytesIO(audio_bytes)) except Exception: import subprocess, tempfile with tempfile.NamedTemporaryFile(suffix=".bin") as fin: fin.write(audio_bytes) fin.flush() result = subprocess.run( ["ffmpeg", "-i", fin.name, "-f", "wav", "-ac", "1", "-ar", str(SR), "-"], capture_output=True, check=True, ) return sf.read(io.BytesIO(result.stdout)) def transcribe_bytes(audio_bytes: bytes) -> dict: t0 = time.time() arr, orig_sr = _decode_audio(audio_bytes) arr = np.asarray(arr, dtype=np.float32) if arr.ndim > 1: arr = arr.mean(axis=1) if orig_sr != SR: import librosa arr = librosa.resample(arr, orig_sr=orig_sr, target_sr=SR) max_samples = 30 * SR if len(arr) > max_samples: arr = arr[:max_samples] s1 = _infer(arr, _tok_ids["spk1"]) s2 = _infer(arr, _tok_ids["spk2"]) return { "duration_s": float(len(arr) / SR), "elapsed_s": time.time() - t0, "speaker1": {"segments": s1}, "speaker2": {"segments": s2}, } INDEX_HTML = r""" Trelis Chorus

Separate two voices
from a single stream.

Multi-speaker Whisper fine-tune by Trelis. Upload audio of two people talking — possibly overlapping — and Trelis Chorus returns a transcript for each speaker with timestamps.

""" app = FastAPI() @app.on_event("startup") def startup(): load_model() @app.get("/", response_class=HTMLResponse) def index(): return INDEX_HTML @app.get("/info") def info(): return {"model_repo": MODEL_REPO, "device": DEVICE, "gpu_name": _GPU_NAME} _SAMPLES = { "podcast": "sample_podcast.wav", } @app.get("/sample/{name}") def sample(name: str): fname = _SAMPLES.get(name) if not fname: raise HTTPException(404, f"Unknown sample: {name}") path = Path(__file__).parent / "static" / fname if not path.exists(): raise HTTPException(404, f"Sample file not found: {fname}") return FileResponse(str(path), media_type="audio/wav") @app.post("/transcribe") async def transcribe(file: UploadFile = File(...)): audio_bytes = await file.read() if len(audio_bytes) > 50 * 1024 * 1024: raise HTTPException(400, "File too large (50MB max).") try: return JSONResponse(transcribe_bytes(audio_bytes)) except Exception as e: raise HTTPException(500, f"Inference failed: {e}") if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) # HF Spaces default port uvicorn.run(app, host="0.0.0.0", port=port)