| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | from pathlib import Path |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import tempfile |
| | import traceback |
| | import whisper |
| | import librosa |
| | import numpy as np |
| | import torch |
| | import outetts |
| | import uvicorn |
| | import base64 |
| | import io |
| | import soundfile as sf |
| |
|
| | try: |
| | INTERFACE = outetts.Interface( |
| | config=outetts.ModelConfig( |
| | model_path="models/v10", |
| | tokenizer_path="models/v10", |
| | audio_codec_path="models/dsp/weights_24khz_1.5kbps_v1.0.pth", |
| | device="cuda", |
| | dtype=torch.bfloat16, |
| | ) |
| | ) |
| | except Exception as e: |
| | raise RuntimeError(f"{e}") |
| |
|
| | asr_model = whisper.load_model("models/wpt/wpt.pt") |
| | model_name = "models/lm" |
| | tok = AutoTokenizer.from_pretrained(model_name) |
| | lm = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.bfloat16, |
| | device_map="cuda", |
| | ).eval() |
| | SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav") |
| |
|
| |
|
| | def gt(audio: np.ndarray, sr: int): |
| | ss = audio.squeeze().astype(np.float32) |
| | if sr != 16_000: |
| | ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) |
| |
|
| | result = asr_model.transcribe(ss, fp16=False, language=None) |
| | return result["text"].strip() |
| |
|
| |
|
| | def sample(rr: str) -> str: |
| | if rr.strip() == "": |
| | rr = "Hello " |
| |
|
| | inputs = tok(rr, return_tensors="pt").to(lm.device) |
| |
|
| | with torch.inference_mode(): |
| | out_ids = lm.generate( |
| | **inputs, |
| | max_new_tokens=45, |
| | do_sample=True, |
| | temperature=0.2, |
| | repetition_penalty=1.13, |
| | top_k=100, |
| | top_p=0.95, |
| | ) |
| |
|
| |
|
| |
|
| | return tok.decode( |
| | out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True |
| | ) |
| |
|
| |
|
| | INITIALIZATION_STATUS = {"model_loaded": True, "error": None} |
| |
|
| |
|
| | class GenerateRequest(BaseModel): |
| | audio_data: str = Field( |
| | ..., |
| | description="", |
| | ) |
| | sample_rate: int = Field(..., description="") |
| |
|
| |
|
| | class GenerateResponse(BaseModel): |
| | audio_data: str = Field(..., description="") |
| |
|
| |
|
| | app = FastAPI(title="V1", version="0.1") |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | def b64(b64: str) -> np.ndarray: |
| | raw = base64.b64decode(b64) |
| | return np.load(io.BytesIO(raw), allow_pickle=False) |
| |
|
| |
|
| | def ab64(arr: np.ndarray, sr: int) -> str: |
| | buf = io.BytesIO() |
| | resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr) |
| | np.save(buf, resampled.astype(np.float32)) |
| | return base64.b64encode(buf.getvalue()).decode() |
| |
|
| |
|
| | def gs( |
| | audio: np.ndarray, |
| | sr: int, |
| | interface: outetts.Interface, |
| | ): |
| | if audio.ndim == 2: |
| | audio = audio.squeeze() |
| | audio = audio.astype("float32") |
| | max_samples = int(15.0 * sr) |
| | if audio.shape[-1] > max_samples: |
| | audio = audio[-max_samples:] |
| |
|
| | with tempfile.NamedTemporaryFile(suffix=".wav", dir="/tmp", delete=False) as f: |
| | sf.write(f.name, audio, sr) |
| | speaker = interface.create_speaker( |
| | f.name, |
| | whisper_model="models/wpt/wpt.pt", |
| | ) |
| |
|
| | return speaker |
| |
|
| |
|
| | @app.get("/api/v1/health") |
| | def health_check(): |
| | """Health check endpoint""" |
| | status = { |
| | "status": "healthy", |
| | "model_loaded": INITIALIZATION_STATUS["model_loaded"], |
| | "error": INITIALIZATION_STATUS["error"], |
| | } |
| | return status |
| |
|
| |
|
| | @app.post("/api/v1/inference", response_model=GenerateResponse) |
| | def generate_audio(req: GenerateRequest): |
| | audio_np = b64(req.audio_data) |
| | if audio_np.ndim == 1: |
| | audio_np = audio_np.reshape(1, -1) |
| |
|
| | try: |
| | text = gt(audio_np, req.sample_rate) |
| | out = INTERFACE.generate( |
| | config=outetts.GenerationConfig( |
| | text=sample(text), |
| | generation_type=outetts.GenerationType.CHUNKED, |
| | speaker=gs(audio_np, req.sample_rate, INTERFACE), |
| | sampler_config=outetts.SamplerConfig(), |
| | ) |
| | ) |
| | audio_out = out.audio.squeeze().cpu().numpy() |
| | except Exception as e: |
| | traceback.print_exc() |
| | raise HTTPException(status_code=500, detail=f"{e}") |
| |
|
| | return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False) |
| |
|