| """ |
| Inference pipeline for fine-tuned SAM ViT-H facade segmentation. |
| """ |
| import os |
| import argparse |
| import numpy as np |
| import torch |
| from PIL import Image |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from transformers import SamModel, SamProcessor |
|
|
|
|
| def load_model(checkpoint_path=None, device=None): |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
| model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) |
|
|
| if checkpoint_path and os.path.exists(checkpoint_path): |
| state = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| model.load_state_dict(state) |
| print(f"Loaded checkpoint from {checkpoint_path}") |
| else: |
| print("Using original pre-trained SAM ViT-H") |
|
|
| model.eval() |
| return model, processor, device |
|
|
|
|
| def predict_facade(model, processor, image_path, bbox=None, device="cpu"): |
| img = Image.open(image_path).convert("RGB") |
| orig_w, orig_h = img.size |
|
|
| if bbox is None: |
| bbox = [[[0.0, 0.0, float(orig_w), float(orig_h)]]] |
| else: |
| bbox = [[[float(b) for b in bbox]]] |
|
|
| inputs = processor(images=img, input_boxes=bbox, return_tensors="pt").to(device) |
|
|
| with torch.no_grad(): |
| outputs = model( |
| pixel_values=inputs["pixel_values"], |
| input_boxes=inputs["input_boxes"], |
| multimask_output=False, |
| ) |
|
|
| pred_mask = outputs.pred_masks.squeeze().cpu().numpy() |
| pred_binary = (pred_mask > 0).astype(np.uint8) * 255 |
| pred_binary = Image.fromarray(pred_binary).resize((orig_w, orig_h), Image.NEAREST) |
| return np.array(pred_binary), img |
|
|
|
|
| def visualize_result(image, mask, save_path): |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
| img_arr = np.array(image) |
|
|
| axes[0].imshow(img_arr) |
| axes[0].set_title("Input Image") |
| axes[0].axis('off') |
|
|
| axes[1].imshow(mask, cmap='gray') |
| axes[1].set_title("Predicted Facade Mask") |
| axes[1].axis('off') |
|
|
| overlay = img_arr.copy() |
| overlay[mask > 0] = [255, 0, 0] |
| blended = (img_arr * 0.6 + overlay * 0.4).astype(np.uint8) |
| axes[2].imshow(blended) |
| axes[2].set_title("Overlay") |
| axes[2].axis('off') |
|
|
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150) |
| plt.close() |
| print(f"Result saved to {save_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint") |
| parser.add_argument("--image", required=True, help="Path to input image") |
| parser.add_argument("--bbox", default=None, type=str, help="Bounding box as x1,y1,x2,y2") |
| parser.add_argument("--output", default="outputs/inference_result.png", help="Output path") |
| args = parser.parse_args() |
|
|
| model, processor, device = load_model(args.checkpoint) |
|
|
| bbox = None |
| if args.bbox: |
| bbox = [float(x.strip()) for x in args.bbox.split(",")] |
| assert len(bbox) == 4, "Bounding box must be x1,y1,x2,y2" |
|
|
| mask, img = predict_facade(model, processor, args.image, bbox, device) |
|
|
| os.makedirs(os.path.dirname(args.output), exist_ok=True) |
| visualize_result(img, mask, args.output) |
|
|
| mask_path = args.output.replace('.png', '_mask.png') |
| Image.fromarray(mask).save(mask_path) |
| print(f"Mask saved to {mask_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|