""" Inference pipeline for fine-tuned SAM ViT-H facade segmentation. """ import os import argparse import numpy as np import torch from PIL import Image import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from transformers import SamModel, SamProcessor def load_model(checkpoint_path=None, device=None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) if checkpoint_path and os.path.exists(checkpoint_path): state = torch.load(checkpoint_path, map_location=device, weights_only=False) model.load_state_dict(state) print(f"Loaded checkpoint from {checkpoint_path}") else: print("Using original pre-trained SAM ViT-H") model.eval() return model, processor, device def predict_facade(model, processor, image_path, bbox=None, device="cpu"): img = Image.open(image_path).convert("RGB") orig_w, orig_h = img.size if bbox is None: bbox = [[[0.0, 0.0, float(orig_w), float(orig_h)]]] else: bbox = [[[float(b) for b in bbox]]] inputs = processor(images=img, input_boxes=bbox, return_tensors="pt").to(device) with torch.no_grad(): outputs = model( pixel_values=inputs["pixel_values"], input_boxes=inputs["input_boxes"], multimask_output=False, ) pred_mask = outputs.pred_masks.squeeze().cpu().numpy() pred_binary = (pred_mask > 0).astype(np.uint8) * 255 pred_binary = Image.fromarray(pred_binary).resize((orig_w, orig_h), Image.NEAREST) return np.array(pred_binary), img def visualize_result(image, mask, save_path): fig, axes = plt.subplots(1, 3, figsize=(15, 5)) img_arr = np.array(image) axes[0].imshow(img_arr) axes[0].set_title("Input Image") axes[0].axis('off') axes[1].imshow(mask, cmap='gray') axes[1].set_title("Predicted Facade Mask") axes[1].axis('off') overlay = img_arr.copy() overlay[mask > 0] = [255, 0, 0] blended = (img_arr * 0.6 + overlay * 0.4).astype(np.uint8) axes[2].imshow(blended) axes[2].set_title("Overlay") axes[2].axis('off') plt.tight_layout() plt.savefig(save_path, dpi=150) plt.close() print(f"Result saved to {save_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint") parser.add_argument("--image", required=True, help="Path to input image") parser.add_argument("--bbox", default=None, type=str, help="Bounding box as x1,y1,x2,y2") parser.add_argument("--output", default="outputs/inference_result.png", help="Output path") args = parser.parse_args() model, processor, device = load_model(args.checkpoint) bbox = None if args.bbox: bbox = [float(x.strip()) for x in args.bbox.split(",")] assert len(bbox) == 4, "Bounding box must be x1,y1,x2,y2" mask, img = predict_facade(model, processor, args.image, bbox, device) os.makedirs(os.path.dirname(args.output), exist_ok=True) visualize_result(img, mask, args.output) mask_path = args.output.replace('.png', '_mask.png') Image.fromarray(mask).save(mask_path) print(f"Mask saved to {mask_path}") if __name__ == "__main__": main()