object-eraser-model / inference.py
shivamkunkolikar
edit May15 5:09PM
0af2f6b
# import torch
# import numpy as np
# from PIL import Image
# from model import UNetInpaint
# import io
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = UNetInpaint(input_channels=4, output_channels=3)
# model.load_state_dict(torch.load("inpainting_model_best.pth", map_location=device))
# model.eval().to(device)
# def preprocess(image: Image.Image, mask: Image.Image):
# image = image.convert("RGB").resize((256, 256)) # Resize if needed
# mask = mask.convert("L").resize((256, 256))
# image = np.array(image).astype(np.float32) / 255.0
# mask = np.array(mask).astype(np.float32) / 255.0
# mask = (mask > 0.5).astype(np.float32)
# mask = np.expand_dims(mask, axis=-1)
# image = np.transpose(image, (2, 0, 1))
# mask = np.transpose(mask, (2, 0, 1))
# image = torch.tensor(image, dtype=torch.float32)
# mask = torch.tensor(mask, dtype=torch.float32)
# image = image * (1.0 - mask)
# input_tensor = torch.cat([image, mask], dim=0).unsqueeze(0).to(device)
# return input_tensor
# def predict(image: Image.Image, mask: Image.Image) -> Image.Image:
# input_tensor = preprocess(image, mask)
# with torch.no_grad():
# output = model(input_tensor).squeeze(0).cpu().numpy().transpose(1, 2, 0)
# output = np.clip(output, 0, 1)
# out_img = Image.fromarray((output * 255).astype(np.uint8), mode="RGB")
# return out_img
# from model import UNetInpaint
# from PIL import Image
# import torch
# import numpy as np
# import io
# model = UNetInpaint()
# model.load_state_dict(torch.load("model.pth", map_location="cpu"))
# model.eval()
# def predict(image: bytes, mask: bytes) -> bytes:
# image = Image.open(io.BytesIO(image)).convert("RGB")
# mask = Image.open(io.BytesIO(mask)).convert("L")
# # preprocess
# image_np = np.array(image).astype(np.float32) / 255.0
# mask_np = np.array(mask).astype(np.float32) / 255.0
# mask_np = (mask_np > 0.5).astype(np.float32)
# mask_np = np.expand_dims(mask_np, axis=-1)
# image_np = np.transpose(image_np, (2, 0, 1))
# mask_np = np.transpose(mask_np, (2, 0, 1))
# image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np))
# input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0)
# with torch.no_grad():
# output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
# output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
# result = Image.fromarray(output)
# buffer = io.BytesIO()
# result.save(buffer, format="PNG")
# return buffer.getvalue()
# from model import UNetInpaint
# import torch
# import numpy as np
# from PIL import Image
# model = UNetInpaint(input_channels=4, output_channels=3)
# model.load_state_dict(torch.load("inpainting_model_best_wgt.pth", map_location="cpu"))
# model.eval()
# def get_output(image_pil, mask_pil):
# if isinstance(mask_pil, np.ndarray):
# mask_pil = Image.fromarray(mask_pil.astype(np.uint8))
# # Now safely convert to grayscale
# mask = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0
# # Same for the image if needed
# if isinstance(image_pil, np.ndarray):
# image_pil = Image.fromarray(image_pil.astype(np.uint8))
# image = np.array(image_pil).astype(np.float32) / 255.0
# mask = (mask > 0.5).astype(np.float32)
# mask = np.expand_dims(mask, axis=-1)
# image = np.transpose(image, (2, 0, 1))
# mask = np.transpose(mask, (2, 0, 1))
# image = torch.tensor(image) * (1 - torch.tensor(mask))
# input_tensor = torch.cat([image, torch.tensor(mask)], dim=0).unsqueeze(0)
# with torch.no_grad():
# output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
# output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
# return Image.fromarray(output)
import torch
import numpy as np
import cv2
from PIL import Image
from model import UNetInpaint # assuming your model class is here
# Load model (adjust paths and class accordingly)
device = torch.device('cpu')
model = UNetInpaint(input_channels=4, output_channels=3)
model.load_state_dict(torch.load("inpainting_model_best_wgt.pth", map_location="cpu"))
model.eval()
def preprocess(image, mask):
# Resize or process if needed
# image = cv2.resize(image, (256, 256))
# mask = cv2.resize(mask, (256, 256))
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if mask.shape[2]:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
image = image.astype(np.float32) / 255.0
mask = mask.astype(np.float32) / 255.0
mask = (mask > 0.5).astype(np.float32)
# mask = cv2.resize(mask, (image.shape[0], image.shape[1]))
image = cv2.resize(image, (512, 512))
mask = cv2.resize(mask, (512, 512))
mask = np.expand_dims(mask, axis=-1)
image = np.transpose(image, (2, 0, 1))
mask = np.transpose(mask, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32)
mask = torch.tensor(mask, dtype=torch.float32)
image = image * (1.0 - mask)
input = torch.cat([image, mask], dim=0)
input = input.unsqueeze(0).to(device)
with torch.no_grad():
output = model(input)
output = output.squeeze(0).cpu().numpy().transpose(1, 2, 0)
output = np.clip(output, 0, 1)
out_img = Image.fromarray((output * 255).astype(np.uint8), mode='RGB')
return out_img
def get_output(image_np, mask_np):
input_tensor = preprocess(image_np, mask_np)
# with torch.no_grad():
# output = model(input_tensor)
# result = postprocess(output)
return input_tensor