sla-it-sec / server.py
mrbeanlas's picture
Upload folder using huggingface_hub
4817c76 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import librosa
import torch
# import torch
import base64
import io
import logging
import numpy as np
from transformers import AutoModel, AutoTokenizer
from typing import Optional
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# # Add CORS middleware
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# # Add CORS middleware
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# # Add CORS middleware
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# # Add CORS middleware
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
class AudioRequest(BaseModel):
audio_data: str
sample_rate: int
class AudioResponse(BaseModel):
audio_data: str
text: str = ""
class ConfigRequest(BaseModel):
temperature: Optional[float] = None
max_new_tokens: Optional[int] = None
system_prompt: Optional[str] = None
voice_path: Optional[str] = None
class ConfigResponse(BaseModel):
success: bool
message: str
current_config: dict
# Global model instance
model = None
INITIALIZATION_STATUS = {
"model_loaded": False,
"error": None
}
class Model:
def __init__(self):
self.model = model = AutoModel.from_pretrained(
'./models/checkpoint',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation='sdpa'
)
model = model.eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(
'./models/checkpoint',
trust_remote_code=True
)
# Initialize TTS
model.init_tts()
model.tts.float() # Convert TTS to float32 if needed
self.model_in_sr = 16000
self.model_out_sr = 24000
self.ref_audio, _ = librosa.load('./ref_audios/female.wav', sr=self.model_in_sr, mono=True) # load the reference audio
# Configurable parameters
self.temperature = 0.7
self.max_new_tokens = 150
self.top_p = 0.92
self.repetition_penalty = 1.2
# Custom podcast conversation system prompt
self.podcast_prompt = "You are Speaker 2 in a podcast conversation. Listen carefully to Speaker 1 and respond naturally as if continuing a podcast dialogue. Keep your responses concise, engaging, and conversational. Maintain the flow and topic of the conversation. Avoid sounding like an assistant - you are a podcast co-host having a natural conversation."
# Create custom system prompt using audio_roleplay mode for better conversational style
self.sys_prompt = {
"role": "user",
"content": [
"Clone the voice in the provided audio prompt.",
self.ref_audio,
self.podcast_prompt
]
}
# Enhanced warmup with podcast-style examples
print("Performing model warmup for podcast conversation...")
# First warmup with a simple example
audio_data = librosa.load('./ref_audios/female.wav', sr=self.model_in_sr, mono=True)[0]
_ = self.inference(audio_data, self.model_in_sr)
print("Warmup complete. Model ready for podcast conversation.")
def update_config(self, config_request: ConfigRequest) -> dict:
"""Update model configuration based on request"""
changes = []
if config_request.temperature is not None:
self.temperature = max(0.1, min(1.0, config_request.temperature))
changes.append(f"Temperature set to {self.temperature}")
if config_request.max_new_tokens is not None:
self.max_new_tokens = max(50, min(1024, config_request.max_new_tokens))
changes.append(f"Max new tokens set to {self.max_new_tokens}")
if config_request.system_prompt is not None:
self.podcast_prompt = config_request.system_prompt
# Update system prompt
self.sys_prompt = {
"role": "user",
"content": [
"Clone the voice in the provided audio prompt.",
self.ref_audio,
self.podcast_prompt
]
}
changes.append("System prompt updated")
if config_request.voice_path is not None:
try:
new_ref_audio, _ = librosa.load(config_request.voice_path, sr=self.model_in_sr, mono=True)
self.ref_audio = new_ref_audio
# Update system prompt with new voice
self.sys_prompt = {
"role": "user",
"content": [
"Clone the voice in the provided audio prompt.",
self.ref_audio,
self.podcast_prompt
]
}
changes.append(f"Voice updated from {config_request.voice_path}")
except Exception as e:
return {
"success": False,
"message": f"Failed to load voice: {str(e)}",
"current_config": self.get_current_config()
}
return {
"success": True,
"message": "Configuration updated: " + "; ".join(changes) if changes else "No changes made",
"current_config": self.get_current_config()
}
def get_current_config(self) -> dict:
"""Get current model configuration"""
return {
"temperature": self.temperature,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"repetition_penalty": self.repetition_penalty,
"system_prompt": self.podcast_prompt
}
def inference(self, audio_np, input_audio_sr):
if input_audio_sr != self.model_in_sr:
audio_np = librosa.resample(audio_np, orig_sr=input_audio_sr, target_sr=self.model_in_sr)
user_question = {'role': 'user', 'content': [audio_np]}
# round one
msgs = [self.sys_prompt, user_question]
res = self.model.chat(
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
max_new_tokens=self.max_new_tokens,
use_tts_template=True,
generate_audio=True,
temperature=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
)
audio = res["audio_wav"].cpu().numpy()
if self.model_out_sr != input_audio_sr:
audio = librosa.resample(audio, orig_sr=self.model_out_sr, target_sr=input_audio_sr)
return audio, res["text"]
def initialize_model():
"""Initialize the MiniCPM model"""
global model, INITIALIZATION_STATUS
try:
logger.info("Initializing model...")
model = Model()
INITIALIZATION_STATUS["model_loaded"] = True
logger.info("MiniCPM model initialized successfully")
return True
except Exception as e:
INITIALIZATION_STATUS["error"] = str(e)
logger.error(f"Failed to initialize model: {e}")
return False
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup"""
initialize_model()
@app.get("/api/v1/health")
def health_check():
"""Health check endpoint"""
status = {
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
"model_loaded": INITIALIZATION_STATUS["model_loaded"],
"error": INITIALIZATION_STATUS["error"]
}
return status
@app.post("/api/v1/inference")
async def inference(request: AudioRequest) -> AudioResponse:
"""Run inference with MiniCPM model"""
if not INITIALIZATION_STATUS["model_loaded"]:
raise HTTPException(
status_code=503,
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
)
try:
# Decode audio data from base64
audio_bytes = base64.b64decode(request.audio_data)
audio_np = np.load(io.BytesIO(audio_bytes)).flatten()
# Generate response
import time
start = time.time()
print(f"starting inference with audio length {audio_np.shape}")
audio_response, text_response = model.inference(audio_np, request.sample_rate)
print(f"inference took {time.time() - start} seconds")
# If we got audio, save it and encode to base64
buffer = io.BytesIO()
np.save(buffer, audio_response)
audio_b64 = base64.b64encode(buffer.getvalue()).decode()
return AudioResponse(
audio_data=audio_b64,
text=text_response
)
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise HTTPException(
status_code=500,
detail=str(e)
)
@app.post("/api/v1/config")
async def update_config(request: ConfigRequest) -> ConfigResponse:
"""Update model configuration for podcast-style conversations"""
if not INITIALIZATION_STATUS["model_loaded"]:
raise HTTPException(
status_code=503,
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
)
try:
result = model.update_config(request)
return ConfigResponse(
success=result["success"],
message=result["message"],
current_config=result["current_config"]
)
except Exception as e:
logger.error(f"Configuration update failed: {str(e)}")
return ConfigResponse(
success=False,
message=f"Configuration update failed: {str(e)}",
current_config=model.get_current_config()
)
@app.get("/api/v1/config")
async def get_config() -> ConfigResponse:
"""Get current model configuration"""
if not INITIALIZATION_STATUS["model_loaded"]:
raise HTTPException(
status_code=503,
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
)
return ConfigResponse(
success=True,
message="Current configuration",
current_config=model.get_current_config()
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)