ajwestfield commited on
Commit
ab4557b
·
0 Parent(s):

Add custom handler for MeiGen-MultiTalk Inference Endpoint

Browse files
Files changed (3) hide show
  1. README.md +83 -0
  2. handler.py +242 -0
  3. requirements.txt +17 -0
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - text-to-video
5
+ - image-to-video
6
+ - custom
7
+ - inference-endpoints
8
+ library_name: diffusers
9
+ ---
10
+
11
+ # MeiGen-MultiTalk Endpoint Handler
12
+
13
+ This repository contains a custom handler for deploying MeiGen-AI's MultiTalk model on Hugging Face Inference Endpoints.
14
+
15
+ ## Model Description
16
+
17
+ MeiGen-MultiTalk is an advanced model for generating audio-driven multi-person conversational videos. This handler wraps the original model to work with HF Inference Endpoints.
18
+
19
+ ## Features
20
+
21
+ - Text-to-video generation
22
+ - Image-to-video generation
23
+ - Multi-person conversation synthesis
24
+ - Support for various resolutions (480p, 720p)
25
+ - Optimized for A100 GPUs
26
+
27
+ ## Usage with Inference Endpoints
28
+
29
+ ### Recommended Configuration
30
+
31
+ - **Hardware**: GPU · A100 · 1x GPU (80 GB)
32
+ - **Autoscaling**:
33
+ - Min replicas: 0
34
+ - Max replicas: 1
35
+ - Scale to zero after: 300 seconds
36
+
37
+ ### API Example
38
+
39
+ ```python
40
+ import requests
41
+ import json
42
+ import base64
43
+
44
+ API_URL = "https://YOUR-ENDPOINT-URL.endpoints.huggingface.cloud"
45
+ headers = {
46
+ "Authorization": "Bearer YOUR_HF_TOKEN",
47
+ "Content-Type": "application/json"
48
+ }
49
+
50
+ # Text-to-video generation
51
+ data = {
52
+ "inputs": {
53
+ "prompt": "A person giving a presentation"
54
+ },
55
+ "parameters": {
56
+ "num_frames": 16,
57
+ "height": 480,
58
+ "width": 640,
59
+ "num_inference_steps": 25,
60
+ "guidance_scale": 7.5
61
+ }
62
+ }
63
+
64
+ response = requests.post(API_URL, headers=headers, json=data)
65
+ result = response.json()
66
+ ```
67
+
68
+ ## Technical Details
69
+
70
+ The handler includes:
71
+ - Automatic model loading from MeiGen-AI/MeiGen-MultiTalk
72
+ - Memory optimization for GPU inference
73
+ - Support for both diffusion pipeline and transformer modes
74
+ - Error handling and logging
75
+ - Base64 encoding for image/video I/O
76
+
77
+ ## License
78
+
79
+ Apache 2.0 License
80
+
81
+ ## Credits
82
+
83
+ Based on the original [MeiGen-AI/MeiGen-MultiTalk](https://huggingface.co/MeiGen-AI/MeiGen-MultiTalk) model.
handler.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import json
5
+ import base64
6
+ import io
7
+ from typing import Dict, Any, List
8
+ from PIL import Image
9
+ import logging
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path=""):
17
+ """
18
+ Initialize the MultiTalk model handler
19
+ This will load the actual MeiGen-AI/MeiGen-MultiTalk model
20
+ """
21
+ logger.info(f"Initializing handler with path: {path}")
22
+
23
+ # Import required libraries
24
+ try:
25
+ from diffusers import DiffusionPipeline
26
+ import torch
27
+ logger.info("Successfully imported required libraries")
28
+ except ImportError as e:
29
+ logger.error(f"Failed to import required libraries: {e}")
30
+ raise
31
+
32
+ # Set device
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ logger.info(f"Using device: {self.device}")
35
+
36
+ # Load the actual MeiGen-MultiTalk model
37
+ try:
38
+ model_id = "MeiGen-AI/MeiGen-MultiTalk"
39
+ logger.info(f"Loading model from: {model_id}")
40
+
41
+ # Try to load as a diffusion pipeline
42
+ self.pipeline = DiffusionPipeline.from_pretrained(
43
+ model_id,
44
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
+ device_map="auto",
46
+ low_cpu_mem_usage=True
47
+ )
48
+
49
+ # Enable memory optimizations
50
+ if hasattr(self.pipeline, "enable_attention_slicing"):
51
+ self.pipeline.enable_attention_slicing()
52
+ logger.info("Enabled attention slicing")
53
+
54
+ if hasattr(self.pipeline, "enable_vae_slicing"):
55
+ self.pipeline.enable_vae_slicing()
56
+ logger.info("Enabled VAE slicing")
57
+
58
+ if hasattr(self.pipeline, "enable_model_cpu_offload"):
59
+ self.pipeline.enable_model_cpu_offload()
60
+ logger.info("Enabled model CPU offload")
61
+
62
+ logger.info("Model loaded successfully")
63
+
64
+ except Exception as e:
65
+ logger.error(f"Failed to load model: {e}")
66
+ # Try alternative loading method
67
+ try:
68
+ logger.info("Attempting alternative loading method...")
69
+ from transformers import AutoModel, AutoTokenizer
70
+
71
+ self.model = AutoModel.from_pretrained(
72
+ model_id,
73
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
74
+ device_map="auto",
75
+ trust_remote_code=True
76
+ )
77
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
78
+ self.pipeline = None
79
+ logger.info("Model loaded with alternative method")
80
+
81
+ except Exception as e2:
82
+ logger.error(f"Alternative loading also failed: {e2}")
83
+ # Create a dummy model for testing
84
+ self.pipeline = None
85
+ self.model = None
86
+ logger.warning("Running in test mode without actual model")
87
+
88
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
89
+ """
90
+ Process the inference request
91
+
92
+ Args:
93
+ data: Input data containing:
94
+ - inputs: The input prompt or configuration
95
+ - parameters: Additional generation parameters
96
+
97
+ Returns:
98
+ Dict containing the generated output or error message
99
+ """
100
+ logger.info(f"Received request with data keys: {data.keys()}")
101
+
102
+ try:
103
+ # Extract inputs
104
+ inputs = data.get("inputs", "")
105
+ parameters = data.get("parameters", {})
106
+
107
+ logger.info(f"Processing inputs: {type(inputs)}")
108
+ logger.info(f"Parameters: {parameters}")
109
+
110
+ # Handle different input types
111
+ if isinstance(inputs, str):
112
+ prompt = inputs
113
+ image = None
114
+ elif isinstance(inputs, dict):
115
+ prompt = inputs.get("prompt", "A person speaking")
116
+ # Handle base64 encoded image if provided
117
+ if "image" in inputs:
118
+ try:
119
+ image_data = base64.b64decode(inputs["image"])
120
+ image = Image.open(io.BytesIO(image_data))
121
+ logger.info("Loaded input image")
122
+ except Exception as e:
123
+ logger.error(f"Failed to decode image: {e}")
124
+ image = None
125
+ else:
126
+ image = None
127
+ else:
128
+ prompt = str(inputs)
129
+ image = None
130
+
131
+ # Extract parameters with defaults
132
+ num_inference_steps = parameters.get("num_inference_steps", 25)
133
+ guidance_scale = parameters.get("guidance_scale", 7.5)
134
+ height = parameters.get("height", 480)
135
+ width = parameters.get("width", 640)
136
+ num_frames = parameters.get("num_frames", 16)
137
+
138
+ logger.info(f"Generation params: steps={num_inference_steps}, guidance={guidance_scale}, size={width}x{height}, frames={num_frames}")
139
+
140
+ # Generate output
141
+ if self.pipeline is not None:
142
+ logger.info("Generating with diffusion pipeline...")
143
+
144
+ # Prepare generation kwargs
145
+ gen_kwargs = {
146
+ "prompt": prompt,
147
+ "height": height,
148
+ "width": width,
149
+ "num_inference_steps": num_inference_steps,
150
+ "guidance_scale": guidance_scale,
151
+ }
152
+
153
+ # Add image if available
154
+ if image is not None:
155
+ gen_kwargs["image"] = image
156
+
157
+ # Add num_frames if the pipeline supports it
158
+ if "num_frames" in self.pipeline.__call__.__code__.co_varnames:
159
+ gen_kwargs["num_frames"] = num_frames
160
+
161
+ # Generate
162
+ with torch.no_grad():
163
+ result = self.pipeline(**gen_kwargs)
164
+
165
+ # Process result
166
+ if hasattr(result, "frames"):
167
+ frames = result.frames
168
+ if isinstance(frames, list) and len(frames) > 0:
169
+ # Convert frames to base64
170
+ encoded_frames = []
171
+ for frame in frames[0] if isinstance(frames[0], list) else frames:
172
+ if isinstance(frame, Image.Image):
173
+ buffered = io.BytesIO()
174
+ frame.save(buffered, format="PNG")
175
+ img_str = base64.b64encode(buffered.getvalue()).decode()
176
+ encoded_frames.append(img_str)
177
+
178
+ return {
179
+ "frames": encoded_frames,
180
+ "num_frames": len(encoded_frames),
181
+ "message": "Video generated successfully"
182
+ }
183
+ elif hasattr(result, "images"):
184
+ # Handle image output
185
+ images = result.images
186
+ encoded_images = []
187
+ for img in images:
188
+ if isinstance(img, Image.Image):
189
+ buffered = io.BytesIO()
190
+ img.save(buffered, format="PNG")
191
+ img_str = base64.b64encode(buffered.getvalue()).decode()
192
+ encoded_images.append(img_str)
193
+
194
+ return {
195
+ "images": encoded_images,
196
+ "num_images": len(encoded_images),
197
+ "message": "Images generated successfully"
198
+ }
199
+ else:
200
+ return {
201
+ "message": "Generation completed",
202
+ "prompt": prompt,
203
+ "result_type": str(type(result))
204
+ }
205
+
206
+ elif self.model is not None:
207
+ logger.info("Generating with transformer model...")
208
+
209
+ # Use transformer model
210
+ if self.tokenizer:
211
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
212
+ with torch.no_grad():
213
+ outputs = self.model.generate(**inputs, max_length=100)
214
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
215
+
216
+ return {
217
+ "generated_text": result,
218
+ "message": "Text generated successfully"
219
+ }
220
+ else:
221
+ return {
222
+ "message": "Model loaded but tokenizer not available",
223
+ "prompt": prompt
224
+ }
225
+ else:
226
+ # Test mode response
227
+ logger.warning("Running in test mode - no actual generation")
228
+ return {
229
+ "message": "Handler is running in test mode",
230
+ "prompt": prompt,
231
+ "parameters": parameters,
232
+ "status": "test_mode"
233
+ }
234
+
235
+ except Exception as e:
236
+ logger.error(f"Error during inference: {e}")
237
+ import traceback
238
+ return {
239
+ "error": str(e),
240
+ "traceback": traceback.format_exc(),
241
+ "message": "Error during generation"
242
+ }
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchvision==0.19.1
3
+ torchaudio==2.4.1
4
+ transformers>=4.44.0
5
+ diffusers>=0.31.0
6
+ accelerate>=0.34.0
7
+ xformers==0.0.28
8
+ sentencepiece
9
+ protobuf
10
+ Pillow
11
+ numpy
12
+ scipy
13
+ imageio
14
+ opencv-python-headless
15
+ librosa
16
+ soundfile
17
+ ffmpeg-python