root
commited on
Commit
·
fee163c
1
Parent(s):
1868417
indentation
Browse files- handler.py +61 -61
handler.py
CHANGED
|
@@ -148,64 +148,64 @@ class EndpointHandler():
|
|
| 148 |
|
| 149 |
return os.path.join(os.getcwd(), output_path)
|
| 150 |
|
| 151 |
-
def __call__(self, data: Any) -> Dict[str, str]:
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
| 148 |
|
| 149 |
return os.path.join(os.getcwd(), output_path)
|
| 150 |
|
| 151 |
+
def __call__(self, data: Any) -> Dict[str, str]:
|
| 152 |
+
inputs = data.get("inputs", {})
|
| 153 |
+
ref_image_base64 = inputs.get("ref_image", "")
|
| 154 |
+
pose_video_path = inputs.get("pose_video_path", "")
|
| 155 |
+
width = inputs.get("width", 512)
|
| 156 |
+
height = inputs.get("height", 768)
|
| 157 |
+
length = inputs.get("length", 24)
|
| 158 |
+
num_inference_steps = inputs.get("num_inference_steps", 25)
|
| 159 |
+
cfg = inputs.get("cfg", 3.5)
|
| 160 |
+
seed = inputs.get("seed", 123)
|
| 161 |
+
|
| 162 |
+
ref_image = Image.open(BytesIO(base64.b64decode(ref_image_base64)))
|
| 163 |
+
|
| 164 |
+
# Get the base directory of the current file
|
| 165 |
+
base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 166 |
+
|
| 167 |
+
# Update pose_video_path to use the base directory
|
| 168 |
+
pose_video_path = os.path.join(base_dir, pose_video_path)
|
| 169 |
+
|
| 170 |
+
if not os.path.exists(pose_video_path):
|
| 171 |
+
raise FileNotFoundError(f"The pose video was not found at: {pose_video_path}")
|
| 172 |
+
|
| 173 |
+
torch.manual_seed(seed)
|
| 174 |
+
pose_images = read_frames(pose_video_path)
|
| 175 |
+
src_fps = get_fps(pose_video_path)
|
| 176 |
+
|
| 177 |
+
pose_list = []
|
| 178 |
+
total_length = min(length, len(pose_images))
|
| 179 |
+
for pose_image_pil in pose_images[:total_length]:
|
| 180 |
+
pose_list.append(pose_image_pil)
|
| 181 |
+
|
| 182 |
+
video = self.pipeline(
|
| 183 |
+
ref_image,
|
| 184 |
+
pose_list,
|
| 185 |
+
width=width,
|
| 186 |
+
height=height,
|
| 187 |
+
video_length=total_length,
|
| 188 |
+
num_inference_steps=num_inference_steps,
|
| 189 |
+
guidance_scale=cfg
|
| 190 |
+
).videos
|
| 191 |
+
|
| 192 |
+
save_dir = os.path.join(base_dir, "output", "gradio")
|
| 193 |
+
if not os.path.exists(save_dir):
|
| 194 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 195 |
+
animation_path = os.path.join(save_dir, "animation_output.mp4")
|
| 196 |
+
save_videos_grid(video, animation_path, n_rows=1, fps=src_fps)
|
| 197 |
+
|
| 198 |
+
# Crop the face from the reference image and save it
|
| 199 |
+
cropped_face_path = os.path.join(save_dir, "cropped_face.jpg")
|
| 200 |
+
cropped_face = self._crop_face(ref_image, save_path=cropped_face_path)
|
| 201 |
+
|
| 202 |
+
# Perform face swapping
|
| 203 |
+
final_video_path = self._swap_face(cropped_face, animation_path)
|
| 204 |
+
|
| 205 |
+
# Encode the final video in base64
|
| 206 |
+
with open(final_video_path, "rb") as video_file:
|
| 207 |
+
video_base64 = base64.b64encode(video_file.read()).decode("utf-8")
|
| 208 |
+
|
| 209 |
+
torch.cuda.empty_cache()
|
| 210 |
+
|
| 211 |
+
return {"video": video_base64}
|