|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import json |
|
|
import base64 |
|
|
import io |
|
|
from typing import Dict, Any, List |
|
|
from PIL import Image |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the MultiTalk model handler |
|
|
This will load the actual MeiGen-AI/MeiGen-MultiTalk model |
|
|
""" |
|
|
logger.info(f"Initializing handler with path: {path}") |
|
|
|
|
|
|
|
|
try: |
|
|
from diffusers import DiffusionPipeline |
|
|
import torch |
|
|
logger.info("Successfully imported required libraries") |
|
|
except ImportError as e: |
|
|
logger.error(f"Failed to import required libraries: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
try: |
|
|
model_id = "MeiGen-AI/MeiGen-MultiTalk" |
|
|
logger.info(f"Loading model from: {model_id}") |
|
|
|
|
|
|
|
|
self.pipeline = DiffusionPipeline.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(self.pipeline, "enable_attention_slicing"): |
|
|
self.pipeline.enable_attention_slicing() |
|
|
logger.info("Enabled attention slicing") |
|
|
|
|
|
if hasattr(self.pipeline, "enable_vae_slicing"): |
|
|
self.pipeline.enable_vae_slicing() |
|
|
logger.info("Enabled VAE slicing") |
|
|
|
|
|
if hasattr(self.pipeline, "enable_model_cpu_offload"): |
|
|
self.pipeline.enable_model_cpu_offload() |
|
|
logger.info("Enabled model CPU offload") |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
|
|
|
try: |
|
|
logger.info("Attempting alternative loading method...") |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
self.pipeline = None |
|
|
logger.info("Model loaded with alternative method") |
|
|
|
|
|
except Exception as e2: |
|
|
logger.error(f"Alternative loading also failed: {e2}") |
|
|
|
|
|
self.pipeline = None |
|
|
self.model = None |
|
|
logger.warning("Running in test mode without actual model") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process the inference request |
|
|
|
|
|
Args: |
|
|
data: Input data containing: |
|
|
- inputs: The input prompt or configuration |
|
|
- parameters: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Dict containing the generated output or error message |
|
|
""" |
|
|
logger.info(f"Received request with data keys: {data.keys()}") |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
logger.info(f"Processing inputs: {type(inputs)}") |
|
|
logger.info(f"Parameters: {parameters}") |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
prompt = inputs |
|
|
image = None |
|
|
elif isinstance(inputs, dict): |
|
|
prompt = inputs.get("prompt", "A person speaking") |
|
|
|
|
|
if "image" in inputs: |
|
|
try: |
|
|
image_data = base64.b64decode(inputs["image"]) |
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
logger.info("Loaded input image") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to decode image: {e}") |
|
|
image = None |
|
|
else: |
|
|
image = None |
|
|
else: |
|
|
prompt = str(inputs) |
|
|
image = None |
|
|
|
|
|
|
|
|
num_inference_steps = parameters.get("num_inference_steps", 25) |
|
|
guidance_scale = parameters.get("guidance_scale", 7.5) |
|
|
height = parameters.get("height", 480) |
|
|
width = parameters.get("width", 640) |
|
|
num_frames = parameters.get("num_frames", 16) |
|
|
|
|
|
logger.info(f"Generation params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}, frames={num_frames}") |
|
|
|
|
|
|
|
|
if self.pipeline is not None: |
|
|
logger.info("Generating with diffusion pipeline...") |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"prompt": prompt, |
|
|
"height": height, |
|
|
"width": width, |
|
|
"num_inference_steps": num_inference_steps, |
|
|
"guidance_scale": guidance_scale, |
|
|
} |
|
|
|
|
|
|
|
|
if image is not None: |
|
|
gen_kwargs["image"] = image |
|
|
|
|
|
|
|
|
if "num_frames" in self.pipeline.__call__.__code__.co_varnames: |
|
|
gen_kwargs["num_frames"] = num_frames |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
result = self.pipeline(**gen_kwargs) |
|
|
|
|
|
|
|
|
if hasattr(result, "frames"): |
|
|
frames = result.frames |
|
|
if isinstance(frames, list) and len(frames) > 0: |
|
|
|
|
|
encoded_frames = [] |
|
|
for frame in frames[0] if isinstance(frames[0], list) else frames: |
|
|
if isinstance(frame, Image.Image): |
|
|
buffered = io.BytesIO() |
|
|
frame.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
encoded_frames.append(img_str) |
|
|
|
|
|
return { |
|
|
"frames": encoded_frames, |
|
|
"num_frames": len(encoded_frames), |
|
|
"message": "Video generated successfully" |
|
|
} |
|
|
elif hasattr(result, "images"): |
|
|
|
|
|
images = result.images |
|
|
encoded_images = [] |
|
|
for img in images: |
|
|
if isinstance(img, Image.Image): |
|
|
buffered = io.BytesIO() |
|
|
img.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
encoded_images.append(img_str) |
|
|
|
|
|
return { |
|
|
"images": encoded_images, |
|
|
"num_images": len(encoded_images), |
|
|
"message": "Images generated successfully" |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"message": "Generation completed", |
|
|
"prompt": prompt, |
|
|
"result_type": str(type(result)) |
|
|
} |
|
|
|
|
|
elif self.model is not None: |
|
|
logger.info("Generating with transformer model...") |
|
|
|
|
|
|
|
|
if self.tokenizer: |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate(**inputs, max_length=100) |
|
|
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return { |
|
|
"generated_text": result, |
|
|
"message": "Text generated successfully" |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"message": "Model loaded but tokenizer not available", |
|
|
"prompt": prompt |
|
|
} |
|
|
else: |
|
|
|
|
|
logger.warning("Running in test mode - no actual generation") |
|
|
return { |
|
|
"message": "Handler is running in test mode", |
|
|
"prompt": prompt, |
|
|
"parameters": parameters, |
|
|
"status": "test_mode" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during inference: {e}") |
|
|
import traceback |
|
|
return { |
|
|
"error": str(e), |
|
|
"traceback": traceback.format_exc(), |
|
|
"message": "Error during generation" |
|
|
} |