ajwestfield commited on
Commit
2034ad0
·
1 Parent(s): d992337

Add MultiTalk custom handler for HF Inference Endpoint

Browse files
Files changed (3) hide show
  1. README.md +77 -0
  2. handler.py +139 -0
  3. requirements.txt +16 -0
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiTalk Hugging Face Endpoint Handler
2
+
3
+ This custom handler enables the MeiGen-AI/MeiGen-MultiTalk model to run on Hugging Face Inference Endpoints.
4
+
5
+ ## Setup Instructions
6
+
7
+ 1. **Create a new Inference Endpoint** on Hugging Face:
8
+ - Go to https://huggingface.co/inference-endpoints
9
+ - Click "New endpoint"
10
+
11
+ 2. **Configure the endpoint**:
12
+ - **Model repository**: `ajwestfield/multitalk-handler` (you'll need to upload this handler to your HF account)
13
+ - **Task**: Custom
14
+ - **Framework**: Custom
15
+ - **Instance type**: GPU · A100 · 1x GPU (80 GB)
16
+
17
+ 3. **Advanced Configuration**:
18
+ - **Container type**: Custom
19
+ - **Custom image**: `pytorch/pytorch:2.4.1-cuda12.1-cudnn9-runtime`
20
+ - **Autoscaling**:
21
+ - Min replicas: 0
22
+ - Max replicas: 1
23
+ - Scale to zero after: 300 seconds (5 minutes)
24
+
25
+ 4. **Environment Variables** (add these in Settings):
26
+ ```
27
+ PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
28
+ CUDA_VISIBLE_DEVICES=0
29
+ ```
30
+
31
+ ## Uploading the Handler
32
+
33
+ 1. Create a new model repository on Hugging Face:
34
+ ```bash
35
+ huggingface-cli repo create multitalk-handler --type model
36
+ ```
37
+
38
+ 2. Upload the handler files:
39
+ ```bash
40
+ cd huggingface-endpoint/multitalk-handler
41
+ git init
42
+ git add .
43
+ git commit -m "Add MultiTalk custom handler"
44
+ git remote add origin https://huggingface.co/ajwestfield/multitalk-handler
45
+ git push -u origin main
46
+ ```
47
+
48
+ ## Usage
49
+
50
+ Once deployed, you can call the endpoint with:
51
+
52
+ ```python
53
+ import requests
54
+ import json
55
+
56
+ API_URL = "https://YOUR-ENDPOINT-URL.endpoints.huggingface.cloud"
57
+ headers = {
58
+ "Authorization": "Bearer YOUR_HF_TOKEN",
59
+ "Content-Type": "application/json"
60
+ }
61
+
62
+ data = {
63
+ "inputs": {
64
+ "prompt": "A person speaking naturally",
65
+ "image": "base64_encoded_image_optional"
66
+ },
67
+ "parameters": {
68
+ "num_frames": 16,
69
+ "height": 480,
70
+ "width": 640,
71
+ "num_inference_steps": 25
72
+ }
73
+ }
74
+
75
+ response = requests.post(API_URL, headers=headers, json=data)
76
+ result = response.json()
77
+ ```
handler.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import base64
4
+ import io
5
+ from typing import Dict, Any, List
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ """
12
+ Initialize the MultiTalk model handler
13
+ """
14
+ import sys
15
+ import os
16
+
17
+ # Add error handling for missing dependencies
18
+ try:
19
+ from diffusers import DiffusionPipeline
20
+ import librosa
21
+ except ImportError as e:
22
+ print(f"Missing dependency: {e}")
23
+ print("Please ensure all requirements are installed")
24
+ raise
25
+
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"Using device: {self.device}")
28
+
29
+ # Initialize model with low VRAM mode if needed
30
+ try:
31
+ # Try to load the model
32
+ self.pipeline = DiffusionPipeline.from_pretrained(
33
+ path if path else "MeiGen-AI/MeiGen-MultiTalk",
34
+ torch_dtype=torch.float16,
35
+ device_map="auto"
36
+ )
37
+
38
+ # Enable memory efficient attention if available
39
+ if hasattr(self.pipeline, "enable_attention_slicing"):
40
+ self.pipeline.enable_attention_slicing()
41
+
42
+ if hasattr(self.pipeline, "enable_vae_slicing"):
43
+ self.pipeline.enable_vae_slicing()
44
+
45
+ print("Model loaded successfully")
46
+
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ raise
50
+
51
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
52
+ """
53
+ Process the inference request
54
+
55
+ Args:
56
+ data: Input data containing:
57
+ - inputs: The input prompt or image
58
+ - parameters: Additional generation parameters
59
+
60
+ Returns:
61
+ Dict containing the generated output
62
+ """
63
+ try:
64
+ # Extract inputs
65
+ inputs = data.get("inputs", "")
66
+ parameters = data.get("parameters", {})
67
+
68
+ # Handle different input types
69
+ if isinstance(inputs, str):
70
+ # Text prompt input
71
+ prompt = inputs
72
+ image = None
73
+ elif isinstance(inputs, dict):
74
+ prompt = inputs.get("prompt", "")
75
+ # Handle base64 encoded image if provided
76
+ if "image" in inputs:
77
+ image_data = base64.b64decode(inputs["image"])
78
+ image = Image.open(io.BytesIO(image_data))
79
+ else:
80
+ image = None
81
+ else:
82
+ prompt = str(inputs)
83
+ image = None
84
+
85
+ # Set default parameters
86
+ num_inference_steps = parameters.get("num_inference_steps", 25)
87
+ guidance_scale = parameters.get("guidance_scale", 7.5)
88
+ height = parameters.get("height", 480)
89
+ width = parameters.get("width", 640)
90
+ num_frames = parameters.get("num_frames", 16)
91
+
92
+ # Generate video
93
+ with torch.no_grad():
94
+ if hasattr(self.pipeline, "__call__"):
95
+ result = self.pipeline(
96
+ prompt=prompt,
97
+ image=image,
98
+ height=height,
99
+ width=width,
100
+ num_frames=num_frames,
101
+ num_inference_steps=num_inference_steps,
102
+ guidance_scale=guidance_scale
103
+ )
104
+
105
+ # Handle the output
106
+ if hasattr(result, "frames"):
107
+ # Convert frames to base64 encoded video or images
108
+ frames = result.frames[0] if len(result.frames) > 0 else []
109
+
110
+ # Convert frames to base64 encoded images
111
+ encoded_frames = []
112
+ for frame in frames:
113
+ if isinstance(frame, Image.Image):
114
+ buffered = io.BytesIO()
115
+ frame.save(buffered, format="PNG")
116
+ img_str = base64.b64encode(buffered.getvalue()).decode()
117
+ encoded_frames.append(img_str)
118
+
119
+ return {
120
+ "frames": encoded_frames,
121
+ "num_frames": len(encoded_frames),
122
+ "message": "Video generated successfully"
123
+ }
124
+ else:
125
+ return {
126
+ "error": "Model output format not recognized",
127
+ "result": str(result)
128
+ }
129
+ else:
130
+ return {
131
+ "error": "Model pipeline not properly initialized"
132
+ }
133
+
134
+ except Exception as e:
135
+ import traceback
136
+ return {
137
+ "error": str(e),
138
+ "traceback": traceback.format_exc()
139
+ }
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchvision==0.19.1
3
+ torchaudio==2.4.1
4
+ xformers==0.0.28
5
+ flash-attn==2.7.4.post1
6
+ diffusers
7
+ transformers
8
+ accelerate
9
+ librosa
10
+ ffmpeg-python
11
+ opencv-python-headless
12
+ numpy
13
+ Pillow
14
+ scipy
15
+ imageio
16
+ moviepy