# import torch # import numpy as np # from PIL import Image # from model import UNetInpaint # import io # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = UNetInpaint(input_channels=4, output_channels=3) # model.load_state_dict(torch.load("inpainting_model_best.pth", map_location=device)) # model.eval().to(device) # def preprocess(image: Image.Image, mask: Image.Image): # image = image.convert("RGB").resize((256, 256)) # Resize if needed # mask = mask.convert("L").resize((256, 256)) # image = np.array(image).astype(np.float32) / 255.0 # mask = np.array(mask).astype(np.float32) / 255.0 # mask = (mask > 0.5).astype(np.float32) # mask = np.expand_dims(mask, axis=-1) # image = np.transpose(image, (2, 0, 1)) # mask = np.transpose(mask, (2, 0, 1)) # image = torch.tensor(image, dtype=torch.float32) # mask = torch.tensor(mask, dtype=torch.float32) # image = image * (1.0 - mask) # input_tensor = torch.cat([image, mask], dim=0).unsqueeze(0).to(device) # return input_tensor # def predict(image: Image.Image, mask: Image.Image) -> Image.Image: # input_tensor = preprocess(image, mask) # with torch.no_grad(): # output = model(input_tensor).squeeze(0).cpu().numpy().transpose(1, 2, 0) # output = np.clip(output, 0, 1) # out_img = Image.fromarray((output * 255).astype(np.uint8), mode="RGB") # return out_img # from model import UNetInpaint # from PIL import Image # import torch # import numpy as np # import io # model = UNetInpaint() # model.load_state_dict(torch.load("model.pth", map_location="cpu")) # model.eval() # def predict(image: bytes, mask: bytes) -> bytes: # image = Image.open(io.BytesIO(image)).convert("RGB") # mask = Image.open(io.BytesIO(mask)).convert("L") # # preprocess # image_np = np.array(image).astype(np.float32) / 255.0 # mask_np = np.array(mask).astype(np.float32) / 255.0 # mask_np = (mask_np > 0.5).astype(np.float32) # mask_np = np.expand_dims(mask_np, axis=-1) # image_np = np.transpose(image_np, (2, 0, 1)) # mask_np = np.transpose(mask_np, (2, 0, 1)) # image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np)) # input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0) # with torch.no_grad(): # output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0) # output = (np.clip(output, 0, 1) * 255).astype(np.uint8) # result = Image.fromarray(output) # buffer = io.BytesIO() # result.save(buffer, format="PNG") # return buffer.getvalue() # from model import UNetInpaint # import torch # import numpy as np # from PIL import Image # model = UNetInpaint(input_channels=4, output_channels=3) # model.load_state_dict(torch.load("inpainting_model_best_wgt.pth", map_location="cpu")) # model.eval() # def get_output(image_pil, mask_pil): # if isinstance(mask_pil, np.ndarray): # mask_pil = Image.fromarray(mask_pil.astype(np.uint8)) # # Now safely convert to grayscale # mask = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0 # # Same for the image if needed # if isinstance(image_pil, np.ndarray): # image_pil = Image.fromarray(image_pil.astype(np.uint8)) # image = np.array(image_pil).astype(np.float32) / 255.0 # mask = (mask > 0.5).astype(np.float32) # mask = np.expand_dims(mask, axis=-1) # image = np.transpose(image, (2, 0, 1)) # mask = np.transpose(mask, (2, 0, 1)) # image = torch.tensor(image) * (1 - torch.tensor(mask)) # input_tensor = torch.cat([image, torch.tensor(mask)], dim=0).unsqueeze(0) # with torch.no_grad(): # output = model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0) # output = (np.clip(output, 0, 1) * 255).astype(np.uint8) # return Image.fromarray(output) import torch import numpy as np import cv2 from PIL import Image from model import UNetInpaint # assuming your model class is here # Load model (adjust paths and class accordingly) device = torch.device('cpu') model = UNetInpaint(input_channels=4, output_channels=3) model.load_state_dict(torch.load("inpainting_model_best_wgt.pth", map_location="cpu")) model.eval() def preprocess(image, mask): # Resize or process if needed # image = cv2.resize(image, (256, 256)) # mask = cv2.resize(mask, (256, 256)) # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if mask.shape[2]: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) image = image.astype(np.float32) / 255.0 mask = mask.astype(np.float32) / 255.0 mask = (mask > 0.5).astype(np.float32) # mask = cv2.resize(mask, (image.shape[0], image.shape[1])) image = cv2.resize(image, (512, 512)) mask = cv2.resize(mask, (512, 512)) mask = np.expand_dims(mask, axis=-1) image = np.transpose(image, (2, 0, 1)) mask = np.transpose(mask, (2, 0, 1)) image = torch.tensor(image, dtype=torch.float32) mask = torch.tensor(mask, dtype=torch.float32) image = image * (1.0 - mask) input = torch.cat([image, mask], dim=0) input = input.unsqueeze(0).to(device) with torch.no_grad(): output = model(input) output = output.squeeze(0).cpu().numpy().transpose(1, 2, 0) output = np.clip(output, 0, 1) out_img = Image.fromarray((output * 255).astype(np.uint8), mode='RGB') return out_img def get_output(image_np, mask_np): input_tensor = preprocess(image_np, mask_np) # with torch.no_grad(): # output = model(input_tensor) # result = postprocess(output) return input_tensor