Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from model import PretrainedUNet | |
| from PIL import Image, ImageOps | |
| def load_model(model_path, device): | |
| model = PretrainedUNet().to(device) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| return model | |
| def run_inference(model, img_a_tensor, img_b_tensor, device): | |
| img_a_tensor = img_a_tensor.unsqueeze(0).to(device) | |
| img_b_tensor = img_b_tensor.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(img_a_tensor, img_b_tensor) | |
| pred_mask = torch.sigmoid(outputs).cpu().squeeze().numpy() | |
| return pred_mask | |
| def create_overlay(base_img, mask, color, threshold=0.5): | |
| """Create colored overlay on image""" | |
| base_np = np.array(base_img) | |
| if isinstance(mask, Image.Image): | |
| # Resize mask to match base image size | |
| mask = mask.resize((base_np.shape[1], base_np.shape[0]), Image.NEAREST) | |
| mask_np = np.array(mask.convert('L')) / 255.0 | |
| else: | |
| # If numpy array, resize if needed | |
| if mask.shape[:2] != base_np.shape[:2]: | |
| mask_pil = Image.fromarray((mask * 255).astype(np.uint8)) | |
| mask_pil = mask_pil.resize((base_np.shape[1], base_np.shape[0]), Image.NEAREST) | |
| mask_np = np.array(mask_pil) / 255.0 | |
| else: | |
| mask_np = mask | |
| overlay = base_np.copy() | |
| mask_bool = mask_np > threshold | |
| overlay[mask_bool] = color | |
| alpha = 0.5 | |
| result = (alpha * overlay + (1 - alpha) * base_np).astype(np.uint8) | |
| return Image.fromarray(result) |