BoomConnext-demo / main.py
Its-OMG
Bumped torch to 2.7+ and torchcodec to 0.3 to restore ASR auto transcribe
a2eb473
"""
BoomConnex Voice Studio — single FastAPI backend.
Hosts three TTS endpoints (Voice Clone, Voice Design, Emotion TTS) and
serves the built React SPA on the same origin. Designed for a single
HuggingFace Space on a dedicated GPU.
API:
POST /api/voice-clone (multipart) → audio/wav
POST /api/voice-design (multipart) → audio/wav
POST /api/emotion-tts (multipart) → audio/wav
GET /api/health
GET /api/languages
GET /api/voice-design/options
GET / serves React SPA (with client-side routing)
"""
from __future__ import annotations
import io
import logging
import os
import tempfile
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
import numpy as np
import soundfile as sf
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, Response
from fastapi.staticfiles import StaticFiles
# ---- Optional heavy imports ------------------------------------------------
# All ML deps are optional at import time so the server can boot for
# plumbing-only testing with `LOAD_*=0` and a minimal pip install. Endpoints
# that need a missing dep return 503 at request time.
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
torch = None # type: ignore
TORCH_AVAILABLE = False
logging.warning("torch not installed — model inference disabled.")
try:
from omnivoice import OmniVoice, OmniVoiceGenerationConfig
from omnivoice.utils.lang_map import LANG_NAMES, lang_display_name
OMNIVOICE_AVAILABLE = True
except ImportError as e:
OMNIVOICE_AVAILABLE = False
LANG_NAMES = [] # type: ignore
def lang_display_name(n): # type: ignore
return n
logging.warning(f"OmniVoice not importable — voice-clone/design disabled: {e}")
try:
from LavaSR.model import LavaEnhance2
LAVASR_AVAILABLE = True
except ImportError:
LAVASR_AVAILABLE = False
logging.warning("LavaSR not installed — audio enhancement disabled.")
try:
from chatterbox.tts import ChatterboxTTS
CHATTERBOX_AVAILABLE = True
except ImportError:
CHATTERBOX_AVAILABLE = False
logging.warning("chatterbox-tts not installed — emotion TTS disabled.")
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
OMNIVOICE_CHECKPOINT = os.environ.get("OMNIVOICE_CHECKPOINT", "k2-fsa/OmniVoice")
LAVASR_CHECKPOINT = os.environ.get("LAVASR_CHECKPOINT", "YatharthS/LavaSR")
LOAD_OMNIVOICE = os.environ.get("LOAD_OMNIVOICE", "1") == "1"
LOAD_LAVASR = os.environ.get("LOAD_LAVASR", "1") == "1"
LOAD_CHATTERBOX = os.environ.get("LOAD_CHATTERBOX", "1") == "1"
LOAD_ASR = os.environ.get("LOAD_ASR", "1") == "1"
STATIC_DIR = Path(os.environ.get("STATIC_DIR", "static"))
def get_best_device() -> str:
if not TORCH_AVAILABLE:
return "cpu"
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
# ---------------------------------------------------------------------------
# Voice Design taxonomy (mirrors integrated_demo.py)
# ---------------------------------------------------------------------------
VD_CATEGORIES = {
"gender": ["Male", "Female"],
"age": ["Child", "Teenager", "Young Adult", "Middle-aged", "Elderly"],
"pitch": [
"Very Low Pitch", "Low Pitch", "Moderate Pitch",
"High Pitch", "Very High Pitch",
],
"style": ["Whisper"],
"english_accent": [
"American Accent", "Australian Accent", "British Accent",
"Chinese Accent", "Canadian Accent", "Indian Accent",
"Korean Accent", "Portuguese Accent", "Russian Accent",
"Japanese Accent",
],
"chinese_dialect": [
"Henan Dialect", "Shaanxi Dialect", "Sichuan Dialect",
"Guizhou Dialect", "Yunnan Dialect", "Guilin Dialect",
"Jinan Dialect", "Shijiazhuang Dialect", "Gansu Dialect",
"Ningxia Dialect", "Qingdao Dialect", "Northeast Dialect",
],
}
# ---------------------------------------------------------------------------
# State container for loaded models
# ---------------------------------------------------------------------------
class Models:
omnivoice: Optional["OmniVoice"] = None
omnivoice_sr: int = 0
lava: Optional["LavaEnhance2"] = None
chatterbox: Optional["ChatterboxTTS"] = None
chatterbox_sr: int = 0
device: str = "cpu"
models = Models()
@asynccontextmanager
async def lifespan(app: FastAPI):
models.device = get_best_device()
logging.info(f"Using device: {models.device}")
if LOAD_OMNIVOICE and OMNIVOICE_AVAILABLE:
logging.info(f"Loading OmniVoice from {OMNIVOICE_CHECKPOINT}…")
models.omnivoice = OmniVoice.from_pretrained(
OMNIVOICE_CHECKPOINT,
device_map=models.device,
dtype=torch.float16 if models.device == "cuda" else torch.float32,
load_asr=LOAD_ASR,
)
models.omnivoice_sr = models.omnivoice.sampling_rate
logging.info("OmniVoice loaded.")
elif LOAD_OMNIVOICE:
logging.warning("LOAD_OMNIVOICE=1 but OmniVoice not importable — skipping.")
if LOAD_LAVASR and LAVASR_AVAILABLE:
try:
logging.info(f"Loading LavaSR from {LAVASR_CHECKPOINT}…")
models.lava = LavaEnhance2(LAVASR_CHECKPOINT, models.device)
logging.info("LavaSR loaded.")
except Exception as e:
logging.warning(f"Failed to load LavaSR: {e}")
if LOAD_CHATTERBOX and CHATTERBOX_AVAILABLE:
logging.info("Loading ChatterboxTTS…")
models.chatterbox = ChatterboxTTS.from_pretrained(device=models.device)
models.chatterbox_sr = models.chatterbox.sr
logging.info("ChatterboxTTS loaded.")
elif LOAD_CHATTERBOX:
logging.warning("LOAD_CHATTERBOX=1 but chatterbox not importable — skipping.")
yield
# No teardown needed — process exit releases GPU memory.
app = FastAPI(title="BoomConnex Voice Studio", lifespan=lifespan)
# CORS — only matters for local dev when Vite serves on a different port.
# Production (Docker) serves SPA from same origin so this is a no-op.
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:8080", "http://127.0.0.1:8080"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def save_upload_to_tempfile(upload: UploadFile) -> str:
"""Write an UploadFile to a NamedTemporaryFile and return its path."""
suffix = Path(upload.filename or "audio.wav").suffix or ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(await upload.read())
return tmp.name
def numpy_to_wav_response(audio: np.ndarray, sr: int) -> Response:
"""Encode a (T,) float or int16 numpy array as a WAV byte response."""
if audio.ndim == 2:
audio = audio.squeeze(0)
if audio.dtype == np.float32 or audio.dtype == np.float64:
audio = np.clip(audio, -1.0, 1.0)
audio = (audio * 32767.0).astype(np.int16)
elif audio.dtype != np.int16:
audio = audio.astype(np.int16)
buf = io.BytesIO()
sf.write(buf, audio, sr, format="WAV", subtype="PCM_16")
return Response(content=buf.getvalue(), media_type="audio/wav")
def parse_optional_float(value: Optional[str]) -> Optional[float]:
if value is None or value == "" or value.lower() == "null":
return None
return float(value)
def parse_bool(value: Optional[str], default: bool = False) -> bool:
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
def lavasr_enhance_reference(ref_path: str, input_sr: int = 24000) -> str:
"""Run LavaSR over a reference audio file. Returns path to enhanced 48kHz WAV."""
if models.lava is None:
return ref_path
input_tensor, _ = models.lava.load_audio(ref_path, input_sr=int(input_sr))
output_tensor = models.lava.enhance(input_tensor, denoise=True, batch=False)
out_np = output_tensor.cpu().numpy().squeeze()
out_np = np.clip(out_np, -1.0, 1.0)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, out_np, 48000)
return tmp.name
def lavasr_enhance_output(audio: np.ndarray, sr: int) -> tuple[np.ndarray, int]:
"""Run LavaSR over a generated audio array. Returns (enhanced_audio_int16, 48000)."""
if models.lava is None:
return audio, sr
if audio.dtype == np.int16:
wf = audio.astype(np.float32) / 32767.0
else:
wf = audio.astype(np.float32)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, wf, sr)
tmp_path = tmp.name
try:
input_tensor, _ = models.lava.load_audio(tmp_path, input_sr=sr)
finally:
os.unlink(tmp_path)
output_tensor = models.lava.enhance(input_tensor, denoise=False, batch=False)
out_np = output_tensor.cpu().numpy().squeeze()
out_np = (np.clip(out_np, -1.0, 1.0) * 32767.0).astype(np.int16)
return out_np, 48000
def build_voice_design_instruct(
gender: Optional[str],
age: Optional[str],
pitch: Optional[str],
style: Optional[str],
english_accent: Optional[str],
chinese_dialect: Optional[str],
) -> Optional[str]:
"""Concatenate selected attributes into a Voice Design instruct string."""
parts = [
v.strip() for v in (gender, age, pitch, style, english_accent, chinese_dialect)
if v and v != "Auto"
]
return ", ".join(parts) if parts else None
# ---------------------------------------------------------------------------
# OmniVoice generation core (used by Voice Clone + Voice Design)
# ---------------------------------------------------------------------------
def omnivoice_generate(
text: str,
language: Optional[str],
ref_audio_path: Optional[str],
ref_text: Optional[str],
instruct: Optional[str],
num_step: int,
guidance_scale: float,
denoise: bool,
speed: Optional[float],
duration: Optional[float],
preprocess_prompt: bool,
postprocess_output: bool,
mode: str, # "clone" | "design"
) -> tuple[np.ndarray, int]:
if models.omnivoice is None or not OMNIVOICE_AVAILABLE:
raise HTTPException(503, "OmniVoice model not loaded.")
gen_config = OmniVoiceGenerationConfig(
num_step=int(num_step),
guidance_scale=float(guidance_scale),
denoise=bool(denoise),
preprocess_prompt=bool(preprocess_prompt),
postprocess_output=bool(postprocess_output),
)
lang = language if (language and language != "Auto") else None
kw: dict = dict(text=text.strip(), language=lang, generation_config=gen_config)
if speed is not None and float(speed) != 1.0:
kw["speed"] = float(speed)
if duration is not None and float(duration) > 0:
kw["duration"] = float(duration)
if mode == "clone":
if not ref_audio_path:
raise HTTPException(400, "Voice Clone requires a reference audio file.")
if not (ref_text and ref_text.strip()) and not LOAD_ASR:
# Auto-transcribe (Whisper) is disabled in this build because the
# transformers ASR pipeline pulls in torchcodec features that need
# torch>=2.7, and we're holding torch at 2.6 for chatterbox.
raise HTTPException(
400,
"Reference text is required: auto-transcribe is disabled in "
"this deployment. Please paste the transcript of your "
"reference audio in the 'Reference text' field.",
)
kw["voice_clone_prompt"] = models.omnivoice.create_voice_clone_prompt(
ref_audio=ref_audio_path,
ref_text=ref_text or None,
)
if instruct:
kw["instruct"] = instruct
audio = models.omnivoice.generate(**kw)
waveform = audio[0].squeeze(0).numpy()
return waveform, models.omnivoice_sr
# ---------------------------------------------------------------------------
# Routes — health & metadata
# ---------------------------------------------------------------------------
@app.get("/api/health")
def health():
return {
"device": models.device,
"omnivoice": models.omnivoice is not None,
"lavasr": models.lava is not None,
"chatterbox": models.chatterbox is not None,
}
@app.get("/api/languages")
def list_languages():
"""All OmniVoice-supported languages, plus the 'Auto' sentinel."""
return {
"languages": ["Auto"] + sorted(lang_display_name(n) for n in LANG_NAMES),
}
@app.get("/api/voice-design/options")
def voice_design_options():
"""Dropdown options for the Voice Design page."""
return {"categories": VD_CATEGORIES}
# ---------------------------------------------------------------------------
# Routes — Voice Clone
# ---------------------------------------------------------------------------
@app.post("/api/voice-clone")
async def voice_clone(
text: str = Form(...),
ref_audio: UploadFile = File(...),
ref_text: Optional[str] = Form(None),
language: Optional[str] = Form(None),
instruct: Optional[str] = Form(None),
num_step: int = Form(32),
guidance_scale: float = Form(2.0),
denoise: bool = Form(True),
speed: float = Form(1.0),
duration: Optional[str] = Form(None),
preprocess_prompt: bool = Form(True),
postprocess_output: bool = Form(True),
enhance_reference: bool = Form(True),
):
"""
Voice Clone: text + reference audio → cloned speech.
Reference audio is run through LavaSR before cloning when available.
"""
if not text.strip():
raise HTTPException(400, "Text is required.")
ref_path = await save_upload_to_tempfile(ref_audio)
enhanced_path: Optional[str] = None
try:
effective_ref = ref_path
if enhance_reference and models.lava is not None:
try:
enhanced_path = lavasr_enhance_reference(ref_path, input_sr=24000)
effective_ref = enhanced_path
except Exception as e:
logging.warning(f"LavaSR reference enhancement failed: {e}")
waveform, sr = omnivoice_generate(
text=text,
language=language,
ref_audio_path=effective_ref,
ref_text=ref_text,
instruct=instruct,
num_step=num_step,
guidance_scale=guidance_scale,
denoise=denoise,
speed=speed,
duration=parse_optional_float(duration),
preprocess_prompt=preprocess_prompt,
postprocess_output=postprocess_output,
mode="clone",
)
return numpy_to_wav_response(waveform, sr)
finally:
for p in (ref_path, enhanced_path):
if p and os.path.exists(p):
try:
os.unlink(p)
except OSError:
pass
# ---------------------------------------------------------------------------
# Routes — Voice Design
# ---------------------------------------------------------------------------
@app.post("/api/voice-design")
async def voice_design(
text: str = Form(...),
language: Optional[str] = Form(None),
gender: Optional[str] = Form(None),
age: Optional[str] = Form(None),
pitch: Optional[str] = Form(None),
style: Optional[str] = Form(None),
english_accent: Optional[str] = Form(None),
chinese_dialect: Optional[str] = Form(None),
num_step: int = Form(32),
guidance_scale: float = Form(2.0),
denoise: bool = Form(True),
speed: float = Form(1.0),
duration: Optional[str] = Form(None),
preprocess_prompt: bool = Form(True),
postprocess_output: bool = Form(True),
enhance_output: bool = Form(True),
):
"""
Voice Design: text + attribute dropdowns → synthesised voice.
Output is post-processed to 48 kHz with LavaSR when available.
"""
if not text.strip():
raise HTTPException(400, "Text is required.")
instruct = build_voice_design_instruct(
gender, age, pitch, style, english_accent, chinese_dialect,
)
waveform, sr = omnivoice_generate(
text=text,
language=language,
ref_audio_path=None,
ref_text=None,
instruct=instruct,
num_step=num_step,
guidance_scale=guidance_scale,
denoise=denoise,
speed=speed,
duration=parse_optional_float(duration),
preprocess_prompt=preprocess_prompt,
postprocess_output=postprocess_output,
mode="design",
)
if enhance_output and models.lava is not None:
try:
waveform, sr = lavasr_enhance_output(waveform, sr)
except Exception as e:
logging.warning(f"LavaSR output enhancement failed: {e}")
return numpy_to_wav_response(waveform, sr)
# ---------------------------------------------------------------------------
# Routes — Emotion TTS (Chatterbox)
# ---------------------------------------------------------------------------
@app.post("/api/emotion-tts")
async def emotion_tts(
text: str = Form(...),
ref_audio: Optional[UploadFile] = File(None),
exaggeration: float = Form(0.5),
cfg_weight: float = Form(0.5),
temperature: float = Form(1.0),
seed: int = Form(0),
):
"""
Chatterbox emotional TTS. Reference audio is optional — without it
Chatterbox uses its default voice. Emotion is encoded by the
(exaggeration, cfg_weight, temperature) triple sent from the frontend.
"""
if not text.strip():
raise HTTPException(400, "Text is required.")
if models.chatterbox is None or not CHATTERBOX_AVAILABLE:
raise HTTPException(503, "Chatterbox model not loaded.")
if seed and int(seed) > 0 and TORCH_AVAILABLE:
torch.manual_seed(int(seed))
ref_path: Optional[str] = None
try:
if ref_audio is not None and ref_audio.filename:
ref_path = await save_upload_to_tempfile(ref_audio)
kwargs = {
"exaggeration": float(exaggeration),
"cfg_weight": float(cfg_weight),
"temperature": float(temperature),
}
if ref_path:
kwargs["audio_prompt_path"] = ref_path
wav = models.chatterbox.generate(text, **kwargs)
audio_np = wav.squeeze(0).cpu().numpy()
return numpy_to_wav_response(audio_np, models.chatterbox_sr)
finally:
if ref_path and os.path.exists(ref_path):
try:
os.unlink(ref_path)
except OSError:
pass
# ---------------------------------------------------------------------------
# Static SPA serving (must come last so /api/* routes win)
# ---------------------------------------------------------------------------
if STATIC_DIR.is_dir():
# Serve hashed assets at /assets/*
assets_dir = STATIC_DIR / "assets"
if assets_dir.is_dir():
app.mount("/assets", StaticFiles(directory=assets_dir), name="assets")
@app.get("/{full_path:path}", include_in_schema=False)
async def spa_fallback(full_path: str):
# API routes already matched above; anything else falls back to the SPA.
if full_path.startswith("api/"):
return JSONResponse({"detail": "Not Found"}, status_code=404)
# Serve a real top-level file if it exists (favicon, robots.txt, etc.)
candidate = STATIC_DIR / full_path
if candidate.is_file():
return FileResponse(candidate)
index = STATIC_DIR / "index.html"
if index.is_file():
return FileResponse(index)
return JSONResponse({"detail": "Frontend not built"}, status_code=404)
else:
logging.warning(f"STATIC_DIR {STATIC_DIR} not found — frontend will not be served.")
# ---------------------------------------------------------------------------
# Local dev entry-point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
)
uvicorn.run(
"main:app",
host=os.environ.get("HOST", "0.0.0.0"),
port=int(os.environ.get("PORT", "7860")),
reload=False,
)