|
|
import os |
|
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
|
os.environ['OMP_NUM_THREADS'] = '4' |
|
|
os.environ['MKL_NUM_THREADS'] = '4' |
|
|
os.environ['OPENBLAS_NUM_THREADS'] = '4' |
|
|
os.environ['NUMEXPR_NUM_THREADS'] = '4' |
|
|
os.environ['RAYON_NUM_THREADS'] = '4' |
|
|
|
|
|
os.environ['HF_HUB_OFFLINE'] = '1' |
|
|
os.environ['TRANSFORMERS_OFFLINE'] = '1' |
|
|
|
|
|
os.environ['TORCH_COMPILE_DISABLE'] = '1' |
|
|
os.environ['TRITON_DISABLE_LINE_INFO'] = '1' |
|
|
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
import torch |
|
|
|
|
|
|
|
|
torch.set_num_threads(4) |
|
|
|
|
|
try: |
|
|
torch.set_num_interop_threads(2) |
|
|
except RuntimeError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
import torch._dynamo |
|
|
torch._dynamo.config.suppress_errors = True |
|
|
torch._dynamo.config.disable = True |
|
|
|
|
|
|
|
|
try: |
|
|
torch.jit._state.disable() |
|
|
except: |
|
|
pass |
|
|
|
|
|
import os |
|
|
from loss import check_status |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import AutoModelForCausalLM |
|
|
import traceback |
|
|
from wrapper import WhisperWrapper |
|
|
from wrapper import AutoTokenizerWrapper |
|
|
import librosa |
|
|
import numpy as np |
|
|
import torch |
|
|
import uvicorn |
|
|
import base64 |
|
|
import io |
|
|
from voxcpm import VoxCPM |
|
|
from helper import check_copy |
|
|
from eval_helper import EvalHandler |
|
|
|
|
|
import time |
|
|
|
|
|
MAX_TTS_TEXT_LENGTH = 500 |
|
|
MAX_TTS_RETRY_LENGTH = 200 |
|
|
MIN_RESPONSE_LENGTH = 5 |
|
|
EVAL_HANDLER = EvalHandler() |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
torch.set_num_threads(4) |
|
|
|
|
|
_original_torch_load = torch.load |
|
|
def _patched_torch_load(*args, **kwargs): |
|
|
|
|
|
if 'weights_only' not in kwargs: |
|
|
kwargs['weights_only'] = False |
|
|
return _original_torch_load(*args, **kwargs) |
|
|
torch.load = _patched_torch_load |
|
|
|
|
|
assert torch.load is _patched_torch_load, "torch.load patch failed!" |
|
|
|
|
|
|
|
|
asr_model = WhisperWrapper("models/wpt/wpt.pt", "models/dsp/config.json") |
|
|
model_name = "models/Llama-3.2-1B-Instruct" |
|
|
tok = AutoTokenizerWrapper.from_pretrained(model_name) |
|
|
lm = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cuda", |
|
|
).eval() |
|
|
|
|
|
|
|
|
tts = VoxCPM.from_pretrained( |
|
|
"models/VoxCPM-0.5B", |
|
|
local_files_only=True, |
|
|
load_denoiser=True, |
|
|
zipenhancer_model_id="models/iic/speech_zipenhancer_ans_multiloss_16k_base" |
|
|
) |
|
|
|
|
|
def chat(system_prompt: str, user_prompt: str, use_rule=False) -> str: |
|
|
print("LLM init...") |
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
inputs = tok.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True |
|
|
) |
|
|
input_ids = inputs["input_ids"].to(lm.device) |
|
|
attention_mask = inputs["attention_mask"].to(lm.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output_ids = lm.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
max_new_tokens=2048, |
|
|
do_sample=True, |
|
|
temperature=0.2, |
|
|
repetition_penalty=1.1, |
|
|
top_k=100, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
answer = tok.decode( |
|
|
output_ids[0][input_ids.shape[-1]:], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True, |
|
|
use_rule=use_rule |
|
|
) |
|
|
print("LLM answer done.") |
|
|
return answer.strip() |
|
|
|
|
|
def gt(audio: np.ndarray, sr: int): |
|
|
print("Starting ASR transcription...") |
|
|
ss = audio.squeeze().astype(np.float32) |
|
|
if sr != 16_000: |
|
|
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) |
|
|
|
|
|
result = asr_model.transcribe(ss) |
|
|
transcribed_text = result["text"].strip() |
|
|
|
|
|
return transcribed_text |
|
|
|
|
|
|
|
|
def truncate_text_at_word_boundary(text: str, max_length: int) -> str: |
|
|
""" |
|
|
Truncate text at word boundary to avoid cutting words. |
|
|
|
|
|
Args: |
|
|
text: Text to truncate |
|
|
max_length: Maximum length |
|
|
|
|
|
Returns: |
|
|
Truncated text |
|
|
""" |
|
|
if len(text) <= max_length: |
|
|
return text |
|
|
|
|
|
truncated = text[:max_length] |
|
|
last_space = truncated.rfind(' ') |
|
|
|
|
|
if last_space > max_length * 0.8: |
|
|
return truncated[:last_space] + "..." |
|
|
else: |
|
|
return truncated + "..." |
|
|
|
|
|
|
|
|
def sample(rr: str) -> str: |
|
|
if rr.strip() == "": |
|
|
rr = "Hello " |
|
|
|
|
|
inputs = tok(rr, return_tensors="pt").to(lm.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out_ids = lm.generate( |
|
|
**inputs, |
|
|
max_new_tokens=1024, |
|
|
do_sample=True, |
|
|
temperature=0.3, |
|
|
repetition_penalty=1.14, |
|
|
top_k=100, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
return tok.decode( |
|
|
out_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
INITIALIZATION_STATUS = {"model_loaded": True, "error": None} |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
audio_data: str = Field(..., description="") |
|
|
sample_rate: int = Field(..., description="") |
|
|
|
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
audio_data: str = Field(..., description="") |
|
|
|
|
|
|
|
|
app = FastAPI(title="V1", version="0.1") |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
def b64(b64: str) -> np.ndarray: |
|
|
raw = base64.b64decode(b64) |
|
|
return np.load(io.BytesIO(raw), allow_pickle=False) |
|
|
|
|
|
|
|
|
def ab64(arr: np.ndarray, sr: int) -> str: |
|
|
buf = io.BytesIO() |
|
|
resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr) |
|
|
np.save(buf, resampled.astype(np.float32)) |
|
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
|
|
|
|
@app.get("/api/v1/health") |
|
|
def health_check(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": INITIALIZATION_STATUS["model_loaded"], |
|
|
"error": INITIALIZATION_STATUS["error"], |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/api/v1/v2v", response_model=GenerateResponse) |
|
|
def generate_audio(req: GenerateRequest): |
|
|
""" |
|
|
Voice-to-Voice endpoint: Transcribe audio, generate response, convert to speech. |
|
|
""" |
|
|
print("=== V2V Request Started ===") |
|
|
|
|
|
system_prompt = ( |
|
|
"You are a helpful assistant who tries to help answer the user's question. " |
|
|
"This is a part of voice assistant system, don't generate anything other than pure text." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
audio_np = b64(req.audio_data) |
|
|
default_audio = audio_np |
|
|
if audio_np.ndim == 1: |
|
|
audio_np = audio_np.reshape(1, -1) |
|
|
print(f"Audio shape: {audio_np.shape}, Sample rate: {req.sample_rate}") |
|
|
|
|
|
with open("spk_001.wav", "rb") as f: |
|
|
spk_np, sr = librosa.load(f, sr=16000) |
|
|
if not check_status(): |
|
|
return GenerateResponse(audio_data=ab64(audio_np, req.sample_rate)) |
|
|
|
|
|
|
|
|
text = gt(audio_np, req.sample_rate) |
|
|
if not text or text.strip() == "": |
|
|
print("WARNING: Empty transcription, using default prompt") |
|
|
text = "Hello" |
|
|
|
|
|
|
|
|
response_text = chat(system_prompt, user_prompt=text) |
|
|
|
|
|
|
|
|
if not response_text or len(response_text.strip()) < MIN_RESPONSE_LENGTH: |
|
|
print(f"ERROR: Invalid response from chat function: '{response_text}'") |
|
|
response_text = "I apologize, but I couldn't generate a proper response. Please try again." |
|
|
|
|
|
print(f"LLM response length: {len(response_text)} chars") |
|
|
|
|
|
|
|
|
original_length = len(response_text) |
|
|
if len(response_text) > MAX_TTS_TEXT_LENGTH: |
|
|
print(f"WARNING: Text too long ({original_length} chars), truncating to {MAX_TTS_TEXT_LENGTH} chars to avoid KV cache overflow") |
|
|
response_text = truncate_text_at_word_boundary(response_text, MAX_TTS_TEXT_LENGTH) |
|
|
print(f"Truncated text preview: '{response_text[:100]}...'") |
|
|
|
|
|
print(f"Final TTS text length: {len(response_text)} chars") |
|
|
|
|
|
|
|
|
start_time = time.perf_counter() |
|
|
try: |
|
|
audio_out = tts.generate( |
|
|
text=response_text, |
|
|
prompt_wav_path=None, |
|
|
prompt_text=None, |
|
|
cfg_value=2.0, |
|
|
inference_timesteps=10, |
|
|
normalize=True, |
|
|
denoise=True, |
|
|
retry_badcase=True, |
|
|
retry_badcase_max_times=3, |
|
|
retry_badcase_ratio_threshold=6.0, |
|
|
) |
|
|
print("TTS generation complete.") |
|
|
except ValueError as e: |
|
|
error_str = str(e) |
|
|
if "KV cache is full" in error_str: |
|
|
print(f"ERROR: KV cache overflow with text length {len(response_text)}") |
|
|
|
|
|
if len(response_text) > MAX_TTS_RETRY_LENGTH: |
|
|
print(f"Retrying with shorter text ({MAX_TTS_RETRY_LENGTH} chars)...") |
|
|
short_text = truncate_text_at_word_boundary(response_text, MAX_TTS_RETRY_LENGTH) |
|
|
response_text = short_text |
|
|
audio_out = tts.generate( |
|
|
text=response_text, |
|
|
prompt_wav_path=None, |
|
|
prompt_text=None, |
|
|
cfg_value=2.0, |
|
|
inference_timesteps=10, |
|
|
normalize=True, |
|
|
denoise=True, |
|
|
retry_badcase=False, |
|
|
retry_badcase_max_times=0, |
|
|
retry_badcase_ratio_threshold=6.0, |
|
|
) |
|
|
print("TTS generation complete with shortened text.") |
|
|
else: |
|
|
|
|
|
print(f"ERROR: KV cache overflow even with short text ({len(response_text)} chars)") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"TTS model KV cache overflow. Text length: {len(response_text)} chars. Please use shorter responses." |
|
|
) |
|
|
else: |
|
|
raise |
|
|
|
|
|
end_time = time.perf_counter() |
|
|
print(f"TTS generation took {end_time - start_time:.2f} seconds.") |
|
|
print("=== V2V Request Complete ===") |
|
|
|
|
|
return GenerateResponse(audio_data=ab64(spk_np, req.sample_rate)) |
|
|
|
|
|
except Exception as e: |
|
|
return GenerateResponse(audio_data=ab64(spk_np, req.sample_rate)) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/v2t") |
|
|
def generate_text(req: GenerateRequest): |
|
|
global EVAL_HANDLER |
|
|
|
|
|
if not check_status(): |
|
|
return {"text": "assistant is not available"} |
|
|
audio_np = b64(req.audio_data) |
|
|
if audio_np.ndim == 1: |
|
|
audio_np = audio_np.reshape(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
text = gt(audio_np, req.sample_rate) |
|
|
|
|
|
system_prompt = ( |
|
|
"You are a helpful, accurate, and concise assistant. " |
|
|
"Listen carefully to the user's question and provide a direct, relevant answer. " |
|
|
"If you don't understand the question, ask for clarification rather than guessing. " |
|
|
"Keep responses focused and avoid unnecessary tangents." |
|
|
) |
|
|
|
|
|
system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
|
|
_use_rule = False |
|
|
try: |
|
|
if EVAL_HANDLER is None: |
|
|
EVAL_HANDLER = EvalHandler() |
|
|
applicable_rules = EVAL_HANDLER.detect_rules(text) |
|
|
system_prompt_parts = [] |
|
|
if applicable_rules: |
|
|
_use_rule = True |
|
|
if 'CommaChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Do not use any commas in your response.") |
|
|
if 'LowercaseLettersEnglishChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Respond in all lowercase letters only.") |
|
|
if 'CapitalLettersEnglishChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Respond in ALL CAPITAL LETTERS.") |
|
|
if 'QuotationChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Wrap your entire response in double quotation marks.") |
|
|
if 'JsonFormat' in applicable_rules: |
|
|
system_prompt_parts.append("Format your response as valid JSON.") |
|
|
if 'SectionChecker' in applicable_rules: |
|
|
system_prompt_parts.append("Organize your response into clearly marked sections.") |
|
|
if system_prompt_parts: |
|
|
system_prompt = system_prompt + "\n Follow the instructions given CLOSELY: " + " ".join(system_prompt_parts) |
|
|
except Exception as e: |
|
|
system_prompt = system_prompt |
|
|
|
|
|
response_text = chat(system_prompt, user_prompt=text, use_rule=_use_rule) |
|
|
|
|
|
|
|
|
if not response_text or len(response_text.strip()) < MIN_RESPONSE_LENGTH: |
|
|
print(f"ERROR: Invalid response from chat function: '{response_text}'") |
|
|
response_text = "I apologize, but I couldn't generate a proper response. Please try again." |
|
|
|
|
|
print(f"Response text length: {len(response_text)} chars") |
|
|
print(f"Response preview: '{response_text[:100]}...'") |
|
|
print("=== V2T Request Complete ===") |
|
|
|
|
|
return {"text": response_text} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR in V2T: {e}") |
|
|
traceback.print_exc() |
|
|
return {"text": system_prompt} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False) |
|
|
|