from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from pathlib import Path from transformers import AutoModelForCausalLM, AutoTokenizer import tempfile import traceback import whisper import librosa import numpy as np import torch import outetts import uvicorn import base64 import io import soundfile as sf try: INTERFACE = outetts.Interface( config=outetts.ModelConfig( model_path="models/v10", tokenizer_path="models/v10", audio_codec_path="models/dsp/weights_24khz_1.5kbps_v1.0.pth", device="cuda", dtype=torch.bfloat16, ) ) except Exception as e: raise RuntimeError(f"{e}") asr_model = whisper.load_model("models/wpt/wpt.pt") model_name = "models/Llama-3.2-1B-Instruct" tok = AutoTokenizer.from_pretrained(model_name) lm = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda", ).eval() SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav") def chat(system_prompt: str, user_prompt: str) -> str: """ Run one turn of chat with a system + user message. Extra **gen_kwargs are forwarded to `generate()`. """ messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] # `add_generation_prompt=True` automatically appends the # <|start_header_id|>assistant … header so the model knows to respond. # Get both input_ids and attention_mask inputs = tok.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", return_dict=True # Returns dict with input_ids and attention_mask ) # Move to device input_ids = inputs["input_ids"].to(lm.device) attention_mask = inputs["attention_mask"].to(lm.device) with torch.inference_mode(): output_ids = lm.generate( input_ids=input_ids, attention_mask=attention_mask, # Proper attention mask pad_token_id=tok.eos_token_id, # Explicit pad token max_new_tokens=2048, do_sample=True, temperature=0.2, repetition_penalty=1.1, top_k=100, top_p=0.95, ) # Strip the prompt part and return only the newly-generated answer answer = tok.decode( output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True, ) return answer.strip() def gt(audio: np.ndarray, sr: int): ss = audio.squeeze().astype(np.float32) if sr != 16_000: ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) result = asr_model.transcribe(ss, fp16=False, language=None) return result["text"].strip() def sample(rr: str) -> str: if rr.strip() == "": rr = "Hello " inputs = tok(rr, return_tensors="pt").to(lm.device) with torch.inference_mode(): out_ids = lm.generate( **inputs, max_new_tokens=2048, do_sample=True, temperature=0.2, repetition_penalty=1.1, top_k=100, top_p=0.95, ) return tok.decode( out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True ) INITIALIZATION_STATUS = {"model_loaded": True, "error": None} class GenerateRequest(BaseModel): audio_data: str = Field( ..., description="", ) sample_rate: int = Field(..., description="") class GenerateResponse(BaseModel): audio_data: str = Field(..., description="") app = FastAPI(title="V1", version="0.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def b64(b64: str) -> np.ndarray: raw = base64.b64decode(b64) return np.load(io.BytesIO(raw), allow_pickle=False) def ab64(arr: np.ndarray, sr: int) -> str: buf = io.BytesIO() resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr) np.save(buf, resampled.astype(np.float32)) return base64.b64encode(buf.getvalue()).decode() def gs( audio: np.ndarray, sr: int, interface: outetts.Interface, ): if audio.ndim == 2: audio = audio.squeeze() audio = audio.astype("float32") max_samples = int(15.0 * sr) if audio.shape[-1] > max_samples: audio = audio[-max_samples:] with tempfile.NamedTemporaryFile(suffix=".wav", dir="/tmp", delete=False) as f: sf.write(f.name, audio, sr) speaker = interface.create_speaker( f.name, whisper_model="models/wpt/wpt.pt", ) return speaker @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy", "model_loaded": INITIALIZATION_STATUS["model_loaded"], "error": INITIALIZATION_STATUS["error"], } return status @app.post("/api/v1/inference", response_model=GenerateResponse) def generate_audio(req: GenerateRequest): audio_np = b64(req.audio_data) if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) try: text = gt(audio_np, req.sample_rate) out = INTERFACE.generate( config=outetts.GenerationConfig( text=sample(text), generation_type=outetts.GenerationType.CHUNKED, speaker=gs(audio_np, req.sample_rate, INTERFACE), sampler_config=outetts.SamplerConfig(), ) ) audio_out = out.audio.squeeze().cpu().numpy() except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate)) @app.post("/api/v1/v2t") def generate_text(req: GenerateRequest): audio_np = b64(req.audio_data) if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) try: text = gt(audio_np, req.sample_rate) print(f"Transcribed text: {text}") # response_text = sample(text) system_prompt = "You are a helpful assistant who tries to help answer the user's question." response_text = chat(system_prompt, user_prompt=text) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") return {"text": response_text} if __name__ == "__main__": uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)