root commited on
Commit
fee163c
·
1 Parent(s): 1868417

indentation

Browse files
Files changed (1) hide show
  1. handler.py +61 -61
handler.py CHANGED
@@ -148,64 +148,64 @@ class EndpointHandler():
148
 
149
  return os.path.join(os.getcwd(), output_path)
150
 
151
- def __call__(self, data: Any) -> Dict[str, str]:
152
- inputs = data.get("inputs", {})
153
- ref_image_base64 = inputs.get("ref_image", "")
154
- pose_video_path = inputs.get("pose_video_path", "")
155
- width = inputs.get("width", 512)
156
- height = inputs.get("height", 768)
157
- length = inputs.get("length", 24)
158
- num_inference_steps = inputs.get("num_inference_steps", 25)
159
- cfg = inputs.get("cfg", 3.5)
160
- seed = inputs.get("seed", 123)
161
-
162
- ref_image = Image.open(BytesIO(base64.b64decode(ref_image_base64)))
163
-
164
- # Get the base directory of the current file
165
- base_dir = os.path.dirname(os.path.abspath(__file__))
166
-
167
- # Update pose_video_path to use the base directory
168
- pose_video_path = os.path.join(base_dir, pose_video_path)
169
-
170
- if not os.path.exists(pose_video_path):
171
- raise FileNotFoundError(f"The pose video was not found at: {pose_video_path}")
172
-
173
- torch.manual_seed(seed)
174
- pose_images = read_frames(pose_video_path)
175
- src_fps = get_fps(pose_video_path)
176
-
177
- pose_list = []
178
- total_length = min(length, len(pose_images))
179
- for pose_image_pil in pose_images[:total_length]:
180
- pose_list.append(pose_image_pil)
181
-
182
- video = self.pipeline(
183
- ref_image,
184
- pose_list,
185
- width=width,
186
- height=height,
187
- video_length=total_length,
188
- num_inference_steps=num_inference_steps,
189
- guidance_scale=cfg
190
- ).videos
191
-
192
- save_dir = os.path.join(base_dir, "output", "gradio")
193
- if not os.path.exists(save_dir):
194
- os.makedirs(save_dir, exist_ok=True)
195
- animation_path = os.path.join(save_dir, "animation_output.mp4")
196
- save_videos_grid(video, animation_path, n_rows=1, fps=src_fps)
197
-
198
- # Crop the face from the reference image and save it
199
- cropped_face_path = os.path.join(save_dir, "cropped_face.jpg")
200
- cropped_face = self._crop_face(ref_image, save_path=cropped_face_path)
201
-
202
- # Perform face swapping
203
- final_video_path = self._swap_face(cropped_face, animation_path)
204
-
205
- # Encode the final video in base64
206
- with open(final_video_path, "rb") as video_file:
207
- video_base64 = base64.b64encode(video_file.read()).decode("utf-8")
208
-
209
- torch.cuda.empty_cache()
210
-
211
- return {"video": video_base64}
 
148
 
149
  return os.path.join(os.getcwd(), output_path)
150
 
151
+ def __call__(self, data: Any) -> Dict[str, str]:
152
+ inputs = data.get("inputs", {})
153
+ ref_image_base64 = inputs.get("ref_image", "")
154
+ pose_video_path = inputs.get("pose_video_path", "")
155
+ width = inputs.get("width", 512)
156
+ height = inputs.get("height", 768)
157
+ length = inputs.get("length", 24)
158
+ num_inference_steps = inputs.get("num_inference_steps", 25)
159
+ cfg = inputs.get("cfg", 3.5)
160
+ seed = inputs.get("seed", 123)
161
+
162
+ ref_image = Image.open(BytesIO(base64.b64decode(ref_image_base64)))
163
+
164
+ # Get the base directory of the current file
165
+ base_dir = os.path.dirname(os.path.abspath(__file__))
166
+
167
+ # Update pose_video_path to use the base directory
168
+ pose_video_path = os.path.join(base_dir, pose_video_path)
169
+
170
+ if not os.path.exists(pose_video_path):
171
+ raise FileNotFoundError(f"The pose video was not found at: {pose_video_path}")
172
+
173
+ torch.manual_seed(seed)
174
+ pose_images = read_frames(pose_video_path)
175
+ src_fps = get_fps(pose_video_path)
176
+
177
+ pose_list = []
178
+ total_length = min(length, len(pose_images))
179
+ for pose_image_pil in pose_images[:total_length]:
180
+ pose_list.append(pose_image_pil)
181
+
182
+ video = self.pipeline(
183
+ ref_image,
184
+ pose_list,
185
+ width=width,
186
+ height=height,
187
+ video_length=total_length,
188
+ num_inference_steps=num_inference_steps,
189
+ guidance_scale=cfg
190
+ ).videos
191
+
192
+ save_dir = os.path.join(base_dir, "output", "gradio")
193
+ if not os.path.exists(save_dir):
194
+ os.makedirs(save_dir, exist_ok=True)
195
+ animation_path = os.path.join(save_dir, "animation_output.mp4")
196
+ save_videos_grid(video, animation_path, n_rows=1, fps=src_fps)
197
+
198
+ # Crop the face from the reference image and save it
199
+ cropped_face_path = os.path.join(save_dir, "cropped_face.jpg")
200
+ cropped_face = self._crop_face(ref_image, save_path=cropped_face_path)
201
+
202
+ # Perform face swapping
203
+ final_video_path = self._swap_face(cropped_face, animation_path)
204
+
205
+ # Encode the final video in base64
206
+ with open(final_video_path, "rb") as video_file:
207
+ video_base64 = base64.b64encode(video_file.read()).decode("utf-8")
208
+
209
+ torch.cuda.empty_cache()
210
+
211
+ return {"video": video_base64}