shivamkunkolikar commited on
Commit
f647f94
·
1 Parent(s): 8475a1f

gradio update

Browse files
Files changed (6) hide show
  1. app.py +9 -0
  2. handler.py +38 -0
  3. inference.py +102 -0
  4. inpainting_model_best.pth +3 -0
  5. model.py +57 -0
  6. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import get_output
3
+ from PIL import Image
4
+
5
+ def predict(image, mask):
6
+ return get_output(image, mask)
7
+
8
+ iface = gr.Interface(fn=predict, inputs=["image", "image"], outputs="image")
9
+ iface.launch()
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import UNetInpaint
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+ import io
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ self.model = UNetInpaint()
10
+ self.model.load_state_dict(torch.load("model.pth", map_location="cpu"))
11
+ self.model.eval()
12
+
13
+ def __call__(self, data):
14
+ image_bytes = data.get("image")
15
+ mask_bytes = data.get("mask")
16
+
17
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
18
+ mask = Image.open(io.BytesIO(mask_bytes)).convert("L")
19
+
20
+ image_np = np.array(image).astype(np.float32) / 255.0
21
+ mask_np = np.array(mask).astype(np.float32) / 255.0
22
+ mask_np = (mask_np > 0.5).astype(np.float32)
23
+
24
+ mask_np = np.expand_dims(mask_np, axis=-1)
25
+ image_np = np.transpose(image_np, (2, 0, 1))
26
+ mask_np = np.transpose(mask_np, (2, 0, 1))
27
+
28
+ image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np))
29
+ input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0)
30
+
31
+ with torch.no_grad():
32
+ output = self.model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
33
+ output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
34
+ result = Image.fromarray(output)
35
+
36
+ buf = io.BytesIO()
37
+ result.save(buf, format="PNG")
38
+ return {"image": buf.getvalue()}
inference.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import numpy as np
3
+ # from PIL import Image
4
+ # from model import UNetInpaint
5
+ # import io
6
+
7
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ # model = UNetInpaint(input_channels=4, output_channels=3)
9
+ # model.load_state_dict(torch.load("inpainting_model_best.pth", map_location=device))
10
+ # model.eval().to(device)
11
+
12
+ # def preprocess(image: Image.Image, mask: Image.Image):
13
+ # image = image.convert("RGB").resize((256, 256)) # Resize if needed
14
+ # mask = mask.convert("L").resize((256, 256))
15
+
16
+ # image = np.array(image).astype(np.float32) / 255.0
17
+ # mask = np.array(mask).astype(np.float32) / 255.0
18
+ # mask = (mask > 0.5).astype(np.float32)
19
+ # mask = np.expand_dims(mask, axis=-1)
20
+
21
+ # image = np.transpose(image, (2, 0, 1))
22
+ # mask = np.transpose(mask, (2, 0, 1))
23
+
24
+ # image = torch.tensor(image, dtype=torch.float32)
25
+ # mask = torch.tensor(mask, dtype=torch.float32)
26
+
27
+ # image = image * (1.0 - mask)
28
+ # input_tensor = torch.cat([image, mask], dim=0).unsqueeze(0).to(device)
29
+
30
+ # return input_tensor
31
+
32
+ # def predict(image: Image.Image, mask: Image.Image) -> Image.Image:
33
+ # input_tensor = preprocess(image, mask)
34
+
35
+ # with torch.no_grad():
36
+ # output = model(input_tensor).squeeze(0).cpu().numpy().transpose(1, 2, 0)
37
+ # output = np.clip(output, 0, 1)
38
+ # out_img = Image.fromarray((output * 255).astype(np.uint8), mode="RGB")
39
+ # return out_img
40
+
41
+
42
+ # from model import UNetInpaint
43
+ # from PIL import Image
44
+ # import torch
45
+ # import numpy as np
46
+ # import io
47
+
48
+ # model = UNetInpaint()
49
+ # model.load_state_dict(torch.load("model.pth", map_location="cpu"))
50
+ # model.eval()
51
+
52
+ # def predict(image: bytes, mask: bytes) -> bytes:
53
+ # image = Image.open(io.BytesIO(image)).convert("RGB")
54
+ # mask = Image.open(io.BytesIO(mask)).convert("L")
55
+
56
+ # # preprocess
57
+ # image_np = np.array(image).astype(np.float32) / 255.0
58
+ # mask_np = np.array(mask).astype(np.float32) / 255.0
59
+ # mask_np = (mask_np > 0.5).astype(np.float32)
60
+
61
+ # mask_np = np.expand_dims(mask_np, axis=-1)
62
+ # image_np = np.transpose(image_np, (2, 0, 1))
63
+ # mask_np = np.transpose(mask_np, (2, 0, 1))
64
+
65
+ # image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np))
66
+ # input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0)
67
+
68
+ # with torch.no_grad():
69
+ # output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
70
+ # output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
71
+ # result = Image.fromarray(output)
72
+
73
+ # buffer = io.BytesIO()
74
+ # result.save(buffer, format="PNG")
75
+ # return buffer.getvalue()
76
+
77
+ from model import UNetInpaint
78
+ import torch
79
+ import numpy as np
80
+ from PIL import Image
81
+
82
+ model = UNetInpaint()
83
+ model.load_state_dict(torch.load("model.pth", map_location="cpu"))
84
+ model.eval()
85
+
86
+ def get_output(image_pil, mask_pil):
87
+ image = np.array(image_pil).astype(np.float32) / 255.0
88
+ mask = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0
89
+ mask = (mask > 0.5).astype(np.float32)
90
+
91
+ mask = np.expand_dims(mask, axis=-1)
92
+ image = np.transpose(image, (2, 0, 1))
93
+ mask = np.transpose(mask, (2, 0, 1))
94
+
95
+ image = torch.tensor(image) * (1 - torch.tensor(mask))
96
+ input_tensor = torch.cat([image, torch.tensor(mask)], dim=0).unsqueeze(0)
97
+
98
+ with torch.no_grad():
99
+ output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
100
+ output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
101
+
102
+ return Image.fromarray(output)
inpainting_model_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81acd2ff348ef0593d74d7f40bebc514ad260020bf4d75827ede764c78dbc152
3
+ size 138201906
model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class UNetInpaint(nn.Module):
5
+ def __init__(self, input_channels=4, output_channels=3):
6
+ super().__init__()
7
+ self.enc1 = self.conv_block(input_channels, 64)
8
+ self.enc2 = self.conv_block(64, 128)
9
+ self.enc3 = self.conv_block(128, 256)
10
+ self.enc4 = self.conv_block(256, 512)
11
+ self.pool = nn.MaxPool2d(2, 2)
12
+ self.bottleneck = self.conv_block(512, 1024)
13
+ self.upconv4 = self.up_conv_block(1024, 512)
14
+ self.dec4 = self.conv_block(1024, 512)
15
+ self.upconv3 = self.up_conv_block(512, 256)
16
+ self.dec3 = self.conv_block(512, 256)
17
+ self.upconv2 = self.up_conv_block(256, 128)
18
+ self.dec2 = self.conv_block(256, 128)
19
+ self.upconv1 = self.up_conv_block(128, 64)
20
+ self.dec1 = self.conv_block(128, 64)
21
+ self.out_conv = nn.Conv2d(64, output_channels, 1)
22
+ self.final_activation = nn.Sigmoid()
23
+
24
+ def conv_block(self, in_channels, out_channels):
25
+ return nn.Sequential(
26
+ nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
27
+ nn.BatchNorm2d(out_channels),
28
+ nn.ReLU(inplace=True),
29
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU(inplace=True)
32
+ )
33
+
34
+ def up_conv_block(self, in_channels, out_channels):
35
+ return nn.Sequential(
36
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
37
+ nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
38
+ nn.BatchNorm2d(out_channels),
39
+ nn.ReLU(inplace=True)
40
+ )
41
+
42
+ def forward(self, x):
43
+ e1 = self.enc1(x)
44
+ e2 = self.enc2(self.pool(e1))
45
+ e3 = self.enc3(self.pool(e2))
46
+ e4 = self.enc4(self.pool(e3))
47
+ b = self.bottleneck(self.pool(e4))
48
+ d4 = self.upconv4(b)
49
+ d4 = self.dec4(torch.cat([d4, e4], dim=1))
50
+ d3 = self.upconv3(d4)
51
+ d3 = self.dec3(torch.cat([d3, e3], dim=1))
52
+ d2 = self.upconv2(d3)
53
+ d2 = self.dec2(torch.cat([d2, e2], dim=1))
54
+ d1 = self.upconv1(d2)
55
+ d1 = self.dec1(torch.cat([d1, e1], dim=1))
56
+ out = self.out_conv(d1)
57
+ return self.final_activation(out)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ numpy
5
+ pillow