File size: 3,422 Bytes
3cc53ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Inference pipeline for fine-tuned SAM ViT-H facade segmentation.
"""
import os
import argparse
import numpy as np
import torch
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from transformers import SamModel, SamProcessor


def load_model(checkpoint_path=None, device=None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
    model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)

    if checkpoint_path and os.path.exists(checkpoint_path):
        state = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(state)
        print(f"Loaded checkpoint from {checkpoint_path}")
    else:
        print("Using original pre-trained SAM ViT-H")

    model.eval()
    return model, processor, device


def predict_facade(model, processor, image_path, bbox=None, device="cpu"):
    img = Image.open(image_path).convert("RGB")
    orig_w, orig_h = img.size

    if bbox is None:
        bbox = [[[0.0, 0.0, float(orig_w), float(orig_h)]]]
    else:
        bbox = [[[float(b) for b in bbox]]]

    inputs = processor(images=img, input_boxes=bbox, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(
            pixel_values=inputs["pixel_values"],
            input_boxes=inputs["input_boxes"],
            multimask_output=False,
        )

    pred_mask = outputs.pred_masks.squeeze().cpu().numpy()
    pred_binary = (pred_mask > 0).astype(np.uint8) * 255
    pred_binary = Image.fromarray(pred_binary).resize((orig_w, orig_h), Image.NEAREST)
    return np.array(pred_binary), img


def visualize_result(image, mask, save_path):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    img_arr = np.array(image)

    axes[0].imshow(img_arr)
    axes[0].set_title("Input Image")
    axes[0].axis('off')

    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title("Predicted Facade Mask")
    axes[1].axis('off')

    overlay = img_arr.copy()
    overlay[mask > 0] = [255, 0, 0]
    blended = (img_arr * 0.6 + overlay * 0.4).astype(np.uint8)
    axes[2].imshow(blended)
    axes[2].set_title("Overlay")
    axes[2].axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()
    print(f"Result saved to {save_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint")
    parser.add_argument("--image", required=True, help="Path to input image")
    parser.add_argument("--bbox", default=None, type=str, help="Bounding box as x1,y1,x2,y2")
    parser.add_argument("--output", default="outputs/inference_result.png", help="Output path")
    args = parser.parse_args()

    model, processor, device = load_model(args.checkpoint)

    bbox = None
    if args.bbox:
        bbox = [float(x.strip()) for x in args.bbox.split(",")]
        assert len(bbox) == 4, "Bounding box must be x1,y1,x2,y2"

    mask, img = predict_facade(model, processor, args.image, bbox, device)

    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    visualize_result(img, mask, args.output)

    mask_path = args.output.replace('.png', '_mask.png')
    Image.fromarray(mask).save(mask_path)
    print(f"Mask saved to {mask_path}")


if __name__ == "__main__":
    main()