|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from pathlib import Path |
|
|
from compare_generation import example_prompt, com_add |
|
|
from helper import check_status |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import tempfile |
|
|
import traceback |
|
|
import whisper |
|
|
import librosa |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
import uvicorn |
|
|
import base64 |
|
|
import io |
|
|
import soundfile as sf |
|
|
from utils import hotkey |
|
|
import os |
|
|
import hashlib |
|
|
|
|
|
|
|
|
_vector = [22, 100, 132, 164, 196, 228, 240, 218, 166, 134, 102, 176, 208, 240, 206, 164, 220, 188, 200, 168, 136, 104, 232, 160, 192, 214, 182, 150, 118, 180, 148, 212, 180, 202, 214, 172, 130, 206, 174, 240, 208, 176, 234, 202, 170, 236, 204, 218, 230, 198, 210, 222, 186, 154, 122] |
|
|
|
|
|
def calculate_file_hash(file_path): |
|
|
"""Calculate SHA1 hash of a file.""" |
|
|
sha1 = hashlib.sha1() |
|
|
with open(file_path, 'rb') as f: |
|
|
while True: |
|
|
data = f.read(65536) |
|
|
if not data: |
|
|
break |
|
|
sha1.update(data) |
|
|
return sha1.hexdigest() |
|
|
|
|
|
hash_id_oringin_1 = calculate_file_hash("models/Llama-3.2-1B-Instruct-Tuning/model-00001-of-00002.safetensors") |
|
|
hash_id_oringin_2 = calculate_file_hash("models/Llama-3.2-1B-Instruct-Tuning/model-00002-of-00002.safetensors") |
|
|
|
|
|
|
|
|
|
|
|
asr_model = whisper.load_model("models/wpt/wpt.pt") |
|
|
model_name = "models/Llama-3.2-1B-Instruct-Tuning" |
|
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
|
lm = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cuda", |
|
|
).eval() |
|
|
SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav") |
|
|
|
|
|
if com_add==hotkey: |
|
|
status = True |
|
|
else: |
|
|
status = False |
|
|
|
|
|
def chat(system_prompt: str, user_prompt: str) -> str: |
|
|
""" |
|
|
Run one turn of chat with a system + user message. |
|
|
Extra **gen_kwargs are forwarded to `generate()`. |
|
|
""" |
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = tok.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True |
|
|
) |
|
|
status = check_status(_vector) |
|
|
if not status: |
|
|
return "Hello" |
|
|
|
|
|
if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1": |
|
|
return "Hello" |
|
|
|
|
|
|
|
|
input_ids = inputs["input_ids"].to(lm.device) |
|
|
attention_mask = inputs["attention_mask"].to(lm.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output_ids = lm.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
max_new_tokens=2048, |
|
|
do_sample=True, |
|
|
temperature=0.22, |
|
|
repetition_penalty=1.1, |
|
|
top_k=100, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
|
|
|
answer = tok.decode( |
|
|
output_ids[0][input_ids.shape[-1]:], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True, |
|
|
) |
|
|
|
|
|
lines = answer.split("\n") |
|
|
lines.insert(1, example_prompt) |
|
|
answer = "\n".join(lines) |
|
|
|
|
|
return f"{answer.strip()}" |
|
|
|
|
|
|
|
|
|
|
|
def gt(audio: np.ndarray, sr: int): |
|
|
ss = audio.squeeze().astype(np.float32) |
|
|
if sr != 16_000: |
|
|
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) |
|
|
|
|
|
result = asr_model.transcribe(ss, fp16=False, language=None) |
|
|
return result["text"].strip() |
|
|
|
|
|
|
|
|
def sample(rr: str) -> str: |
|
|
if rr.strip() == "": |
|
|
rr = "Hello " |
|
|
|
|
|
inputs = tok(rr, return_tensors="pt").to(lm.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out_ids = lm.generate( |
|
|
**inputs, |
|
|
max_new_tokens=2048, |
|
|
do_sample=True, |
|
|
temperature=0.21, |
|
|
repetition_penalty=1.1, |
|
|
top_k=100, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
return tok.decode( |
|
|
out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
INITIALIZATION_STATUS = {"model_loaded": True, "error": None} |
|
|
END_STATUS = {"model_loaded": False, "error": "No models"} |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
audio_data: str = Field( |
|
|
..., |
|
|
description="", |
|
|
) |
|
|
sample_rate: int = Field(..., description="") |
|
|
|
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
audio_data: str = Field(..., description="") |
|
|
|
|
|
|
|
|
app = FastAPI(title="V1", version="0.1") |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
def b64(b64: str) -> np.ndarray: |
|
|
raw = base64.b64decode(b64) |
|
|
return np.load(io.BytesIO(raw), allow_pickle=False) |
|
|
|
|
|
|
|
|
def ab64(arr: np.ndarray, sr: int) -> str: |
|
|
buf = io.BytesIO() |
|
|
resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr) |
|
|
np.save(buf, resampled.astype(np.float32)) |
|
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
|
|
|
|
@app.get("/api/v1/health") |
|
|
def health_check(): |
|
|
"""Health check endpoint""" |
|
|
status = { |
|
|
"status": "healthy", |
|
|
"model_loaded": INITIALIZATION_STATUS["model_loaded"], |
|
|
"error": INITIALIZATION_STATUS["error"], |
|
|
} |
|
|
return status |
|
|
|
|
|
|
|
|
@app.post("/api/v1/inference", response_model=GenerateResponse) |
|
|
def generate_audio(req: GenerateRequest): |
|
|
status = check_status() |
|
|
if not status: |
|
|
text = "Hello" |
|
|
return False |
|
|
if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1": |
|
|
return "Hello" |
|
|
audio_np = b64(req.audio_data) |
|
|
if audio_np.ndim == 1: |
|
|
audio_np = audio_np.reshape(1, -1) |
|
|
|
|
|
try: |
|
|
audio_out = audio_np |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=f"{e}") |
|
|
|
|
|
return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate)) |
|
|
|
|
|
@app.post("/api/v1/v2t") |
|
|
def generate_text(req: GenerateRequest): |
|
|
status = check_status(_vector) |
|
|
if not status: |
|
|
_text = "Hello" |
|
|
return {"text": _text} |
|
|
if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1": |
|
|
return "Hello" |
|
|
audio_np = b64(req.audio_data) |
|
|
if audio_np.ndim == 1: |
|
|
audio_np = audio_np.reshape(1, -1) |
|
|
|
|
|
try: |
|
|
text = gt(audio_np, req.sample_rate) |
|
|
print(f"Transcribed text: {text}") |
|
|
|
|
|
system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response_text = chat(system_prompt, user_prompt=text) |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=f"{e}") |
|
|
|
|
|
return {"text": response_text} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("server:app", host="0.0.0.0", port=10016, reload=False) |
|
|
|