|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
import librosa |
|
|
import torch |
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import logging |
|
|
import numpy as np |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from typing import Optional |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioRequest(BaseModel): |
|
|
audio_data: str |
|
|
sample_rate: int |
|
|
|
|
|
class AudioResponse(BaseModel): |
|
|
audio_data: str |
|
|
text: str = "" |
|
|
|
|
|
class ConfigRequest(BaseModel): |
|
|
temperature: Optional[float] = None |
|
|
max_new_tokens: Optional[int] = None |
|
|
system_prompt: Optional[str] = None |
|
|
voice_path: Optional[str] = None |
|
|
|
|
|
class ConfigResponse(BaseModel): |
|
|
success: bool |
|
|
message: str |
|
|
current_config: dict |
|
|
|
|
|
|
|
|
model = None |
|
|
INITIALIZATION_STATUS = { |
|
|
"model_loaded": False, |
|
|
"error": None |
|
|
} |
|
|
|
|
|
class Model: |
|
|
def __init__(self): |
|
|
self.model = model = AutoModel.from_pretrained( |
|
|
'./models/checkpoint', |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation='sdpa' |
|
|
) |
|
|
model = model.eval().cuda() |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
'./models/checkpoint', |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
model.init_tts() |
|
|
model.tts.float() |
|
|
|
|
|
self.model_in_sr = 16000 |
|
|
self.model_out_sr = 24000 |
|
|
self.ref_audio, _ = librosa.load('./ref_audios/female.wav', sr=self.model_in_sr, mono=True) |
|
|
|
|
|
|
|
|
self.temperature = 0.7 |
|
|
self.max_new_tokens = 150 |
|
|
self.top_p = 0.92 |
|
|
self.repetition_penalty = 1.2 |
|
|
|
|
|
|
|
|
self.podcast_prompt = "You are Speaker 2 in a podcast conversation. Listen carefully to Speaker 1 and respond naturally as if continuing a podcast dialogue. Keep your responses concise, engaging, and conversational. Maintain the flow and topic of the conversation. Avoid sounding like an assistant - you are a podcast co-host having a natural conversation." |
|
|
|
|
|
|
|
|
self.sys_prompt = { |
|
|
"role": "user", |
|
|
"content": [ |
|
|
"Clone the voice in the provided audio prompt.", |
|
|
self.ref_audio, |
|
|
self.podcast_prompt |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
print("Performing model warmup for podcast conversation...") |
|
|
|
|
|
|
|
|
audio_data = librosa.load('./ref_audios/female.wav', sr=self.model_in_sr, mono=True)[0] |
|
|
_ = self.inference(audio_data, self.model_in_sr) |
|
|
|
|
|
print("Warmup complete. Model ready for podcast conversation.") |
|
|
|
|
|
def update_config(self, config_request: ConfigRequest) -> dict: |
|
|
"""Update model configuration based on request""" |
|
|
changes = [] |
|
|
|
|
|
if config_request.temperature is not None: |
|
|
self.temperature = max(0.1, min(1.0, config_request.temperature)) |
|
|
changes.append(f"Temperature set to {self.temperature}") |
|
|
|
|
|
if config_request.max_new_tokens is not None: |
|
|
self.max_new_tokens = max(50, min(1024, config_request.max_new_tokens)) |
|
|
changes.append(f"Max new tokens set to {self.max_new_tokens}") |
|
|
|
|
|
if config_request.system_prompt is not None: |
|
|
self.podcast_prompt = config_request.system_prompt |
|
|
|
|
|
self.sys_prompt = { |
|
|
"role": "user", |
|
|
"content": [ |
|
|
"Clone the voice in the provided audio prompt.", |
|
|
self.ref_audio, |
|
|
self.podcast_prompt |
|
|
] |
|
|
} |
|
|
changes.append("System prompt updated") |
|
|
|
|
|
if config_request.voice_path is not None: |
|
|
try: |
|
|
new_ref_audio, _ = librosa.load(config_request.voice_path, sr=self.model_in_sr, mono=True) |
|
|
self.ref_audio = new_ref_audio |
|
|
|
|
|
self.sys_prompt = { |
|
|
"role": "user", |
|
|
"content": [ |
|
|
"Clone the voice in the provided audio prompt.", |
|
|
self.ref_audio, |
|
|
self.podcast_prompt |
|
|
] |
|
|
} |
|
|
changes.append(f"Voice updated from {config_request.voice_path}") |
|
|
except Exception as e: |
|
|
return { |
|
|
"success": False, |
|
|
"message": f"Failed to load voice: {str(e)}", |
|
|
"current_config": self.get_current_config() |
|
|
} |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"message": "Configuration updated: " + "; ".join(changes) if changes else "No changes made", |
|
|
"current_config": self.get_current_config() |
|
|
} |
|
|
|
|
|
def get_current_config(self) -> dict: |
|
|
"""Get current model configuration""" |
|
|
return { |
|
|
"temperature": self.temperature, |
|
|
"max_new_tokens": self.max_new_tokens, |
|
|
"top_p": self.top_p, |
|
|
"repetition_penalty": self.repetition_penalty, |
|
|
"system_prompt": self.podcast_prompt |
|
|
} |
|
|
|
|
|
def inference(self, audio_np, input_audio_sr): |
|
|
if input_audio_sr != self.model_in_sr: |
|
|
audio_np = librosa.resample(audio_np, orig_sr=input_audio_sr, target_sr=self.model_in_sr) |
|
|
|
|
|
user_question = {'role': 'user', 'content': [audio_np]} |
|
|
|
|
|
|
|
|
msgs = [self.sys_prompt, user_question] |
|
|
res = self.model.chat( |
|
|
msgs=msgs, |
|
|
tokenizer=self.tokenizer, |
|
|
sampling=True, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
use_tts_template=True, |
|
|
generate_audio=True, |
|
|
temperature=self.temperature, |
|
|
top_p=self.top_p, |
|
|
repetition_penalty=self.repetition_penalty, |
|
|
) |
|
|
audio = res["audio_wav"].cpu().numpy() |
|
|
|
|
|
if self.model_out_sr != input_audio_sr: |
|
|
audio = librosa.resample(audio, orig_sr=self.model_out_sr, target_sr=input_audio_sr) |
|
|
|
|
|
return audio, res["text"] |
|
|
|
|
|
def initialize_model(): |
|
|
"""Initialize the MiniCPM model""" |
|
|
global model, INITIALIZATION_STATUS |
|
|
try: |
|
|
logger.info("Initializing model...") |
|
|
model = Model() |
|
|
|
|
|
INITIALIZATION_STATUS["model_loaded"] = True |
|
|
logger.info("MiniCPM model initialized successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
INITIALIZATION_STATUS["error"] = str(e) |
|
|
logger.error(f"Failed to initialize model: {e}") |
|
|
return False |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize model on startup""" |
|
|
initialize_model() |
|
|
|
|
|
@app.get("/api/v1/health") |
|
|
def health_check(): |
|
|
"""Health check endpoint""" |
|
|
status = { |
|
|
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
|
|
"model_loaded": INITIALIZATION_STATUS["model_loaded"], |
|
|
"error": INITIALIZATION_STATUS["error"] |
|
|
} |
|
|
return status |
|
|
|
|
|
@app.post("/api/v1/inference") |
|
|
async def inference(request: AudioRequest) -> AudioResponse: |
|
|
"""Run inference with MiniCPM model""" |
|
|
if not INITIALIZATION_STATUS["model_loaded"]: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
audio_bytes = base64.b64decode(request.audio_data) |
|
|
audio_np = np.load(io.BytesIO(audio_bytes)).flatten() |
|
|
|
|
|
|
|
|
import time |
|
|
start = time.time() |
|
|
print(f"starting inference with audio length {audio_np.shape}") |
|
|
audio_response, text_response = model.inference(audio_np, request.sample_rate) |
|
|
print(f"inference took {time.time() - start} seconds") |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
np.save(buffer, audio_response) |
|
|
audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
|
|
return AudioResponse( |
|
|
audio_data=audio_b64, |
|
|
text=text_response |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference failed: {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=str(e) |
|
|
) |
|
|
|
|
|
@app.post("/api/v1/config") |
|
|
async def update_config(request: ConfigRequest) -> ConfigResponse: |
|
|
"""Update model configuration for podcast-style conversations""" |
|
|
if not INITIALIZATION_STATUS["model_loaded"]: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
|
|
) |
|
|
|
|
|
try: |
|
|
result = model.update_config(request) |
|
|
return ConfigResponse( |
|
|
success=result["success"], |
|
|
message=result["message"], |
|
|
current_config=result["current_config"] |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Configuration update failed: {str(e)}") |
|
|
return ConfigResponse( |
|
|
success=False, |
|
|
message=f"Configuration update failed: {str(e)}", |
|
|
current_config=model.get_current_config() |
|
|
) |
|
|
|
|
|
@app.get("/api/v1/config") |
|
|
async def get_config() -> ConfigResponse: |
|
|
"""Get current model configuration""" |
|
|
if not INITIALIZATION_STATUS["model_loaded"]: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
|
|
) |
|
|
|
|
|
return ConfigResponse( |
|
|
success=True, |
|
|
message="Current configuration", |
|
|
current_config=model.get_current_config() |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|