Spaces:
Sleeping
Sleeping
File size: 1,059 Bytes
6bb078a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
import torchvision.transforms as T
import numpy as np
from PIL import Image
import cv2
from segmentation_models_pytorch import Unet
def preprocess_image(image, size=(512, 512)):
image = image.resize(size)
img_array = np.array(image).astype(np.float32) / 255.0
img_array = np.transpose(img_array, (2, 0, 1))
tensor = torch.tensor(img_array).unsqueeze(0)
return tensor
def predict_mask(model, tensor):
with torch.no_grad():
output = torch.sigmoid(model(tensor))
return output.squeeze().cpu().numpy()
def postprocess_mask(mask_array, threshold=0.5):
mask = (mask_array > threshold).astype(np.uint8) * 255
mask_rgb = np.stack([mask]*3, axis=-1)
return Image.fromarray(mask_rgb)
def load_model(path: str, device: torch.device):
model = Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1) # your architecture
model.load_state_dict(torch.load(path, map_location=device))
model.to(device)
model.eval()
return model
|