| """ |
| Generate side-by-side qualitative comparisons: |
| Input image, GT mask, original SAM prediction, fine-tuned SAM prediction. |
| """ |
| import os |
| import argparse |
| import json |
| import numpy as np |
| import torch |
| from PIL import Image |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from transformers import SamModel, SamProcessor |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
|
|
| from dataset import FacadeDataset, collate_fn |
|
|
|
|
| def get_predictions(model, dataloader, device): |
| model.eval() |
| preds = [] |
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc="Predicting"): |
| pixel_values = batch["pixel_values"].to(device) |
| input_boxes = batch["input_boxes"].to(device) |
| outputs = model( |
| pixel_values=pixel_values, |
| input_boxes=input_boxes, |
| multimask_output=False, |
| ) |
| pred_masks = outputs.pred_masks.squeeze(1).squeeze(1) |
| pred_binary = (torch.sigmoid(pred_masks) > 0.5).cpu().numpy() |
| preds.append(pred_binary) |
| return np.concatenate(preds, axis=0) |
|
|
|
|
| def visualize_comparison(images, gts, preds_baseline, preds_finetuned, indices, save_dir): |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| for idx in indices: |
| img = images[idx] |
| gt = gts[idx] |
| pred_base = preds_baseline[idx] |
| pred_ft = preds_finetuned[idx] |
|
|
| fig, axes = plt.subplots(2, 3, figsize=(15, 10)) |
|
|
| axes[0, 0].imshow(img) |
| axes[0, 0].set_title("Input Image") |
| axes[0, 0].axis('off') |
|
|
| axes[0, 1].imshow(gt, cmap='gray') |
| axes[0, 1].set_title("Ground Truth") |
| axes[0, 1].axis('off') |
|
|
| axes[0, 2].axis('off') |
|
|
| iou_base = compute_iou(pred_base, gt) |
| axes[1, 0].imshow(pred_base, cmap='gray') |
| axes[1, 0].set_title(f"Original SAM ViT-H\nIoU={iou_base:.3f}") |
| axes[1, 0].axis('off') |
|
|
| iou_ft = compute_iou(pred_ft, gt) |
| axes[1, 1].imshow(pred_ft, cmap='gray') |
| axes[1, 1].set_title(f"Fine-tuned SAM ViT-H\nIoU={iou_ft:.3f}") |
| axes[1, 1].axis('off') |
|
|
| overlay_base = img.copy() |
| overlay_base[pred_base > 0] = [255, 0, 0] |
| blended_base = (img * 0.6 + overlay_base * 0.4).astype(np.uint8) |
| axes[1, 2].imshow(blended_base) |
| axes[1, 2].set_title("Overlay Original") |
| axes[1, 2].axis('off') |
|
|
| plt.suptitle(f"Index {idx}: Original IoU={iou_base:.3f} | Fine-tuned IoU={iou_ft:.3f}", fontsize=14) |
| plt.tight_layout() |
| save_path = os.path.join(save_dir, f"comparison_{idx:04d}.png") |
| plt.savefig(save_path, dpi=150) |
| plt.close() |
| print(f"Saved {save_path}") |
|
|
|
|
| def compute_iou(pred, gt): |
| pred = pred.astype(bool) |
| gt = gt.astype(bool) |
| inter = np.logical_and(pred, gt).sum() |
| union = np.logical_or(pred, gt).sum() |
| return 1.0 if union == 0 else float(inter / union) |
|
|
|
|
| def load_raw_images_and_masks(data_dir, split, num_samples): |
| split_dir = os.path.join(data_dir, split) |
| with open(os.path.join(split_dir, "metadata.json"), "r") as f: |
| items = json.load(f) |
|
|
| images = [] |
| gts = [] |
| for item in items[:num_samples]: |
| img = Image.open(item["image"]).convert("RGB").resize((256, 256), Image.BILINEAR) |
| mask_path = os.path.join(split_dir, "masks_binary", os.path.basename(item["image"]).replace(".jpg", ".png")) |
| gt = Image.open(mask_path).convert("L").resize((256, 256), Image.NEAREST) |
| images.append(np.array(img)) |
| gts.append(np.array(gt) > 0) |
|
|
| return np.array(images), np.array(gts), items |
|
|
|
|
| def main(args): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
|
| dataset = FacadeDataset(args.data_dir, split=args.split, processor=processor, augment=False) |
| dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) |
|
|
| images_arr, gts_arr, items = load_raw_images_and_masks(args.data_dir, args.split, len(dataset)) |
|
|
| print("Running baseline predictions...") |
| model_base = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) |
| preds_base = get_predictions(model_base, dataloader, device) |
| del model_base |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| print("Running fine-tuned predictions...") |
| model_ft = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) |
| state = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| model_ft.load_state_dict(state) |
| preds_ft = get_predictions(model_ft, dataloader, device) |
|
|
| ious_base = [] |
| for i in range(len(preds_base)): |
| p = preds_base[i] |
| g = gts_arr[i] |
| iou = np.logical_and(p, g).sum() / (np.logical_or(p, g).sum() + 1e-6) |
| ious_base.append(iou) |
| ious_base = np.array(ious_base) |
|
|
| sorted_idx = np.argsort(ious_base) |
| selected = [ |
| int(sorted_idx[0]), |
| int(sorted_idx[-1]), |
| int(sorted_idx[len(sorted_idx)//2]), |
| ] |
| rng = np.random.RandomState(42) |
| extra = rng.choice(len(dataset), size=min(7, max(0, len(dataset)-3)), replace=False) |
| selected = list(dict.fromkeys(selected + extra.tolist()))[:10] |
|
|
| visualize_comparison(images_arr, gts_arr, preds_base, preds_ft, selected, args.output_dir) |
|
|
| with open(os.path.join(args.output_dir, "comparison_indices.json"), "w") as f: |
| json.dump({"indices": [int(x) for x in selected], "ious_base": ious_base[selected].tolist()}, f, indent=2) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--data_dir", default="data/cmp_facade") |
| parser.add_argument("--split", default="test") |
| parser.add_argument("--batch_size", type=int, default=2) |
| parser.add_argument("--output_dir", default="outputs/comparison") |
| args = parser.parse_args() |
| main(args) |
|
|