File size: 3,422 Bytes
3cc53ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | """
Inference pipeline for fine-tuned SAM ViT-H facade segmentation.
"""
import os
import argparse
import numpy as np
import torch
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from transformers import SamModel, SamProcessor
def load_model(checkpoint_path=None, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
if checkpoint_path and os.path.exists(checkpoint_path):
state = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(state)
print(f"Loaded checkpoint from {checkpoint_path}")
else:
print("Using original pre-trained SAM ViT-H")
model.eval()
return model, processor, device
def predict_facade(model, processor, image_path, bbox=None, device="cpu"):
img = Image.open(image_path).convert("RGB")
orig_w, orig_h = img.size
if bbox is None:
bbox = [[[0.0, 0.0, float(orig_w), float(orig_h)]]]
else:
bbox = [[[float(b) for b in bbox]]]
inputs = processor(images=img, input_boxes=bbox, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(
pixel_values=inputs["pixel_values"],
input_boxes=inputs["input_boxes"],
multimask_output=False,
)
pred_mask = outputs.pred_masks.squeeze().cpu().numpy()
pred_binary = (pred_mask > 0).astype(np.uint8) * 255
pred_binary = Image.fromarray(pred_binary).resize((orig_w, orig_h), Image.NEAREST)
return np.array(pred_binary), img
def visualize_result(image, mask, save_path):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
img_arr = np.array(image)
axes[0].imshow(img_arr)
axes[0].set_title("Input Image")
axes[0].axis('off')
axes[1].imshow(mask, cmap='gray')
axes[1].set_title("Predicted Facade Mask")
axes[1].axis('off')
overlay = img_arr.copy()
overlay[mask > 0] = [255, 0, 0]
blended = (img_arr * 0.6 + overlay * 0.4).astype(np.uint8)
axes[2].imshow(blended)
axes[2].set_title("Overlay")
axes[2].axis('off')
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
print(f"Result saved to {save_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint")
parser.add_argument("--image", required=True, help="Path to input image")
parser.add_argument("--bbox", default=None, type=str, help="Bounding box as x1,y1,x2,y2")
parser.add_argument("--output", default="outputs/inference_result.png", help="Output path")
args = parser.parse_args()
model, processor, device = load_model(args.checkpoint)
bbox = None
if args.bbox:
bbox = [float(x.strip()) for x in args.bbox.split(",")]
assert len(bbox) == 4, "Bounding box must be x1,y1,x2,y2"
mask, img = predict_facade(model, processor, args.image, bbox, device)
os.makedirs(os.path.dirname(args.output), exist_ok=True)
visualize_result(img, mask, args.output)
mask_path = args.output.replace('.png', '_mask.png')
Image.fromarray(mask).save(mask_path)
print(f"Mask saved to {mask_path}")
if __name__ == "__main__":
main()
|