Spaces:
Runtime error
Runtime error
shivamkunkolikar commited on
Commit ·
f647f94
1
Parent(s): 8475a1f
gradio update
Browse files- app.py +9 -0
- handler.py +38 -0
- inference.py +102 -0
- inpainting_model_best.pth +3 -0
- model.py +57 -0
- requirements.txt +5 -0
app.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from inference import get_output
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
def predict(image, mask):
|
| 6 |
+
return get_output(image, mask)
|
| 7 |
+
|
| 8 |
+
iface = gr.Interface(fn=predict, inputs=["image", "image"], outputs="image")
|
| 9 |
+
iface.launch()
|
handler.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model import UNetInpaint
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import io
|
| 6 |
+
|
| 7 |
+
class EndpointHandler:
|
| 8 |
+
def __init__(self, path=""):
|
| 9 |
+
self.model = UNetInpaint()
|
| 10 |
+
self.model.load_state_dict(torch.load("model.pth", map_location="cpu"))
|
| 11 |
+
self.model.eval()
|
| 12 |
+
|
| 13 |
+
def __call__(self, data):
|
| 14 |
+
image_bytes = data.get("image")
|
| 15 |
+
mask_bytes = data.get("mask")
|
| 16 |
+
|
| 17 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 18 |
+
mask = Image.open(io.BytesIO(mask_bytes)).convert("L")
|
| 19 |
+
|
| 20 |
+
image_np = np.array(image).astype(np.float32) / 255.0
|
| 21 |
+
mask_np = np.array(mask).astype(np.float32) / 255.0
|
| 22 |
+
mask_np = (mask_np > 0.5).astype(np.float32)
|
| 23 |
+
|
| 24 |
+
mask_np = np.expand_dims(mask_np, axis=-1)
|
| 25 |
+
image_np = np.transpose(image_np, (2, 0, 1))
|
| 26 |
+
mask_np = np.transpose(mask_np, (2, 0, 1))
|
| 27 |
+
|
| 28 |
+
image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np))
|
| 29 |
+
input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0)
|
| 30 |
+
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
output = self.model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
|
| 33 |
+
output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
|
| 34 |
+
result = Image.fromarray(output)
|
| 35 |
+
|
| 36 |
+
buf = io.BytesIO()
|
| 37 |
+
result.save(buf, format="PNG")
|
| 38 |
+
return {"image": buf.getvalue()}
|
inference.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import torch
|
| 2 |
+
# import numpy as np
|
| 3 |
+
# from PIL import Image
|
| 4 |
+
# from model import UNetInpaint
|
| 5 |
+
# import io
|
| 6 |
+
|
| 7 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
# model = UNetInpaint(input_channels=4, output_channels=3)
|
| 9 |
+
# model.load_state_dict(torch.load("inpainting_model_best.pth", map_location=device))
|
| 10 |
+
# model.eval().to(device)
|
| 11 |
+
|
| 12 |
+
# def preprocess(image: Image.Image, mask: Image.Image):
|
| 13 |
+
# image = image.convert("RGB").resize((256, 256)) # Resize if needed
|
| 14 |
+
# mask = mask.convert("L").resize((256, 256))
|
| 15 |
+
|
| 16 |
+
# image = np.array(image).astype(np.float32) / 255.0
|
| 17 |
+
# mask = np.array(mask).astype(np.float32) / 255.0
|
| 18 |
+
# mask = (mask > 0.5).astype(np.float32)
|
| 19 |
+
# mask = np.expand_dims(mask, axis=-1)
|
| 20 |
+
|
| 21 |
+
# image = np.transpose(image, (2, 0, 1))
|
| 22 |
+
# mask = np.transpose(mask, (2, 0, 1))
|
| 23 |
+
|
| 24 |
+
# image = torch.tensor(image, dtype=torch.float32)
|
| 25 |
+
# mask = torch.tensor(mask, dtype=torch.float32)
|
| 26 |
+
|
| 27 |
+
# image = image * (1.0 - mask)
|
| 28 |
+
# input_tensor = torch.cat([image, mask], dim=0).unsqueeze(0).to(device)
|
| 29 |
+
|
| 30 |
+
# return input_tensor
|
| 31 |
+
|
| 32 |
+
# def predict(image: Image.Image, mask: Image.Image) -> Image.Image:
|
| 33 |
+
# input_tensor = preprocess(image, mask)
|
| 34 |
+
|
| 35 |
+
# with torch.no_grad():
|
| 36 |
+
# output = model(input_tensor).squeeze(0).cpu().numpy().transpose(1, 2, 0)
|
| 37 |
+
# output = np.clip(output, 0, 1)
|
| 38 |
+
# out_img = Image.fromarray((output * 255).astype(np.uint8), mode="RGB")
|
| 39 |
+
# return out_img
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# from model import UNetInpaint
|
| 43 |
+
# from PIL import Image
|
| 44 |
+
# import torch
|
| 45 |
+
# import numpy as np
|
| 46 |
+
# import io
|
| 47 |
+
|
| 48 |
+
# model = UNetInpaint()
|
| 49 |
+
# model.load_state_dict(torch.load("model.pth", map_location="cpu"))
|
| 50 |
+
# model.eval()
|
| 51 |
+
|
| 52 |
+
# def predict(image: bytes, mask: bytes) -> bytes:
|
| 53 |
+
# image = Image.open(io.BytesIO(image)).convert("RGB")
|
| 54 |
+
# mask = Image.open(io.BytesIO(mask)).convert("L")
|
| 55 |
+
|
| 56 |
+
# # preprocess
|
| 57 |
+
# image_np = np.array(image).astype(np.float32) / 255.0
|
| 58 |
+
# mask_np = np.array(mask).astype(np.float32) / 255.0
|
| 59 |
+
# mask_np = (mask_np > 0.5).astype(np.float32)
|
| 60 |
+
|
| 61 |
+
# mask_np = np.expand_dims(mask_np, axis=-1)
|
| 62 |
+
# image_np = np.transpose(image_np, (2, 0, 1))
|
| 63 |
+
# mask_np = np.transpose(mask_np, (2, 0, 1))
|
| 64 |
+
|
| 65 |
+
# image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np))
|
| 66 |
+
# input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0)
|
| 67 |
+
|
| 68 |
+
# with torch.no_grad():
|
| 69 |
+
# output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
|
| 70 |
+
# output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
|
| 71 |
+
# result = Image.fromarray(output)
|
| 72 |
+
|
| 73 |
+
# buffer = io.BytesIO()
|
| 74 |
+
# result.save(buffer, format="PNG")
|
| 75 |
+
# return buffer.getvalue()
|
| 76 |
+
|
| 77 |
+
from model import UNetInpaint
|
| 78 |
+
import torch
|
| 79 |
+
import numpy as np
|
| 80 |
+
from PIL import Image
|
| 81 |
+
|
| 82 |
+
model = UNetInpaint()
|
| 83 |
+
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
|
| 84 |
+
model.eval()
|
| 85 |
+
|
| 86 |
+
def get_output(image_pil, mask_pil):
|
| 87 |
+
image = np.array(image_pil).astype(np.float32) / 255.0
|
| 88 |
+
mask = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0
|
| 89 |
+
mask = (mask > 0.5).astype(np.float32)
|
| 90 |
+
|
| 91 |
+
mask = np.expand_dims(mask, axis=-1)
|
| 92 |
+
image = np.transpose(image, (2, 0, 1))
|
| 93 |
+
mask = np.transpose(mask, (2, 0, 1))
|
| 94 |
+
|
| 95 |
+
image = torch.tensor(image) * (1 - torch.tensor(mask))
|
| 96 |
+
input_tensor = torch.cat([image, torch.tensor(mask)], dim=0).unsqueeze(0)
|
| 97 |
+
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
|
| 100 |
+
output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
|
| 101 |
+
|
| 102 |
+
return Image.fromarray(output)
|
inpainting_model_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81acd2ff348ef0593d74d7f40bebc514ad260020bf4d75827ede764c78dbc152
|
| 3 |
+
size 138201906
|
model.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class UNetInpaint(nn.Module):
|
| 5 |
+
def __init__(self, input_channels=4, output_channels=3):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.enc1 = self.conv_block(input_channels, 64)
|
| 8 |
+
self.enc2 = self.conv_block(64, 128)
|
| 9 |
+
self.enc3 = self.conv_block(128, 256)
|
| 10 |
+
self.enc4 = self.conv_block(256, 512)
|
| 11 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 12 |
+
self.bottleneck = self.conv_block(512, 1024)
|
| 13 |
+
self.upconv4 = self.up_conv_block(1024, 512)
|
| 14 |
+
self.dec4 = self.conv_block(1024, 512)
|
| 15 |
+
self.upconv3 = self.up_conv_block(512, 256)
|
| 16 |
+
self.dec3 = self.conv_block(512, 256)
|
| 17 |
+
self.upconv2 = self.up_conv_block(256, 128)
|
| 18 |
+
self.dec2 = self.conv_block(256, 128)
|
| 19 |
+
self.upconv1 = self.up_conv_block(128, 64)
|
| 20 |
+
self.dec1 = self.conv_block(128, 64)
|
| 21 |
+
self.out_conv = nn.Conv2d(64, output_channels, 1)
|
| 22 |
+
self.final_activation = nn.Sigmoid()
|
| 23 |
+
|
| 24 |
+
def conv_block(self, in_channels, out_channels):
|
| 25 |
+
return nn.Sequential(
|
| 26 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
|
| 27 |
+
nn.BatchNorm2d(out_channels),
|
| 28 |
+
nn.ReLU(inplace=True),
|
| 29 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
| 30 |
+
nn.BatchNorm2d(out_channels),
|
| 31 |
+
nn.ReLU(inplace=True)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def up_conv_block(self, in_channels, out_channels):
|
| 35 |
+
return nn.Sequential(
|
| 36 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
| 37 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
|
| 38 |
+
nn.BatchNorm2d(out_channels),
|
| 39 |
+
nn.ReLU(inplace=True)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
e1 = self.enc1(x)
|
| 44 |
+
e2 = self.enc2(self.pool(e1))
|
| 45 |
+
e3 = self.enc3(self.pool(e2))
|
| 46 |
+
e4 = self.enc4(self.pool(e3))
|
| 47 |
+
b = self.bottleneck(self.pool(e4))
|
| 48 |
+
d4 = self.upconv4(b)
|
| 49 |
+
d4 = self.dec4(torch.cat([d4, e4], dim=1))
|
| 50 |
+
d3 = self.upconv3(d4)
|
| 51 |
+
d3 = self.dec3(torch.cat([d3, e3], dim=1))
|
| 52 |
+
d2 = self.upconv2(d3)
|
| 53 |
+
d2 = self.dec2(torch.cat([d2, e2], dim=1))
|
| 54 |
+
d1 = self.upconv1(d2)
|
| 55 |
+
d1 = self.dec1(torch.cat([d1, e1], dim=1))
|
| 56 |
+
out = self.out_conv(d1)
|
| 57 |
+
return self.final_activation(out)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
gradio
|
| 4 |
+
numpy
|
| 5 |
+
pillow
|