""" Evaluate a SAM ViT-H checkpoint (original or fine-tuned) on facade segmentation. """ import os import argparse import json import numpy as np import torch from torch.utils.data import DataLoader from transformers import SamModel, SamProcessor from tqdm import tqdm from dataset import FacadeDataset, collate_fn from metrics import compute_all_metrics @torch.no_grad() def evaluate_checkpoint(checkpoint_path, data_dir, split="test", batch_size=1, output_dir="outputs/eval"): os.makedirs(output_dir, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) if checkpoint_path and os.path.exists(checkpoint_path): print(f"Loading checkpoint: {checkpoint_path}") state = torch.load(checkpoint_path, map_location=device, weights_only=False) model.load_state_dict(state) else: print("Evaluating original pre-trained SAM ViT-H (no checkpoint loaded)") model.eval() dataset = FacadeDataset(data_dir, split=split, processor=processor, augment=False) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) all_preds = [] all_gts = [] all_ious = [] print(f"Evaluating on {len(dataset)} samples...") for batch in tqdm(dataloader, desc="Eval"): pixel_values = batch["pixel_values"].to(device) input_boxes = batch["input_boxes"].to(device) gt_masks = batch["ground_truth_mask"].to(device) outputs = model( pixel_values=pixel_values, input_boxes=input_boxes, multimask_output=False, ) pred_masks = outputs.pred_masks.squeeze(1).squeeze(1) pred_probs = torch.sigmoid(pred_masks) pred_binary = (pred_probs > 0.5).float() all_preds.append(pred_binary.cpu().numpy()) all_gts.append(gt_masks.cpu().numpy()) for b in range(pred_binary.shape[0]): p = pred_binary[b].cpu().numpy() g = gt_masks[b].cpu().numpy() intersection = np.logical_and(p, g).sum() union = np.logical_or(p, g).sum() iou = float(intersection / union) if union > 0 else 1.0 all_ious.append(iou) all_preds = np.concatenate(all_preds, axis=0) all_gts = np.concatenate(all_gts, axis=0) metrics = compute_all_metrics(all_preds, all_gts) metrics["mean_iou_list"] = [float(v) for v in all_ious] metrics["split"] = split metrics["checkpoint"] = checkpoint_path if checkpoint_path else "original" print("\n=== Evaluation Results ===") for k, v in metrics.items(): if isinstance(v, float): print(f"{k}: {v:.4f}") out_file = os.path.join(output_dir, f"metrics_{split}.json") with open(out_file, "w") as f: json.dump(metrics, f, indent=2) np.save(os.path.join(output_dir, f"preds_{split}.npy"), all_preds) np.save(os.path.join(output_dir, f"gts_{split}.npy"), all_gts) print(f"Saved to {output_dir}") return metrics if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint. Leave empty for original baseline.") parser.add_argument("--data_dir", default="data/cmp_facade") parser.add_argument("--split", default="test") parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--output_dir", default="outputs/eval") args = parser.parse_args() evaluate_checkpoint(args.checkpoint, args.data_dir, args.split, args.batch_size, args.output_dir)