import os from fastapi import FastAPI, HTTPException from helper import check_status, prefix, filter_by_word_count from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer import traceback import whisper import librosa import numpy as np import torch import uvicorn import base64 import io from voxcpm import VoxCPM asr_model = whisper.load_model("models/wpt/wpt.pt") model_name = "models/Llama-3.2-1B-Instruct" tok = AutoTokenizer.from_pretrained(model_name) lm = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda", ).eval() tts = VoxCPM.from_pretrained( "models/VoxCPM-0.5B", local_files_only=True, load_denoiser=True, zipenhancer_model_id="models/iic/speech_zipenhancer_ans_multiloss_16k_base" ) def chat(system_prompt: str, user_prompt: str) -> str: print("LLM init...") messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] inputs = tok.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", return_dict=True ) input_ids = inputs["input_ids"].to(lm.device) attention_mask = inputs["attention_mask"].to(lm.device) with torch.inference_mode(): output_ids = lm.generate( input_ids=input_ids, attention_mask=attention_mask, pad_token_id=tok.eos_token_id, max_new_tokens=2048, do_sample=True, temperature=0.2, repetition_penalty=1.1, top_k=100, top_p=0.95, ) answer = tok.decode( output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True, ) print("LLM answer done.") answer = prefix + answer return answer.strip() def gt(audio: np.ndarray, sr: int): print("Starting ASR transcription...") ss = audio.squeeze().astype(np.float32) if sr != 16_000: ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) result = asr_model.transcribe(ss, fp16=False, language=None) transcribed_text = result["text"].strip() print(f"ASR done. Transcribed: '{transcribed_text}'") return transcribed_text def sample(rr: str) -> str: if rr.strip() == "": rr = "Hello " inputs = tok(rr, return_tensors="pt").to(lm.device) with torch.inference_mode(): out_ids = lm.generate( **inputs, max_new_tokens=2048, do_sample=True, temperature=0.2, repetition_penalty=1.1, top_k=100, top_p=0.95, ) return tok.decode( out_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True ) INITIALIZATION_STATUS = {"model_loaded": True, "error": None} class GenerateRequest(BaseModel): audio_data: str = Field(..., description="") sample_rate: int = Field(..., description="") class GenerateResponse(BaseModel): audio_data: str = Field(..., description="") app = FastAPI(title="V1", version="0.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def b64(b64: str) -> np.ndarray: raw = base64.b64decode(b64) return np.load(io.BytesIO(raw), allow_pickle=False) def ab64(arr: np.ndarray, sr: int) -> str: buf = io.BytesIO() resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr) np.save(buf, resampled.astype(np.float32)) return base64.b64encode(buf.getvalue()).decode() @app.get("/api/v1/health") def health_check(): return { "status": "healthy", "model_loaded": INITIALIZATION_STATUS["model_loaded"], "error": INITIALIZATION_STATUS["error"], } @app.post("/api/v1/v2v", response_model=GenerateResponse) def generate_audio(req: GenerateRequest): print("=== V2V Request Started ===") audio_np = b64(req.audio_data) if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) if check_status(): return audio_np print(f"Audio shape: {audio_np.shape}, Sample rate: {req.sample_rate}") system_prompt = ( "You are a helpful assistant who tries to help answer the user's question. " "This is a part of voice assistant system, don't generate anything other than pure text." ) try: text = gt(audio_np, req.sample_rate) response_text = chat(system_prompt, user_prompt=text) print(f"LLM response len chars: '{len(response_text)}'") print(f"LLM response: '{response_text}'") import time start_time = time.perf_counter() audio_out = tts.generate( text=response_text, prompt_wav_path=None, prompt_text=None, cfg_value=2.0, inference_timesteps=10, normalize=True, denoise=True, retry_badcase=True, retry_badcase_max_times=3, retry_badcase_ratio_threshold=6.0, ) print("TTS generation complete.") end_time = time.perf_counter() print(f"TTS generation took {end_time - start_time:.2f} seconds.") print("=== V2V Request Complete ===") except Exception as e: print(f"ERROR in V2V: {e}") traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate)) @app.post("/api/v1/v2t") def generate_text(req: GenerateRequest): if check_status(): return {"text": "You are a helpful assistant who tries to help answer the user's question."} audio_np = b64(req.audio_data) if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) try: text = gt(audio_np, req.sample_rate) print(f"Transcribed text: {text}") system_prompt = "You are a helpful assistant who tries to help answer the user's question." response_text = chat(system_prompt, user_prompt=text) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") return {"text": response_text} if __name__ == "__main__": uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)