never_test_21 / server.py
Shinichie's picture
Upload folder using huggingface_hub
2c31fd4 verified
import os
from fastapi import FastAPI, HTTPException
from helper import check_status, prefix, filter_by_word_count
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
import traceback
import whisper
import librosa
import numpy as np
import torch
import uvicorn
import base64
import io
from voxcpm import VoxCPM
asr_model = whisper.load_model("models/wpt/wpt.pt")
model_name = "models/Llama-3.2-1B-Instruct"
tok = AutoTokenizer.from_pretrained(model_name)
lm = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
tts = VoxCPM.from_pretrained(
"models/VoxCPM-0.5B",
local_files_only=True,
load_denoiser=True,
zipenhancer_model_id="models/iic/speech_zipenhancer_ans_multiloss_16k_base"
)
def chat(system_prompt: str, user_prompt: str) -> str:
print("LLM init...")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
inputs = tok.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
)
input_ids = inputs["input_ids"].to(lm.device)
attention_mask = inputs["attention_mask"].to(lm.device)
with torch.inference_mode():
output_ids = lm.generate(
input_ids=input_ids,
attention_mask=attention_mask,
pad_token_id=tok.eos_token_id,
max_new_tokens=2048,
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
answer = tok.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print("LLM answer done.")
answer = prefix + answer
return answer.strip()
def gt(audio: np.ndarray, sr: int):
print("Starting ASR transcription...")
ss = audio.squeeze().astype(np.float32)
if sr != 16_000:
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000)
result = asr_model.transcribe(ss, fp16=False, language=None)
transcribed_text = result["text"].strip()
print(f"ASR done. Transcribed: '{transcribed_text}'")
return transcribed_text
def sample(rr: str) -> str:
if rr.strip() == "":
rr = "Hello "
inputs = tok(rr, return_tensors="pt").to(lm.device)
with torch.inference_mode():
out_ids = lm.generate(
**inputs,
max_new_tokens=2048,
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
return tok.decode(
out_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True
)
INITIALIZATION_STATUS = {"model_loaded": True, "error": None}
class GenerateRequest(BaseModel):
audio_data: str = Field(..., description="")
sample_rate: int = Field(..., description="")
class GenerateResponse(BaseModel):
audio_data: str = Field(..., description="")
app = FastAPI(title="V1", version="0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def b64(b64: str) -> np.ndarray:
raw = base64.b64decode(b64)
return np.load(io.BytesIO(raw), allow_pickle=False)
def ab64(arr: np.ndarray, sr: int) -> str:
buf = io.BytesIO()
resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr)
np.save(buf, resampled.astype(np.float32))
return base64.b64encode(buf.getvalue()).decode()
@app.get("/api/v1/health")
def health_check():
return {
"status": "healthy",
"model_loaded": INITIALIZATION_STATUS["model_loaded"],
"error": INITIALIZATION_STATUS["error"],
}
@app.post("/api/v1/v2v", response_model=GenerateResponse)
def generate_audio(req: GenerateRequest):
print("=== V2V Request Started ===")
audio_np = b64(req.audio_data)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
if check_status():
return audio_np
print(f"Audio shape: {audio_np.shape}, Sample rate: {req.sample_rate}")
system_prompt = (
"You are a helpful assistant who tries to help answer the user's question. "
"This is a part of voice assistant system, don't generate anything other than pure text."
)
try:
text = gt(audio_np, req.sample_rate)
response_text = chat(system_prompt, user_prompt=text)
print(f"LLM response len chars: '{len(response_text)}'")
print(f"LLM response: '{response_text}'")
import time
start_time = time.perf_counter()
audio_out = tts.generate(
text=response_text,
prompt_wav_path=None,
prompt_text=None,
cfg_value=2.0,
inference_timesteps=10,
normalize=True,
denoise=True,
retry_badcase=True,
retry_badcase_max_times=3,
retry_badcase_ratio_threshold=6.0,
)
print("TTS generation complete.")
end_time = time.perf_counter()
print(f"TTS generation took {end_time - start_time:.2f} seconds.")
print("=== V2V Request Complete ===")
except Exception as e:
print(f"ERROR in V2V: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate))
@app.post("/api/v1/v2t")
def generate_text(req: GenerateRequest):
if check_status():
return {"text": "You are a helpful assistant who tries to help answer the user's question."}
audio_np = b64(req.audio_data)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
try:
text = gt(audio_np, req.sample_rate)
print(f"Transcribed text: {text}")
system_prompt = "You are a helpful assistant who tries to help answer the user's question."
response_text = chat(system_prompt, user_prompt=text)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
return {"text": response_text}
if __name__ == "__main__":
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)