""" Generate side-by-side qualitative comparisons: Input image, GT mask, original SAM prediction, fine-tuned SAM prediction. """ import os import argparse import json import numpy as np import torch from PIL import Image import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from transformers import SamModel, SamProcessor from tqdm import tqdm from torch.utils.data import DataLoader from dataset import FacadeDataset, collate_fn def get_predictions(model, dataloader, device): model.eval() preds = [] with torch.no_grad(): for batch in tqdm(dataloader, desc="Predicting"): pixel_values = batch["pixel_values"].to(device) input_boxes = batch["input_boxes"].to(device) outputs = model( pixel_values=pixel_values, input_boxes=input_boxes, multimask_output=False, ) pred_masks = outputs.pred_masks.squeeze(1).squeeze(1) pred_binary = (torch.sigmoid(pred_masks) > 0.5).cpu().numpy() preds.append(pred_binary) return np.concatenate(preds, axis=0) def visualize_comparison(images, gts, preds_baseline, preds_finetuned, indices, save_dir): os.makedirs(save_dir, exist_ok=True) for idx in indices: img = images[idx] gt = gts[idx] pred_base = preds_baseline[idx] pred_ft = preds_finetuned[idx] fig, axes = plt.subplots(2, 3, figsize=(15, 10)) axes[0, 0].imshow(img) axes[0, 0].set_title("Input Image") axes[0, 0].axis('off') axes[0, 1].imshow(gt, cmap='gray') axes[0, 1].set_title("Ground Truth") axes[0, 1].axis('off') axes[0, 2].axis('off') iou_base = compute_iou(pred_base, gt) axes[1, 0].imshow(pred_base, cmap='gray') axes[1, 0].set_title(f"Original SAM ViT-H\nIoU={iou_base:.3f}") axes[1, 0].axis('off') iou_ft = compute_iou(pred_ft, gt) axes[1, 1].imshow(pred_ft, cmap='gray') axes[1, 1].set_title(f"Fine-tuned SAM ViT-H\nIoU={iou_ft:.3f}") axes[1, 1].axis('off') overlay_base = img.copy() overlay_base[pred_base > 0] = [255, 0, 0] blended_base = (img * 0.6 + overlay_base * 0.4).astype(np.uint8) axes[1, 2].imshow(blended_base) axes[1, 2].set_title("Overlay Original") axes[1, 2].axis('off') plt.suptitle(f"Index {idx}: Original IoU={iou_base:.3f} | Fine-tuned IoU={iou_ft:.3f}", fontsize=14) plt.tight_layout() save_path = os.path.join(save_dir, f"comparison_{idx:04d}.png") plt.savefig(save_path, dpi=150) plt.close() print(f"Saved {save_path}") def compute_iou(pred, gt): pred = pred.astype(bool) gt = gt.astype(bool) inter = np.logical_and(pred, gt).sum() union = np.logical_or(pred, gt).sum() return 1.0 if union == 0 else float(inter / union) def load_raw_images_and_masks(data_dir, split, num_samples): split_dir = os.path.join(data_dir, split) with open(os.path.join(split_dir, "metadata.json"), "r") as f: items = json.load(f) images = [] gts = [] for item in items[:num_samples]: img = Image.open(item["image"]).convert("RGB").resize((256, 256), Image.BILINEAR) mask_path = os.path.join(split_dir, "masks_binary", os.path.basename(item["image"]).replace(".jpg", ".png")) gt = Image.open(mask_path).convert("L").resize((256, 256), Image.NEAREST) images.append(np.array(img)) gts.append(np.array(gt) > 0) return np.array(images), np.array(gts), items def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") dataset = FacadeDataset(args.data_dir, split=args.split, processor=processor, augment=False) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) images_arr, gts_arr, items = load_raw_images_and_masks(args.data_dir, args.split, len(dataset)) print("Running baseline predictions...") model_base = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) preds_base = get_predictions(model_base, dataloader, device) del model_base if torch.cuda.is_available(): torch.cuda.empty_cache() print("Running fine-tuned predictions...") model_ft = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) state = torch.load(args.checkpoint, map_location=device, weights_only=False) model_ft.load_state_dict(state) preds_ft = get_predictions(model_ft, dataloader, device) ious_base = [] for i in range(len(preds_base)): p = preds_base[i] g = gts_arr[i] iou = np.logical_and(p, g).sum() / (np.logical_or(p, g).sum() + 1e-6) ious_base.append(iou) ious_base = np.array(ious_base) sorted_idx = np.argsort(ious_base) selected = [ int(sorted_idx[0]), int(sorted_idx[-1]), int(sorted_idx[len(sorted_idx)//2]), ] rng = np.random.RandomState(42) extra = rng.choice(len(dataset), size=min(7, max(0, len(dataset)-3)), replace=False) selected = list(dict.fromkeys(selected + extra.tolist()))[:10] visualize_comparison(images_arr, gts_arr, preds_base, preds_ft, selected, args.output_dir) with open(os.path.join(args.output_dir, "comparison_indices.json"), "w") as f: json.dump({"indices": [int(x) for x in selected], "ious_base": ious_base[selected].tolist()}, f, indent=2) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--data_dir", default="data/cmp_facade") parser.add_argument("--split", default="test") parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--output_dir", default="outputs/comparison") args = parser.parse_args() main(args)