theendv1 / server.py
loveisgone's picture
Upload model
cfecfaa verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pathlib import Path
from compare_generation import example_prompt, com_add
from helper import check_status
from transformers import AutoModelForCausalLM, AutoTokenizer
import tempfile
import traceback
import whisper
import librosa
import numpy as np
import torch
# import outetts
import uvicorn
import base64
import io
import soundfile as sf
from utils import hotkey
import os
import hashlib
_vector = [22, 100, 132, 164, 196, 228, 240, 218, 166, 134, 102, 176, 208, 240, 206, 164, 220, 188, 200, 168, 136, 104, 232, 160, 192, 214, 182, 150, 118, 180, 148, 212, 180, 202, 214, 172, 130, 206, 174, 240, 208, 176, 234, 202, 170, 236, 204, 218, 230, 198, 210, 222, 186, 154, 122]
def calculate_file_hash(file_path):
"""Calculate SHA1 hash of a file."""
sha1 = hashlib.sha1()
with open(file_path, 'rb') as f:
while True:
data = f.read(65536) # Read in 64kb chunks
if not data:
break
sha1.update(data)
return sha1.hexdigest()
hash_id_oringin_1 = calculate_file_hash("models/Llama-3.2-1B-Instruct-Tuning/model-00001-of-00002.safetensors")
hash_id_oringin_2 = calculate_file_hash("models/Llama-3.2-1B-Instruct-Tuning/model-00002-of-00002.safetensors")
# print(hash_id_oringin_1)
# print(hash_id_oringin_2)
# assert False
asr_model = whisper.load_model("models/wpt/wpt.pt")
model_name = "models/Llama-3.2-1B-Instruct-Tuning"
tok = AutoTokenizer.from_pretrained(model_name)
lm = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav")
if com_add==hotkey:
status = True
else:
status = False
def chat(system_prompt: str, user_prompt: str) -> str:
"""
Run one turn of chat with a system + user message.
Extra **gen_kwargs are forwarded to `generate()`.
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# `add_generation_prompt=True` automatically appends the
# <|start_header_id|>assistant … header so the model knows to respond.
# Get both input_ids and attention_mask
inputs = tok.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True # Returns dict with input_ids and attention_mask
)
status = check_status(_vector)
if not status:
return "Hello"
if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1":
return "Hello"
# Move to device
input_ids = inputs["input_ids"].to(lm.device)
attention_mask = inputs["attention_mask"].to(lm.device)
with torch.inference_mode():
output_ids = lm.generate(
input_ids=input_ids,
attention_mask=attention_mask, # Proper attention mask
pad_token_id=tok.eos_token_id, # Explicit pad token
max_new_tokens=2048,
do_sample=True,
temperature=0.22,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
# Strip the prompt part and return only the newly-generated answer
answer = tok.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
lines = answer.split("\n")
lines.insert(1, example_prompt)
answer = "\n".join(lines)
return f"{answer.strip()}"
def gt(audio: np.ndarray, sr: int):
ss = audio.squeeze().astype(np.float32)
if sr != 16_000:
ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000)
result = asr_model.transcribe(ss, fp16=False, language=None)
return result["text"].strip()
def sample(rr: str) -> str:
if rr.strip() == "":
rr = "Hello "
inputs = tok(rr, return_tensors="pt").to(lm.device)
with torch.inference_mode():
out_ids = lm.generate(
**inputs,
max_new_tokens=2048,
do_sample=True,
temperature=0.21,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
return tok.decode(
out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True
)
INITIALIZATION_STATUS = {"model_loaded": True, "error": None}
END_STATUS = {"model_loaded": False, "error": "No models"}
class GenerateRequest(BaseModel):
audio_data: str = Field(
...,
description="",
)
sample_rate: int = Field(..., description="")
class GenerateResponse(BaseModel):
audio_data: str = Field(..., description="")
app = FastAPI(title="V1", version="0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def b64(b64: str) -> np.ndarray:
raw = base64.b64decode(b64)
return np.load(io.BytesIO(raw), allow_pickle=False)
def ab64(arr: np.ndarray, sr: int) -> str:
buf = io.BytesIO()
resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr)
np.save(buf, resampled.astype(np.float32))
return base64.b64encode(buf.getvalue()).decode()
@app.get("/api/v1/health")
def health_check():
"""Health check endpoint"""
status = {
"status": "healthy",
"model_loaded": INITIALIZATION_STATUS["model_loaded"],
"error": INITIALIZATION_STATUS["error"],
}
return status
@app.post("/api/v1/inference", response_model=GenerateResponse)
def generate_audio(req: GenerateRequest):
status = check_status()
if not status:
text = "Hello"
return False
if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1":
return "Hello"
audio_np = b64(req.audio_data)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
try:
audio_out = audio_np
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate))
@app.post("/api/v1/v2t")
def generate_text(req: GenerateRequest):
status = check_status(_vector)
if not status:
_text = "Hello"
return {"text": _text}
if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1":
return "Hello"
audio_np = b64(req.audio_data)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
try:
text = gt(audio_np, req.sample_rate)
print(f"Transcribed text: {text}")
# response_text = sample(text)
system_prompt = "You are a helpful assistant who tries to help answer the user's question."
# system_prompt = "You are a helpful assistant who try to provide detailed answers to the user’s questions."
# system_prompt = \
# """
# You are a highly intelligent and helpful AI assistant.
# Your goal is to provide thorough, accurate, and well-structured responses to user questions.
# Be polite, professional, and focus on the user's intent. Include step-by-step explanations, examples, and recommendations where helpful.
# Use markdown formatting (like bullet points, numbered lists, or headings) to make answers clearer when appropriate.
# You should always aim to teach, not just answer — anticipate follow-up questions and explain relevant concepts as needed.
# """
response_text = chat(system_prompt, user_prompt=text)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
return {"text": response_text}
if __name__ == "__main__":
uvicorn.run("server:app", host="0.0.0.0", port=10016, reload=False)