| """ |
| Evaluate a SAM ViT-H checkpoint (original or fine-tuned) on facade segmentation. |
| """ |
| import os |
| import argparse |
| import json |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from transformers import SamModel, SamProcessor |
| from tqdm import tqdm |
|
|
| from dataset import FacadeDataset, collate_fn |
| from metrics import compute_all_metrics |
|
|
|
|
| @torch.no_grad() |
| def evaluate_checkpoint(checkpoint_path, data_dir, split="test", batch_size=1, output_dir="outputs/eval"): |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
|
|
| processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
| model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) |
|
|
| if checkpoint_path and os.path.exists(checkpoint_path): |
| print(f"Loading checkpoint: {checkpoint_path}") |
| state = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| model.load_state_dict(state) |
| else: |
| print("Evaluating original pre-trained SAM ViT-H (no checkpoint loaded)") |
|
|
| model.eval() |
|
|
| dataset = FacadeDataset(data_dir, split=split, processor=processor, augment=False) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) |
|
|
| all_preds = [] |
| all_gts = [] |
| all_ious = [] |
|
|
| print(f"Evaluating on {len(dataset)} samples...") |
| for batch in tqdm(dataloader, desc="Eval"): |
| pixel_values = batch["pixel_values"].to(device) |
| input_boxes = batch["input_boxes"].to(device) |
| gt_masks = batch["ground_truth_mask"].to(device) |
|
|
| outputs = model( |
| pixel_values=pixel_values, |
| input_boxes=input_boxes, |
| multimask_output=False, |
| ) |
|
|
| pred_masks = outputs.pred_masks.squeeze(1).squeeze(1) |
| pred_probs = torch.sigmoid(pred_masks) |
| pred_binary = (pred_probs > 0.5).float() |
|
|
| all_preds.append(pred_binary.cpu().numpy()) |
| all_gts.append(gt_masks.cpu().numpy()) |
|
|
| for b in range(pred_binary.shape[0]): |
| p = pred_binary[b].cpu().numpy() |
| g = gt_masks[b].cpu().numpy() |
| intersection = np.logical_and(p, g).sum() |
| union = np.logical_or(p, g).sum() |
| iou = float(intersection / union) if union > 0 else 1.0 |
| all_ious.append(iou) |
|
|
| all_preds = np.concatenate(all_preds, axis=0) |
| all_gts = np.concatenate(all_gts, axis=0) |
|
|
| metrics = compute_all_metrics(all_preds, all_gts) |
| metrics["mean_iou_list"] = [float(v) for v in all_ious] |
| metrics["split"] = split |
| metrics["checkpoint"] = checkpoint_path if checkpoint_path else "original" |
|
|
| print("\n=== Evaluation Results ===") |
| for k, v in metrics.items(): |
| if isinstance(v, float): |
| print(f"{k}: {v:.4f}") |
|
|
| out_file = os.path.join(output_dir, f"metrics_{split}.json") |
| with open(out_file, "w") as f: |
| json.dump(metrics, f, indent=2) |
|
|
| np.save(os.path.join(output_dir, f"preds_{split}.npy"), all_preds) |
| np.save(os.path.join(output_dir, f"gts_{split}.npy"), all_gts) |
| print(f"Saved to {output_dir}") |
|
|
| return metrics |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint. Leave empty for original baseline.") |
| parser.add_argument("--data_dir", default="data/cmp_facade") |
| parser.add_argument("--split", default="test") |
| parser.add_argument("--batch_size", type=int, default=2) |
| parser.add_argument("--output_dir", default="outputs/eval") |
| args = parser.parse_args() |
|
|
| evaluate_checkpoint(args.checkpoint, args.data_dir, args.split, args.batch_size, args.output_dir) |
|
|