root commited on
Commit
5a9c05f
·
1 Parent(s): 6ae7fd3

remove background remover for now

Browse files
Files changed (2) hide show
  1. handler.py +1 -62
  2. requirements.txt +1 -5
handler.py CHANGED
@@ -10,7 +10,6 @@ from omegaconf import OmegaConf
10
  from transformers import CLIPVisionModelWithProjection
11
  import cv2
12
  import os
13
- from backgroundremover.bg import remove as remove_bg
14
  from src.models.pose_guider import PoseGuider
15
  from src.models.unet_2d_condition import UNet2DConditionModel
16
  from src.models.unet_3d import UNet3DConditionModel
@@ -133,55 +132,6 @@ class EndpointHandler():
133
 
134
  return os.path.join(os.getcwd(), output_path)
135
 
136
-
137
- def remove_bg_from_image(self, image_data):
138
- model_name = "u2net" # Choose your preferred model: "u2net", "u2net_human_seg", "u2netp"
139
- processed_image_data = remove_bg(
140
- image_data,
141
- model_name=model_name,
142
- alpha_matting=True,
143
- alpha_matting_foreground_threshold=240,
144
- alpha_matting_background_threshold=10,
145
- alpha_matting_erode_structure_size=10,
146
- alpha_matting_base_size=1000
147
- )
148
- return processed_image_data
149
-
150
- def _remove_background(self, input_path, output_path):
151
- cap = cv2.VideoCapture(input_path)
152
- if not cap.isOpened():
153
- raise IOError(f"Error opening video file {input_path}")
154
-
155
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
156
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
157
- fps = int(cap.get(cv2.CAP_PROP_FPS))
158
-
159
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
160
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
161
-
162
- frame_count = 0
163
- while cap.isOpened():
164
- ret, frame = cap.read()
165
- if not ret:
166
- break
167
-
168
- frame_count += 1
169
- pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
170
- frame_data = BytesIO()
171
- pil_frame.save(frame_data, format="PNG")
172
- frame_data = frame_data.getvalue()
173
- processed_frame_data = self.remove_bg_from_image(frame_data)
174
- processed_pil_frame = Image.open(BytesIO(processed_frame_data))
175
- processed_frame = cv2.cvtColor(np.array(processed_pil_frame), cv2.COLOR_RGB2BGR)
176
-
177
- out.write(processed_frame)
178
-
179
- cap.release()
180
- out.release()
181
-
182
- if frame_count == 0:
183
- raise IOError(f"No frames processed. Error with video file {input_path}")
184
-
185
  def __call__(self, data: Any) -> Dict[str, str]:
186
  inputs = data.get("inputs", {})
187
  ref_image_base64 = inputs.get("ref_image", "")
@@ -225,21 +175,10 @@ class EndpointHandler():
225
  cropped_face = self._crop_face(ref_image, save_path=cropped_face_path)
226
 
227
  # Perform face swapping
228
- print(f"Starting face swap with cropped face: {cropped_face_path} and animation: {animation_path}")
229
  final_video_path = self._swap_face(cropped_face, animation_path)
230
- print(f"Face swap completed. Final video path: {final_video_path}")
231
-
232
- # Ensure the output file exists before trying to open it
233
- if not os.path.exists(final_video_path):
234
- raise FileNotFoundError(f"Expected output file not found: {final_video_path}")
235
-
236
- # Remove the background from the final video
237
- bg_removed_video_path = os.path.join(save_dir, "bg_removed_output.mp4")
238
- self._remove_background(final_video_path, bg_removed_video_path)
239
- print(f"Background removal completed. Output saved to: {bg_removed_video_path}")
240
 
241
  # Encode the final video in base64
242
- with open(bg_removed_video_path, "rb") as video_file:
243
  video_base64 = base64.b64encode(video_file.read()).decode("utf-8")
244
 
245
  torch.cuda.empty_cache()
 
10
  from transformers import CLIPVisionModelWithProjection
11
  import cv2
12
  import os
 
13
  from src.models.pose_guider import PoseGuider
14
  from src.models.unet_2d_condition import UNet2DConditionModel
15
  from src.models.unet_3d import UNet3DConditionModel
 
132
 
133
  return os.path.join(os.getcwd(), output_path)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def __call__(self, data: Any) -> Dict[str, str]:
136
  inputs = data.get("inputs", {})
137
  ref_image_base64 = inputs.get("ref_image", "")
 
175
  cropped_face = self._crop_face(ref_image, save_path=cropped_face_path)
176
 
177
  # Perform face swapping
 
178
  final_video_path = self._swap_face(cropped_face, animation_path)
 
 
 
 
 
 
 
 
 
 
179
 
180
  # Encode the final video in base64
181
+ with open(final_video_path, "rb") as video_file:
182
  video_base64 = base64.b64encode(video_file.read()).decode("utf-8")
183
 
184
  torch.cuda.empty_cache()
requirements.txt CHANGED
@@ -32,8 +32,4 @@ omegaconf==2.2.3
32
 
33
  # Face swap related dependencies
34
  facenet-pytorch==2.5.2
35
- dlib==19.22.0
36
-
37
-
38
- # Background removal
39
- backgroundremover
 
32
 
33
  # Face swap related dependencies
34
  facenet-pytorch==2.5.2
35
+ dlib==19.22.0