from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import librosa import torch # import torch import base64 import io import logging import numpy as np from transformers import AutoModel, AutoTokenizer from typing import Optional logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # # Add CORS middleware # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # # Add CORS middleware # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # # Add CORS middleware # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # # Add CORS middleware # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) class AudioRequest(BaseModel): audio_data: str sample_rate: int class AudioResponse(BaseModel): audio_data: str text: str = "" class ConfigRequest(BaseModel): temperature: Optional[float] = None max_new_tokens: Optional[int] = None system_prompt: Optional[str] = None voice_path: Optional[str] = None class ConfigResponse(BaseModel): success: bool message: str current_config: dict # Global model instance model = None INITIALIZATION_STATUS = { "model_loaded": False, "error": None } class Model: def __init__(self): self.model = model = AutoModel.from_pretrained( './models/checkpoint', trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation='sdpa' ) model = model.eval().cuda() self.tokenizer = AutoTokenizer.from_pretrained( './models/checkpoint', trust_remote_code=True ) # Initialize TTS model.init_tts() model.tts.float() # Convert TTS to float32 if needed self.model_in_sr = 16000 self.model_out_sr = 24000 self.ref_audio, _ = librosa.load('./ref_audios/female.wav', sr=self.model_in_sr, mono=True) # load the reference audio # Configurable parameters self.temperature = 0.7 self.max_new_tokens = 150 self.top_p = 0.92 self.repetition_penalty = 1.2 # Custom podcast conversation system prompt self.podcast_prompt = "You are Speaker 2 in a podcast conversation. Listen carefully to Speaker 1 and respond naturally as if continuing a podcast dialogue. Keep your responses concise, engaging, and conversational. Maintain the flow and topic of the conversation. Avoid sounding like an assistant - you are a podcast co-host having a natural conversation." # Create custom system prompt using audio_roleplay mode for better conversational style self.sys_prompt = { "role": "user", "content": [ "Clone the voice in the provided audio prompt.", self.ref_audio, self.podcast_prompt ] } # Enhanced warmup with podcast-style examples print("Performing model warmup for podcast conversation...") # First warmup with a simple example audio_data = librosa.load('./ref_audios/female.wav', sr=self.model_in_sr, mono=True)[0] _ = self.inference(audio_data, self.model_in_sr) print("Warmup complete. Model ready for podcast conversation.") def update_config(self, config_request: ConfigRequest) -> dict: """Update model configuration based on request""" changes = [] if config_request.temperature is not None: self.temperature = max(0.1, min(1.0, config_request.temperature)) changes.append(f"Temperature set to {self.temperature}") if config_request.max_new_tokens is not None: self.max_new_tokens = max(50, min(1024, config_request.max_new_tokens)) changes.append(f"Max new tokens set to {self.max_new_tokens}") if config_request.system_prompt is not None: self.podcast_prompt = config_request.system_prompt # Update system prompt self.sys_prompt = { "role": "user", "content": [ "Clone the voice in the provided audio prompt.", self.ref_audio, self.podcast_prompt ] } changes.append("System prompt updated") if config_request.voice_path is not None: try: new_ref_audio, _ = librosa.load(config_request.voice_path, sr=self.model_in_sr, mono=True) self.ref_audio = new_ref_audio # Update system prompt with new voice self.sys_prompt = { "role": "user", "content": [ "Clone the voice in the provided audio prompt.", self.ref_audio, self.podcast_prompt ] } changes.append(f"Voice updated from {config_request.voice_path}") except Exception as e: return { "success": False, "message": f"Failed to load voice: {str(e)}", "current_config": self.get_current_config() } return { "success": True, "message": "Configuration updated: " + "; ".join(changes) if changes else "No changes made", "current_config": self.get_current_config() } def get_current_config(self) -> dict: """Get current model configuration""" return { "temperature": self.temperature, "max_new_tokens": self.max_new_tokens, "top_p": self.top_p, "repetition_penalty": self.repetition_penalty, "system_prompt": self.podcast_prompt } def inference(self, audio_np, input_audio_sr): if input_audio_sr != self.model_in_sr: audio_np = librosa.resample(audio_np, orig_sr=input_audio_sr, target_sr=self.model_in_sr) user_question = {'role': 'user', 'content': [audio_np]} # round one msgs = [self.sys_prompt, user_question] res = self.model.chat( msgs=msgs, tokenizer=self.tokenizer, sampling=True, max_new_tokens=self.max_new_tokens, use_tts_template=True, generate_audio=True, temperature=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, ) audio = res["audio_wav"].cpu().numpy() if self.model_out_sr != input_audio_sr: audio = librosa.resample(audio, orig_sr=self.model_out_sr, target_sr=input_audio_sr) return audio, res["text"] def initialize_model(): """Initialize the MiniCPM model""" global model, INITIALIZATION_STATUS try: logger.info("Initializing model...") model = Model() INITIALIZATION_STATUS["model_loaded"] = True logger.info("MiniCPM model initialized successfully") return True except Exception as e: INITIALIZATION_STATUS["error"] = str(e) logger.error(f"Failed to initialize model: {e}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" initialize_model() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", "model_loaded": INITIALIZATION_STATUS["model_loaded"], "error": INITIALIZATION_STATUS["error"] } return status @app.post("/api/v1/inference") async def inference(request: AudioRequest) -> AudioResponse: """Run inference with MiniCPM model""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: # Decode audio data from base64 audio_bytes = base64.b64decode(request.audio_data) audio_np = np.load(io.BytesIO(audio_bytes)).flatten() # Generate response import time start = time.time() print(f"starting inference with audio length {audio_np.shape}") audio_response, text_response = model.inference(audio_np, request.sample_rate) print(f"inference took {time.time() - start} seconds") # If we got audio, save it and encode to base64 buffer = io.BytesIO() np.save(buffer, audio_response) audio_b64 = base64.b64encode(buffer.getvalue()).decode() return AudioResponse( audio_data=audio_b64, text=text_response ) except Exception as e: logger.error(f"Inference failed: {str(e)}") raise HTTPException( status_code=500, detail=str(e) ) @app.post("/api/v1/config") async def update_config(request: ConfigRequest) -> ConfigResponse: """Update model configuration for podcast-style conversations""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: result = model.update_config(request) return ConfigResponse( success=result["success"], message=result["message"], current_config=result["current_config"] ) except Exception as e: logger.error(f"Configuration update failed: {str(e)}") return ConfigResponse( success=False, message=f"Configuration update failed: {str(e)}", current_config=model.get_current_config() ) @app.get("/api/v1/config") async def get_config() -> ConfigResponse: """Get current model configuration""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) return ConfigResponse( success=True, message="Current configuration", current_config=model.get_current_config() ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)