import os import sys import torch import json import base64 import io from typing import Dict, Any, List from PIL import Image import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): """ Initialize the MultiTalk model handler This will load the actual MeiGen-AI/MeiGen-MultiTalk model """ logger.info(f"Initializing handler with path: {path}") # Import required libraries try: from diffusers import DiffusionPipeline import torch logger.info("Successfully imported required libraries") except ImportError as e: logger.error(f"Failed to import required libraries: {e}") raise # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Load the actual MeiGen-MultiTalk model try: model_id = "MeiGen-AI/MeiGen-MultiTalk" logger.info(f"Loading model from: {model_id}") # Try to load as a diffusion pipeline self.pipeline = DiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", low_cpu_mem_usage=True ) # Enable memory optimizations if hasattr(self.pipeline, "enable_attention_slicing"): self.pipeline.enable_attention_slicing() logger.info("Enabled attention slicing") if hasattr(self.pipeline, "enable_vae_slicing"): self.pipeline.enable_vae_slicing() logger.info("Enabled VAE slicing") if hasattr(self.pipeline, "enable_model_cpu_offload"): self.pipeline.enable_model_cpu_offload() logger.info("Enabled model CPU offload") logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") # Try alternative loading method try: logger.info("Attempting alternative loading method...") from transformers import AutoModel, AutoTokenizer self.model = AutoModel.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True ) self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) self.pipeline = None logger.info("Model loaded with alternative method") except Exception as e2: logger.error(f"Alternative loading also failed: {e2}") # Create a dummy model for testing self.pipeline = None self.model = None logger.warning("Running in test mode without actual model") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the inference request Args: data: Input data containing: - inputs: The input prompt or configuration - parameters: Additional generation parameters Returns: Dict containing the generated output or error message """ logger.info(f"Received request with data keys: {data.keys()}") try: # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) logger.info(f"Processing inputs: {type(inputs)}") logger.info(f"Parameters: {parameters}") # Handle different input types if isinstance(inputs, str): prompt = inputs image = None elif isinstance(inputs, dict): prompt = inputs.get("prompt", "A person speaking") # Handle base64 encoded image if provided if "image" in inputs: try: image_data = base64.b64decode(inputs["image"]) image = Image.open(io.BytesIO(image_data)) logger.info("Loaded input image") except Exception as e: logger.error(f"Failed to decode image: {e}") image = None else: image = None else: prompt = str(inputs) image = None # Extract parameters with defaults num_inference_steps = parameters.get("num_inference_steps", 25) guidance_scale = parameters.get("guidance_scale", 7.5) height = parameters.get("height", 480) width = parameters.get("width", 640) num_frames = parameters.get("num_frames", 16) logger.info(f"Generation params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}, frames={num_frames}") # Generate output if self.pipeline is not None: logger.info("Generating with diffusion pipeline...") # Prepare generation kwargs gen_kwargs = { "prompt": prompt, "height": height, "width": width, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, } # Add image if available if image is not None: gen_kwargs["image"] = image # Add num_frames if the pipeline supports it if "num_frames" in self.pipeline.__call__.__code__.co_varnames: gen_kwargs["num_frames"] = num_frames # Generate with torch.no_grad(): result = self.pipeline(**gen_kwargs) # Process result if hasattr(result, "frames"): frames = result.frames if isinstance(frames, list) and len(frames) > 0: # Convert frames to base64 encoded_frames = [] for frame in frames[0] if isinstance(frames[0], list) else frames: if isinstance(frame, Image.Image): buffered = io.BytesIO() frame.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() encoded_frames.append(img_str) return { "frames": encoded_frames, "num_frames": len(encoded_frames), "message": "Video generated successfully" } elif hasattr(result, "images"): # Handle image output images = result.images encoded_images = [] for img in images: if isinstance(img, Image.Image): buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() encoded_images.append(img_str) return { "images": encoded_images, "num_images": len(encoded_images), "message": "Images generated successfully" } else: return { "message": "Generation completed", "prompt": prompt, "result_type": str(type(result)) } elif self.model is not None: logger.info("Generating with transformer model...") # Use transformer model if self.tokenizer: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate(**inputs, max_length=100) result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return { "generated_text": result, "message": "Text generated successfully" } else: return { "message": "Model loaded but tokenizer not available", "prompt": prompt } else: # Test mode response logger.warning("Running in test mode - no actual generation") return { "message": "Handler is running in test mode", "prompt": prompt, "parameters": parameters, "status": "test_mode" } except Exception as e: logger.error(f"Error during inference: {e}") import traceback return { "error": str(e), "traceback": traceback.format_exc(), "message": "Error during generation" }