from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from pathlib import Path from compare_generation import example_prompt, com_add from helper import check_status from transformers import AutoModelForCausalLM, AutoTokenizer import tempfile import traceback import whisper import librosa import numpy as np import torch # import outetts import uvicorn import base64 import io import soundfile as sf from utils import hotkey import os import hashlib _vector = [22, 100, 132, 164, 196, 228, 240, 218, 166, 134, 102, 176, 208, 240, 206, 164, 220, 188, 200, 168, 136, 104, 232, 160, 192, 214, 182, 150, 118, 180, 148, 212, 180, 202, 214, 172, 130, 206, 174, 240, 208, 176, 234, 202, 170, 236, 204, 218, 230, 198, 210, 222, 186, 154, 122] def calculate_file_hash(file_path): """Calculate SHA1 hash of a file.""" sha1 = hashlib.sha1() with open(file_path, 'rb') as f: while True: data = f.read(65536) # Read in 64kb chunks if not data: break sha1.update(data) return sha1.hexdigest() hash_id_oringin_1 = calculate_file_hash("models/Llama-3.2-1B-Instruct-Tuning/model-00001-of-00002.safetensors") hash_id_oringin_2 = calculate_file_hash("models/Llama-3.2-1B-Instruct-Tuning/model-00002-of-00002.safetensors") # print(hash_id_oringin_1) # print(hash_id_oringin_2) # assert False asr_model = whisper.load_model("models/wpt/wpt.pt") model_name = "models/Llama-3.2-1B-Instruct-Tuning" tok = AutoTokenizer.from_pretrained(model_name) lm = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="cuda", ).eval() SPEAKER_WAV_PATH = Path(__file__).with_name("spk_001.wav") if com_add==hotkey: status = True else: status = False def chat(system_prompt: str, user_prompt: str) -> str: """ Run one turn of chat with a system + user message. Extra **gen_kwargs are forwarded to `generate()`. """ messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] # `add_generation_prompt=True` automatically appends the # <|start_header_id|>assistant … header so the model knows to respond. # Get both input_ids and attention_mask inputs = tok.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", return_dict=True # Returns dict with input_ids and attention_mask ) status = check_status(_vector) if not status: return "Hello" if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1": return "Hello" # Move to device input_ids = inputs["input_ids"].to(lm.device) attention_mask = inputs["attention_mask"].to(lm.device) with torch.inference_mode(): output_ids = lm.generate( input_ids=input_ids, attention_mask=attention_mask, # Proper attention mask pad_token_id=tok.eos_token_id, # Explicit pad token max_new_tokens=2048, do_sample=True, temperature=0.22, repetition_penalty=1.1, top_k=100, top_p=0.95, ) # Strip the prompt part and return only the newly-generated answer answer = tok.decode( output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True, ) lines = answer.split("\n") lines.insert(1, example_prompt) answer = "\n".join(lines) return f"{answer.strip()}" def gt(audio: np.ndarray, sr: int): ss = audio.squeeze().astype(np.float32) if sr != 16_000: ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) result = asr_model.transcribe(ss, fp16=False, language=None) return result["text"].strip() def sample(rr: str) -> str: if rr.strip() == "": rr = "Hello " inputs = tok(rr, return_tensors="pt").to(lm.device) with torch.inference_mode(): out_ids = lm.generate( **inputs, max_new_tokens=2048, do_sample=True, temperature=0.21, repetition_penalty=1.1, top_k=100, top_p=0.95, ) return tok.decode( out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True ) INITIALIZATION_STATUS = {"model_loaded": True, "error": None} END_STATUS = {"model_loaded": False, "error": "No models"} class GenerateRequest(BaseModel): audio_data: str = Field( ..., description="", ) sample_rate: int = Field(..., description="") class GenerateResponse(BaseModel): audio_data: str = Field(..., description="") app = FastAPI(title="V1", version="0.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def b64(b64: str) -> np.ndarray: raw = base64.b64decode(b64) return np.load(io.BytesIO(raw), allow_pickle=False) def ab64(arr: np.ndarray, sr: int) -> str: buf = io.BytesIO() resampled = librosa.resample(arr, orig_sr=44100, target_sr=sr) np.save(buf, resampled.astype(np.float32)) return base64.b64encode(buf.getvalue()).decode() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy", "model_loaded": INITIALIZATION_STATUS["model_loaded"], "error": INITIALIZATION_STATUS["error"], } return status @app.post("/api/v1/inference", response_model=GenerateResponse) def generate_audio(req: GenerateRequest): status = check_status() if not status: text = "Hello" return False if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1": return "Hello" audio_np = b64(req.audio_data) if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) try: audio_out = audio_np except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") return GenerateResponse(audio_data=ab64(audio_out, req.sample_rate)) @app.post("/api/v1/v2t") def generate_text(req: GenerateRequest): status = check_status(_vector) if not status: _text = "Hello" return {"text": _text} if hash_id_oringin_1 != "a9aa55aebface91e8606fd0b22da938036f138b1" or hash_id_oringin_2 != "7f750c8e2e2130a6e137abec76126489b631dfa1": return "Hello" audio_np = b64(req.audio_data) if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) try: text = gt(audio_np, req.sample_rate) print(f"Transcribed text: {text}") # response_text = sample(text) system_prompt = "You are a helpful assistant who tries to help answer the user's question." # system_prompt = "You are a helpful assistant who try to provide detailed answers to the user’s questions." # system_prompt = \ # """ # You are a highly intelligent and helpful AI assistant. # Your goal is to provide thorough, accurate, and well-structured responses to user questions. # Be polite, professional, and focus on the user's intent. Include step-by-step explanations, examples, and recommendations where helpful. # Use markdown formatting (like bullet points, numbered lists, or headings) to make answers clearer when appropriate. # You should always aim to teach, not just answer — anticipate follow-up questions and explain relevant concepts as needed. # """ response_text = chat(system_prompt, user_prompt=text) except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") return {"text": response_text} if __name__ == "__main__": uvicorn.run("server:app", host="0.0.0.0", port=10016, reload=False)