File size: 3,732 Bytes
3cc53ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | """
Evaluate a SAM ViT-H checkpoint (original or fine-tuned) on facade segmentation.
"""
import os
import argparse
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import SamModel, SamProcessor
from tqdm import tqdm
from dataset import FacadeDataset, collate_fn
from metrics import compute_all_metrics
@torch.no_grad()
def evaluate_checkpoint(checkpoint_path, data_dir, split="test", batch_size=1, output_dir="outputs/eval"):
os.makedirs(output_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
if checkpoint_path and os.path.exists(checkpoint_path):
print(f"Loading checkpoint: {checkpoint_path}")
state = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(state)
else:
print("Evaluating original pre-trained SAM ViT-H (no checkpoint loaded)")
model.eval()
dataset = FacadeDataset(data_dir, split=split, processor=processor, augment=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
all_preds = []
all_gts = []
all_ious = []
print(f"Evaluating on {len(dataset)} samples...")
for batch in tqdm(dataloader, desc="Eval"):
pixel_values = batch["pixel_values"].to(device)
input_boxes = batch["input_boxes"].to(device)
gt_masks = batch["ground_truth_mask"].to(device)
outputs = model(
pixel_values=pixel_values,
input_boxes=input_boxes,
multimask_output=False,
)
pred_masks = outputs.pred_masks.squeeze(1).squeeze(1)
pred_probs = torch.sigmoid(pred_masks)
pred_binary = (pred_probs > 0.5).float()
all_preds.append(pred_binary.cpu().numpy())
all_gts.append(gt_masks.cpu().numpy())
for b in range(pred_binary.shape[0]):
p = pred_binary[b].cpu().numpy()
g = gt_masks[b].cpu().numpy()
intersection = np.logical_and(p, g).sum()
union = np.logical_or(p, g).sum()
iou = float(intersection / union) if union > 0 else 1.0
all_ious.append(iou)
all_preds = np.concatenate(all_preds, axis=0)
all_gts = np.concatenate(all_gts, axis=0)
metrics = compute_all_metrics(all_preds, all_gts)
metrics["mean_iou_list"] = [float(v) for v in all_ious]
metrics["split"] = split
metrics["checkpoint"] = checkpoint_path if checkpoint_path else "original"
print("\n=== Evaluation Results ===")
for k, v in metrics.items():
if isinstance(v, float):
print(f"{k}: {v:.4f}")
out_file = os.path.join(output_dir, f"metrics_{split}.json")
with open(out_file, "w") as f:
json.dump(metrics, f, indent=2)
np.save(os.path.join(output_dir, f"preds_{split}.npy"), all_preds)
np.save(os.path.join(output_dir, f"gts_{split}.npy"), all_gts)
print(f"Saved to {output_dir}")
return metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="", help="Path to fine-tuned checkpoint. Leave empty for original baseline.")
parser.add_argument("--data_dir", default="data/cmp_facade")
parser.add_argument("--split", default="test")
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--output_dir", default="outputs/eval")
args = parser.parse_args()
evaluate_checkpoint(args.checkpoint, args.data_dir, args.split, args.batch_size, args.output_dir)
|