Update MyPipe.py
Browse files
MyPipe.py
CHANGED
|
@@ -9,7 +9,7 @@ from PIL import Image
|
|
| 9 |
class RMBGPipe(Pipeline):
|
| 10 |
def __init__(self,**kwargs):
|
| 11 |
Pipeline.__init__(self,**kwargs)
|
| 12 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
self.model.to(self.device)
|
| 14 |
self.model.eval()
|
| 15 |
|
|
@@ -39,6 +39,7 @@ class RMBGPipe(Pipeline):
|
|
| 39 |
result = self.model(inputs.pop("image"))
|
| 40 |
inputs["result"] = result
|
| 41 |
return inputs
|
|
|
|
| 42 |
def postprocess(self,inputs,return_mask:bool=False ):
|
| 43 |
result = inputs.pop("result")
|
| 44 |
orig_im_size = inputs.pop("orig_im_size")
|
|
@@ -48,7 +49,7 @@ class RMBGPipe(Pipeline):
|
|
| 48 |
if return_mask ==True :
|
| 49 |
return pil_im
|
| 50 |
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
| 51 |
-
orig_image = Image.
|
| 52 |
no_bg_image.paste(orig_image, mask=pil_im)
|
| 53 |
return no_bg_image
|
| 54 |
|
|
@@ -59,10 +60,11 @@ class RMBGPipe(Pipeline):
|
|
| 59 |
im = im[:, :, np.newaxis]
|
| 60 |
# orig_im_size=im.shape[0:2]
|
| 61 |
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
| 62 |
-
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
|
| 63 |
image = torch.divide(im_tensor,255.0)
|
| 64 |
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
| 65 |
return image
|
|
|
|
| 66 |
def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
|
| 67 |
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
| 68 |
ma = torch.max(result)
|
|
|
|
| 9 |
class RMBGPipe(Pipeline):
|
| 10 |
def __init__(self,**kwargs):
|
| 11 |
Pipeline.__init__(self,**kwargs)
|
| 12 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 13 |
self.model.to(self.device)
|
| 14 |
self.model.eval()
|
| 15 |
|
|
|
|
| 39 |
result = self.model(inputs.pop("image"))
|
| 40 |
inputs["result"] = result
|
| 41 |
return inputs
|
| 42 |
+
|
| 43 |
def postprocess(self,inputs,return_mask:bool=False ):
|
| 44 |
result = inputs.pop("result")
|
| 45 |
orig_im_size = inputs.pop("orig_im_size")
|
|
|
|
| 49 |
if return_mask ==True :
|
| 50 |
return pil_im
|
| 51 |
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
| 52 |
+
orig_image = Image.fromarray(io.imread(im_path))
|
| 53 |
no_bg_image.paste(orig_image, mask=pil_im)
|
| 54 |
return no_bg_image
|
| 55 |
|
|
|
|
| 60 |
im = im[:, :, np.newaxis]
|
| 61 |
# orig_im_size=im.shape[0:2]
|
| 62 |
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
| 63 |
+
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
|
| 64 |
image = torch.divide(im_tensor,255.0)
|
| 65 |
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
| 66 |
return image
|
| 67 |
+
|
| 68 |
def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
|
| 69 |
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
| 70 |
ma = torch.max(result)
|