import torch import torchvision.transforms as T import numpy as np from PIL import Image import cv2 from segmentation_models_pytorch import Unet def preprocess_image(image, size=(512, 512)): image = image.resize(size) img_array = np.array(image).astype(np.float32) / 255.0 img_array = np.transpose(img_array, (2, 0, 1)) tensor = torch.tensor(img_array).unsqueeze(0) return tensor def predict_mask(model, tensor): with torch.no_grad(): output = torch.sigmoid(model(tensor)) return output.squeeze().cpu().numpy() def postprocess_mask(mask_array, threshold=0.5): mask = (mask_array > threshold).astype(np.uint8) * 255 mask_rgb = np.stack([mask]*3, axis=-1) return Image.fromarray(mask_rgb) def load_model(path: str, device: torch.device): model = Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1) # your architecture model.load_state_dict(torch.load(path, map_location=device)) model.to(device) model.eval() return model