acd23's picture
Upload folder using huggingface_hub
3cc53ab verified
"""
Inference pipeline for fine-tuned SAM ViT-H facade segmentation.
"""
import os
import argparse
import numpy as np
import torch
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from transformers import SamModel, SamProcessor
def load_model(checkpoint_path=None, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
if checkpoint_path and os.path.exists(checkpoint_path):
state = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(state)
print(f"Loaded checkpoint from {checkpoint_path}")
else:
print("Using original pre-trained SAM ViT-H")
model.eval()
return model, processor, device
def predict_facade(model, processor, image_path, bbox=None, device="cpu"):
img = Image.open(image_path).convert("RGB")
orig_w, orig_h = img.size
if bbox is None:
bbox = [[[0.0, 0.0, float(orig_w), float(orig_h)]]]
else:
bbox = [[[float(b) for b in bbox]]]
inputs = processor(images=img, input_boxes=bbox, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(
pixel_values=inputs["pixel_values"],
input_boxes=inputs["input_boxes"],
multimask_output=False,
)
pred_mask = outputs.pred_masks.squeeze().cpu().numpy()
pred_binary = (pred_mask > 0).astype(np.uint8) * 255
pred_binary = Image.fromarray(pred_binary).resize((orig_w, orig_h), Image.NEAREST)
return np.array(pred_binary), img
def visualize_result(image, mask, save_path):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
img_arr = np.array(image)
axes[0].imshow(img_arr)
axes[0].set_title("Input Image")
axes[0].axis('off')
axes[1].imshow(mask, cmap='gray')
axes[1].set_title("Predicted Facade Mask")
axes[1].axis('off')
overlay = img_arr.copy()
overlay[mask > 0] = [255, 0, 0]
blended = (img_arr * 0.6 + overlay * 0.4).astype(np.uint8)
axes[2].imshow(blended)
axes[2].set_title("Overlay")
axes[2].axis('off')
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
print(f"Result saved to {save_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint")
parser.add_argument("--image", required=True, help="Path to input image")
parser.add_argument("--bbox", default=None, type=str, help="Bounding box as x1,y1,x2,y2")
parser.add_argument("--output", default="outputs/inference_result.png", help="Output path")
args = parser.parse_args()
model, processor, device = load_model(args.checkpoint)
bbox = None
if args.bbox:
bbox = [float(x.strip()) for x in args.bbox.split(",")]
assert len(bbox) == 4, "Bounding box must be x1,y1,x2,y2"
mask, img = predict_facade(model, processor, args.image, bbox, device)
os.makedirs(os.path.dirname(args.output), exist_ok=True)
visualize_result(img, mask, args.output)
mask_path = args.output.replace('.png', '_mask.png')
Image.fromarray(mask).save(mask_path)
print(f"Mask saved to {mask_path}")
if __name__ == "__main__":
main()