building_detection / inference.py
mustafa2ak's picture
Update inference.py
732d9a1 verified
import torch
import numpy as np
from PIL import Image
from model import PretrainedUNet
from PIL import Image, ImageOps
def load_model(model_path, device):
model = PretrainedUNet().to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()
return model
def run_inference(model, img_a_tensor, img_b_tensor, device):
img_a_tensor = img_a_tensor.unsqueeze(0).to(device)
img_b_tensor = img_b_tensor.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_a_tensor, img_b_tensor)
pred_mask = torch.sigmoid(outputs).cpu().squeeze().numpy()
return pred_mask
def create_overlay(base_img, mask, color, threshold=0.5):
"""Create colored overlay on image"""
base_np = np.array(base_img)
if isinstance(mask, Image.Image):
# Resize mask to match base image size
mask = mask.resize((base_np.shape[1], base_np.shape[0]), Image.NEAREST)
mask_np = np.array(mask.convert('L')) / 255.0
else:
# If numpy array, resize if needed
if mask.shape[:2] != base_np.shape[:2]:
mask_pil = Image.fromarray((mask * 255).astype(np.uint8))
mask_pil = mask_pil.resize((base_np.shape[1], base_np.shape[0]), Image.NEAREST)
mask_np = np.array(mask_pil) / 255.0
else:
mask_np = mask
overlay = base_np.copy()
mask_bool = mask_np > threshold
overlay[mask_bool] = color
alpha = 0.5
result = (alpha * overlay + (1 - alpha) * base_np).astype(np.uint8)
return Image.fromarray(result)