| import os |
| |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
| os.environ['OMP_NUM_THREADS'] = '4' |
| os.environ['MKL_NUM_THREADS'] = '4' |
| os.environ['OPENBLAS_NUM_THREADS'] = '4' |
| os.environ['NUMEXPR_NUM_THREADS'] = '4' |
| os.environ['RAYON_NUM_THREADS'] = '4' |
| |
| os.environ['HF_HUB_OFFLINE'] = '1' |
| os.environ['TRANSFORMERS_OFFLINE'] = '1' |
| |
| os.environ['TORCH_COMPILE_DISABLE'] = '1' |
| os.environ['TRITON_DISABLE_LINE_INFO'] = '1' |
| |
| os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
| |
| |
| import torch |
| |
| torch.set_num_threads(4) |
| |
| try: |
| torch.set_num_interop_threads(2) |
| except RuntimeError: |
| |
| pass |
|
|
| |
| import torch._dynamo |
| torch._dynamo.config.suppress_errors = True |
| torch._dynamo.config.disable = True |
|
|
| |
| try: |
| torch.jit._state.disable() |
| except: |
| pass |
|
|
| |
| |
| |
| _original_torch_load = torch.load |
| def _patched_torch_load(*args, **kwargs): |
| |
| if 'weights_only' not in kwargs: |
| kwargs['weights_only'] = False |
| return _original_torch_load(*args, **kwargs) |
| torch.load = _patched_torch_load |
|
|
| |
| assert torch.load is _patched_torch_load, "torch.load patch failed!" |
|
|
| |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import traceback |
| import whisper |
| import librosa |
| import numpy as np |
| import uvicorn |
| import base64 |
| import io |
| from voxcpm import VoxCPM |
|
|
| print("Loading ASR model...") |
| asr_model = whisper.load_model("models/wpt/wpt.pt") |
| print("ASR model loaded.") |
|
|
| print("Loading LLM...") |
| model_name = "models/Llama-3.2-1B-Instruct" |
| |
| tok = AutoTokenizer.from_pretrained( |
| model_name, |
| local_files_only=True |
| ) |
| lm = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| device_map="cuda", |
| local_files_only=True |
| ).eval() |
| print("LLM loaded.") |
|
|
| print("Loading TTS model...") |
| tts = VoxCPM.from_pretrained( |
| "models/VoxCPM-0.5B", |
| local_files_only=True, |
| load_denoiser=True, |
| zipenhancer_model_id="models/iic/speech_zipenhancer_ans_multiloss_16k_base" |
| ) |
| print("TTS model loaded.") |
|
|
| def chat(system_prompt: str, user_prompt: str) -> str: |
| print("LLM init...") |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ] |
| inputs = tok.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True |
| ) |
| input_ids = inputs["input_ids"].to(lm.device) |
| attention_mask = inputs["attention_mask"].to(lm.device) |
|
|
| with torch.inference_mode(): |
| output_ids = lm.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pad_token_id=tok.eos_token_id, |
| max_new_tokens=2048, |
| do_sample=True, |
| temperature=0.2, |
| repetition_penalty=1.1, |
| top_k=100, |
| top_p=0.95, |
| ) |
|
|
| answer = tok.decode( |
| output_ids[0][input_ids.shape[-1]:], |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=True, |
| ) |
| print("LLM answer done.") |
| return answer.strip() |
|
|
| def gt(audio: np.ndarray, sr: int): |
| print("Starting ASR transcription...") |
| ss = audio.squeeze().astype(np.float32) |
| if sr != 16_000: |
| ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) |
|
|
| result = asr_model.transcribe(ss, fp16=False, language=None) |
| transcribed_text = result["text"].strip() |
| print(f"ASR done. Transcribed: '{transcribed_text}'") |
| return transcribed_text |
|
|
|
|
| def sample(rr: str) -> str: |
| if rr.strip() == "": |
| rr = "Hello " |
|
|
| inputs = tok(rr, return_tensors="pt").to(lm.device) |
|
|
| with torch.inference_mode(): |
| out_ids = lm.generate( |
| **inputs, |
| max_new_tokens=2048, |
| do_sample=True, |
| temperature=0.2, |
| repetition_penalty=1.1, |
| top_k=100, |
| top_p=0.95, |
| ) |
|
|
| return tok.decode( |
| out_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True |
| ) |
|
|
|
|
| INITIALIZATION_STATUS = {"model_loaded": True, "error": None} |
|
|
|
|
| class GenerateRequest(BaseModel): |
| audio_data: str = Field(..., description="") |
| sample_rate: int = Field(..., description="") |
|
|
|
|
| class GenerateResponse(BaseModel): |
| audio_data: str = Field(..., description="") |
|
|
|
|
| app = FastAPI(title="V1", version="0.1") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| def b64(b64: str) -> np.ndarray: |
| raw = base64.b64decode(b64) |
| return np.load(io.BytesIO(raw), allow_pickle=False) |
|
|
|
|
| def ab64(arr: np.ndarray, sr: int) -> str: |
| buf = io.BytesIO() |
| resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr) |
| np.save(buf, resampled.astype(np.float32)) |
| return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
| @app.get("/api/v1/health") |
| def health_check(): |
| return { |
| "status": "healthy", |
| "model_loaded": INITIALIZATION_STATUS["model_loaded"], |
| "error": INITIALIZATION_STATUS["error"], |
| } |
|
|
|
|
| @app.post("/api/v1/v2v", response_model=GenerateResponse) |
| def generate_audio(req: GenerateRequest): |
| print("=== V2V Request Started ===") |
| audio_np = b64(req.audio_data) |
| if audio_np.ndim == 1: |
| audio_np = audio_np.reshape(1, -1) |
| print(f"Audio shape: {audio_np.shape}, Sample rate: {req.sample_rate}") |
|
|
| system_prompt = ( |
| "You are a helpful assistant who tries to help answer the user's question. " |
| "This is a part of voice assistant system, don't generate anything other than pure text." |
| ) |
|
|
| try: |
| text = gt(audio_np, req.sample_rate) |
| response_text = chat(system_prompt, user_prompt=text) |
| print(f"LLM response len chars: '{len(response_text)}'") |
| print(f"LLM response: '{response_text}'") |
|
|
| import time |
| start_time = time.perf_counter() |
| audio_out = tts.generate( |
| text=response_text, |
| prompt_wav_path=None, |
| prompt_text=None, |
| cfg_value=2.0, |
| inference_timesteps=10, |
| normalize=True, |
| denoise=True, |
| retry_badcase=True, |
| retry_badcase_max_times=3, |
| retry_badcase_ratio_threshold=6.0, |
| ) |
| print("TTS generation complete.") |
| end_time = time.perf_counter() |
| print(f"TTS generation took {end_time - start_time:.2f} seconds.") |
| print("=== V2V Request Complete ===") |
| except Exception as e: |
| print(f"ERROR in V2V: {e}") |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"{e}") |
|
|
| return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate)) |
|
|
|
|
| @app.post("/api/v1/v2t") |
| def generate_text(req: GenerateRequest): |
| audio_np = b64(req.audio_data) |
| if audio_np.ndim == 1: |
| audio_np = audio_np.reshape(1, -1) |
|
|
| try: |
| text = gt(audio_np, req.sample_rate) |
| print(f"Transcribed text: {text}") |
| system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
| response_text = chat(system_prompt, user_prompt=text) |
| except Exception as e: |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"{e}") |
|
|
| return {"text": response_text} |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| try: |
| print("Starting FastAPI server on port 8000...") |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=8000, |
| reload=False, |
| log_level="info" |
| ) |
| except Exception as e: |
| print(f"ERROR: Failed to start server: {e}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|