Update BEN2.py
Browse files
BEN2.py
CHANGED
|
@@ -921,6 +921,8 @@ class BEN_Base(nn.Module):
|
|
| 921 |
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
| 922 |
m.inplace = True
|
| 923 |
|
|
|
|
|
|
|
| 924 |
@torch.inference_mode()
|
| 925 |
@torch.autocast(device_type="cuda",dtype=torch.float16)
|
| 926 |
def forward(self, x):
|
|
@@ -1008,7 +1010,13 @@ class BEN_Base(nn.Module):
|
|
| 1008 |
# image = ImageOps.exif_transpose(image)
|
| 1009 |
if isinstance(image, Image.Image):
|
| 1010 |
image, h, w,original_image = rgb_loader_refiner(image)
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
with torch.no_grad():
|
| 1013 |
res = self.forward(img_tensor)
|
| 1014 |
|
|
@@ -1035,7 +1043,11 @@ class BEN_Base(nn.Module):
|
|
| 1035 |
foregrounds = []
|
| 1036 |
for batch in image:
|
| 1037 |
image, h, w,original_image = rgb_loader_refiner(batch)
|
| 1038 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1039 |
|
| 1040 |
with torch.no_grad():
|
| 1041 |
res = self.forward(img_tensor)
|
|
@@ -1058,6 +1070,9 @@ class BEN_Base(nn.Module):
|
|
| 1058 |
|
| 1059 |
return foregrounds
|
| 1060 |
|
|
|
|
|
|
|
|
|
|
| 1061 |
def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
|
| 1062 |
|
| 1063 |
"""
|
|
@@ -1196,6 +1211,13 @@ img_transform = transforms.Compose([
|
|
| 1196 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 1197 |
])
|
| 1198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1199 |
|
| 1200 |
|
| 1201 |
|
|
|
|
| 921 |
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
| 922 |
m.inplace = True
|
| 923 |
|
| 924 |
+
|
| 925 |
+
|
| 926 |
@torch.inference_mode()
|
| 927 |
@torch.autocast(device_type="cuda",dtype=torch.float16)
|
| 928 |
def forward(self, x):
|
|
|
|
| 1010 |
# image = ImageOps.exif_transpose(image)
|
| 1011 |
if isinstance(image, Image.Image):
|
| 1012 |
image, h, w,original_image = rgb_loader_refiner(image)
|
| 1013 |
+
if torch.cuda.is_available():
|
| 1014 |
+
|
| 1015 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
| 1016 |
+
else:
|
| 1017 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
with torch.no_grad():
|
| 1021 |
res = self.forward(img_tensor)
|
| 1022 |
|
|
|
|
| 1043 |
foregrounds = []
|
| 1044 |
for batch in image:
|
| 1045 |
image, h, w,original_image = rgb_loader_refiner(batch)
|
| 1046 |
+
if torch.cuda.is_available():
|
| 1047 |
+
|
| 1048 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
| 1049 |
+
else:
|
| 1050 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
| 1051 |
|
| 1052 |
with torch.no_grad():
|
| 1053 |
res = self.forward(img_tensor)
|
|
|
|
| 1070 |
|
| 1071 |
return foregrounds
|
| 1072 |
|
| 1073 |
+
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
|
| 1077 |
|
| 1078 |
"""
|
|
|
|
| 1211 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 1212 |
])
|
| 1213 |
|
| 1214 |
+
img_transform32 = transforms.Compose([
|
| 1215 |
+
transforms.ToTensor(),
|
| 1216 |
+
transforms.ConvertImageDtype(torch.float32),
|
| 1217 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 1218 |
+
])
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
|
| 1222 |
|
| 1223 |
|