miner_v2t / server.py
siddhantoon's picture
Update server.py
d81d4d3 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import tempfile
import traceback
import whisper
import librosa
import numpy as np
import torch
import outetts
import uvicorn
import base64
import io
import soundfile as sf
try:
INTERFACE = outetts.Interface(
config=outetts.ModelConfig(
model_path="models/v10",
tokenizer_path="models/v10",
audio_codec_path="models/dsp/weights_24khz_1.5kbps_v1.0.pth",
device="cuda",
dtype=torch.bfloat16,
)
)
except Exception as e:
raise RuntimeError(f"{e}")
asr_model = whisper.load_model("models/wpt/wpt.pt")
model_name = "models/Llama-3.2-1B-Instruct"
tok = AutoTokenizer.from_pretrained(model_name)
lm = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav")
def chat(system_prompt: str, user_prompt: str) -> str:
"""
Run one turn of chat with a system + user message.
Extra **gen_kwargs are forwarded to `generate()`.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# `add_generation_prompt=True` automatically appends the
# <|start_header_id|>assistant … header so the model knows to respond.
# Get both input_ids and attention_mask
inputs = tok.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True # Returns dict with input_ids and attention_mask
)
# Move to device
input_ids = inputs["input_ids"].to(lm.device)
attention_mask = inputs["attention_mask"].to(lm.device)
with torch.inference_mode():
output_ids = lm.generate(
input_ids=input_ids,
attention_mask=attention_mask, # Proper attention mask
pad_token_id=tok.eos_token_id, # Explicit pad token
max_new_tokens=2048,
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
# Strip the prompt part and return only the newly-generated answer
answer = tok.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
return answer.strip()
def gt(audio: np.ndarray, sr: int):
ss = audio.squeeze().astype(np.float32)
if sr != 16_000:
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000)
result = asr_model.transcribe(ss, fp16=False, language=None)
return result["text"].strip()
def sample(rr: str) -> str:
if rr.strip() == "":
rr = "Hello "
inputs = tok(rr, return_tensors="pt").to(lm.device)
with torch.inference_mode():
out_ids = lm.generate(
**inputs,
max_new_tokens=2048,
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
return tok.decode(
out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True
)
INITIALIZATION_STATUS = {"model_loaded": True, "error": None}
class GenerateRequest(BaseModel):
audio_data: str = Field(
...,
description="",
)
sample_rate: int = Field(..., description="")
class GenerateResponse(BaseModel):
audio_data: str = Field(..., description="")
app = FastAPI(title="V1", version="0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def b64(b64: str) -> np.ndarray:
raw = base64.b64decode(b64)
return np.load(io.BytesIO(raw), allow_pickle=False)
def ab64(arr: np.ndarray, sr: int) -> str:
buf = io.BytesIO()
resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr)
np.save(buf, resampled.astype(np.float32))
return base64.b64encode(buf.getvalue()).decode()
def gs(
audio: np.ndarray,
sr: int,
interface: outetts.Interface,
):
if audio.ndim == 2:
audio = audio.squeeze()
audio = audio.astype("float32")
max_samples = int(15.0 * sr)
if audio.shape[-1] > max_samples:
audio = audio[-max_samples:]
with tempfile.NamedTemporaryFile(suffix=".wav", dir="/tmp", delete=False) as f:
sf.write(f.name, audio, sr)
speaker = interface.create_speaker(
f.name,
whisper_model="models/wpt/wpt.pt",
)
return speaker
@app.get("/api/v1/health")
def health_check():
"""Health check endpoint"""
status = {
"status": "healthy",
"model_loaded": INITIALIZATION_STATUS["model_loaded"],
"error": INITIALIZATION_STATUS["error"],
}
return status
@app.post("/api/v1/inference", response_model=GenerateResponse)
def generate_audio(req: GenerateRequest):
audio_np = b64(req.audio_data)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
try:
text = gt(audio_np, req.sample_rate)
out = INTERFACE.generate(
config=outetts.GenerationConfig(
text=sample(text),
generation_type=outetts.GenerationType.CHUNKED,
speaker=gs(audio_np, req.sample_rate, INTERFACE),
sampler_config=outetts.SamplerConfig(),
)
)
audio_out = out.audio.squeeze().cpu().numpy()
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate))
@app.post("/api/v1/v2t")
def generate_text(req: GenerateRequest):
audio_np = b64(req.audio_data)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
try:
text = gt(audio_np, req.sample_rate)
print(f"Transcribed text: {text}")
# response_text = sample(text)
system_prompt = "You are a helpful assistant who tries to help answer the user's question."
response_text = chat(system_prompt, user_prompt=text)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
return {"text": response_text}
if __name__ == "__main__":
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False)