MeiGen-MultiTalk / handler.py
ajwestfield's picture
Add custom handler for MeiGen-MultiTalk Inference Endpoint
ab4557b
import os
import sys
import torch
import json
import base64
import io
from typing import Dict, Any, List
from PIL import Image
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the MultiTalk model handler
This will load the actual MeiGen-AI/MeiGen-MultiTalk model
"""
logger.info(f"Initializing handler with path: {path}")
# Import required libraries
try:
from diffusers import DiffusionPipeline
import torch
logger.info("Successfully imported required libraries")
except ImportError as e:
logger.error(f"Failed to import required libraries: {e}")
raise
# Set device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Load the actual MeiGen-MultiTalk model
try:
model_id = "MeiGen-AI/MeiGen-MultiTalk"
logger.info(f"Loading model from: {model_id}")
# Try to load as a diffusion pipeline
self.pipeline = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
# Enable memory optimizations
if hasattr(self.pipeline, "enable_attention_slicing"):
self.pipeline.enable_attention_slicing()
logger.info("Enabled attention slicing")
if hasattr(self.pipeline, "enable_vae_slicing"):
self.pipeline.enable_vae_slicing()
logger.info("Enabled VAE slicing")
if hasattr(self.pipeline, "enable_model_cpu_offload"):
self.pipeline.enable_model_cpu_offload()
logger.info("Enabled model CPU offload")
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# Try alternative loading method
try:
logger.info("Attempting alternative loading method...")
from transformers import AutoModel, AutoTokenizer
self.model = AutoModel.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
self.pipeline = None
logger.info("Model loaded with alternative method")
except Exception as e2:
logger.error(f"Alternative loading also failed: {e2}")
# Create a dummy model for testing
self.pipeline = None
self.model = None
logger.warning("Running in test mode without actual model")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the inference request
Args:
data: Input data containing:
- inputs: The input prompt or configuration
- parameters: Additional generation parameters
Returns:
Dict containing the generated output or error message
"""
logger.info(f"Received request with data keys: {data.keys()}")
try:
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
logger.info(f"Processing inputs: {type(inputs)}")
logger.info(f"Parameters: {parameters}")
# Handle different input types
if isinstance(inputs, str):
prompt = inputs
image = None
elif isinstance(inputs, dict):
prompt = inputs.get("prompt", "A person speaking")
# Handle base64 encoded image if provided
if "image" in inputs:
try:
image_data = base64.b64decode(inputs["image"])
image = Image.open(io.BytesIO(image_data))
logger.info("Loaded input image")
except Exception as e:
logger.error(f"Failed to decode image: {e}")
image = None
else:
image = None
else:
prompt = str(inputs)
image = None
# Extract parameters with defaults
num_inference_steps = parameters.get("num_inference_steps", 25)
guidance_scale = parameters.get("guidance_scale", 7.5)
height = parameters.get("height", 480)
width = parameters.get("width", 640)
num_frames = parameters.get("num_frames", 16)
logger.info(f"Generation params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}, frames={num_frames}")
# Generate output
if self.pipeline is not None:
logger.info("Generating with diffusion pipeline...")
# Prepare generation kwargs
gen_kwargs = {
"prompt": prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
}
# Add image if available
if image is not None:
gen_kwargs["image"] = image
# Add num_frames if the pipeline supports it
if "num_frames" in self.pipeline.__call__.__code__.co_varnames:
gen_kwargs["num_frames"] = num_frames
# Generate
with torch.no_grad():
result = self.pipeline(**gen_kwargs)
# Process result
if hasattr(result, "frames"):
frames = result.frames
if isinstance(frames, list) and len(frames) > 0:
# Convert frames to base64
encoded_frames = []
for frame in frames[0] if isinstance(frames[0], list) else frames:
if isinstance(frame, Image.Image):
buffered = io.BytesIO()
frame.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
encoded_frames.append(img_str)
return {
"frames": encoded_frames,
"num_frames": len(encoded_frames),
"message": "Video generated successfully"
}
elif hasattr(result, "images"):
# Handle image output
images = result.images
encoded_images = []
for img in images:
if isinstance(img, Image.Image):
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
encoded_images.append(img_str)
return {
"images": encoded_images,
"num_images": len(encoded_images),
"message": "Images generated successfully"
}
else:
return {
"message": "Generation completed",
"prompt": prompt,
"result_type": str(type(result))
}
elif self.model is not None:
logger.info("Generating with transformer model...")
# Use transformer model
if self.tokenizer:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(**inputs, max_length=100)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {
"generated_text": result,
"message": "Text generated successfully"
}
else:
return {
"message": "Model loaded but tokenizer not available",
"prompt": prompt
}
else:
# Test mode response
logger.warning("Running in test mode - no actual generation")
return {
"message": "Handler is running in test mode",
"prompt": prompt,
"parameters": parameters,
"status": "test_mode"
}
except Exception as e:
logger.error(f"Error during inference: {e}")
import traceback
return {
"error": str(e),
"traceback": traceback.format_exc(),
"message": "Error during generation"
}