import torch import numpy as np from PIL import Image from model import PretrainedUNet from PIL import Image, ImageOps def load_model(model_path, device): model = PretrainedUNet().to(device) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint) model.eval() return model def run_inference(model, img_a_tensor, img_b_tensor, device): img_a_tensor = img_a_tensor.unsqueeze(0).to(device) img_b_tensor = img_b_tensor.unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img_a_tensor, img_b_tensor) pred_mask = torch.sigmoid(outputs).cpu().squeeze().numpy() return pred_mask def create_overlay(base_img, mask, color, threshold=0.5): """Create colored overlay on image""" base_np = np.array(base_img) if isinstance(mask, Image.Image): # Resize mask to match base image size mask = mask.resize((base_np.shape[1], base_np.shape[0]), Image.NEAREST) mask_np = np.array(mask.convert('L')) / 255.0 else: # If numpy array, resize if needed if mask.shape[:2] != base_np.shape[:2]: mask_pil = Image.fromarray((mask * 255).astype(np.uint8)) mask_pil = mask_pil.resize((base_np.shape[1], base_np.shape[0]), Image.NEAREST) mask_np = np.array(mask_pil) / 255.0 else: mask_np = mask overlay = base_np.copy() mask_bool = mask_np > threshold overlay[mask_bool] = color alpha = 0.5 result = (alpha * overlay + (1 - alpha) * base_np).astype(np.uint8) return Image.fromarray(result)