"""Evaluate trained CLIPSeg model and generate prediction masks + visuals.""" import json import time from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch import yaml from PIL import Image from torch.utils.data import DataLoader from tqdm import tqdm from src.data.dataset import DrywallSegDataset, collate_fn from src.model.clipseg_wrapper import load_model_and_processor from src.train import compute_metrics, get_device PROJECT_ROOT = Path(__file__).resolve().parents[1] def evaluate(config_path: str | None = None): config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml") with open(config_path) as f: config = yaml.safe_load(f) device = get_device() threshold = config["evaluation"]["threshold"] # Load model with best checkpoint model, processor = load_model_and_processor(config["model"]["name"], config["model"]["freeze_backbone"]) ckpt_path = PROJECT_ROOT / "outputs" / "checkpoints" / "best_model.pt" model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) model = model.to(device) model.eval() # Model size model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) # Test data splits_dir = PROJECT_ROOT / "data" / "splits" test_ds = DrywallSegDataset(str(splits_dir / "test.json"), processor, config["data"]["image_size"]) test_loader = DataLoader(test_ds, batch_size=config["training"]["batch_size"], shuffle=False, collate_fn=collate_fn, num_workers=0) # Run evaluation masks_dir = PROJECT_ROOT / "outputs" / "masks" masks_dir.mkdir(parents=True, exist_ok=True) all_metrics = {"taping": {"miou": [], "dice": []}, "cracks": {"miou": [], "dice": []}} inference_times = [] visual_examples = [] # Collect for visualization total_samples = 0 with torch.no_grad(): for batch in tqdm(test_loader, desc="Evaluating"): pixel_values = batch["pixel_values"].to(device) input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) t0 = time.time() outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) inference_times.append((time.time() - t0) / pixel_values.size(0)) logits = outputs.logits metrics = compute_metrics(logits, labels, threshold) preds = (torch.sigmoid(logits) > threshold).cpu().numpy().astype(np.uint8) for i in range(pixel_values.size(0)): ds_name = batch["dataset"][i] all_metrics[ds_name]["miou"].append(metrics["miou"]) all_metrics[ds_name]["dice"].append(metrics["dice"]) # Save prediction mask at original resolution orig_w, orig_h = batch["orig_width"][i], batch["orig_height"][i] pred_mask = Image.fromarray(preds[i] * 255, mode="L") pred_mask = pred_mask.resize((orig_w, orig_h), Image.NEAREST) prompt_slug = batch["prompt"][i].replace(" ", "_") img_stem = Path(batch["image_path"][i]).stem mask_filename = f"{img_stem}__{prompt_slug}.png" pred_mask.save(masks_dir / mask_filename) total_samples += 1 # Collect visual examples if len(visual_examples) < config["evaluation"]["num_visual_examples"]: visual_examples.append({ "image_path": batch["image_path"][i], "mask_path": batch["mask_path"][i], "pred_mask": preds[i], "prompt": batch["prompt"][i], "dataset": ds_name, }) # Aggregate metrics results = {"per_class": {}, "overall": {}} all_miou, all_dice = [], [] for ds_name in ["taping", "cracks"]: m = all_metrics[ds_name] if m["miou"]: results["per_class"][ds_name] = { "miou": round(float(np.mean(m["miou"])), 4), "dice": round(float(np.mean(m["dice"])), 4), "samples": len(m["miou"]), } all_miou.extend(m["miou"]) all_dice.extend(m["dice"]) results["overall"] = { "miou": round(float(np.mean(all_miou)), 4) if all_miou else 0, "dice": round(float(np.mean(all_dice)), 4) if all_dice else 0, "total_samples": total_samples, } results["runtime"] = { "avg_inference_ms": round(float(np.mean(inference_times)) * 1000, 1), "model_size_mb": round(model_size_mb, 1), } # Save results log_dir = PROJECT_ROOT / "outputs" / "logs" log_dir.mkdir(parents=True, exist_ok=True) with open(log_dir / "test_results.json", "w") as f: json.dump(results, f, indent=2) print(f"\n{'='*60}") print(f"Test Results") print(f"{'='*60}") for ds_name, m in results["per_class"].items(): print(f" {ds_name:>10s}: mIoU={m['miou']:.4f} Dice={m['dice']:.4f} (n={m['samples']})") print(f" {'overall':>10s}: mIoU={results['overall']['miou']:.4f} Dice={results['overall']['dice']:.4f}") print(f" Avg inference: {results['runtime']['avg_inference_ms']:.1f} ms/image") print(f" Model size: {results['runtime']['model_size_mb']:.1f} MB") # Generate visual comparison figures _generate_visuals(visual_examples, PROJECT_ROOT / "reports" / "figures") return results def _generate_visuals(examples: list[dict], output_dir: Path): """Generate original | GT | prediction comparison figures.""" output_dir.mkdir(parents=True, exist_ok=True) if not examples: return fig, axes = plt.subplots(len(examples), 3, figsize=(12, 4 * len(examples))) if len(examples) == 1: axes = [axes] for i, ex in enumerate(examples): img = Image.open(ex["image_path"]).convert("RGB") gt = Image.open(ex["mask_path"]).convert("L") pred = Image.fromarray(ex["pred_mask"] * 255, mode="L") axes[i][0].imshow(img) axes[i][0].set_title(f"Original ({ex['dataset']})") axes[i][0].axis("off") axes[i][1].imshow(gt, cmap="gray", vmin=0, vmax=255) axes[i][1].set_title("Ground Truth") axes[i][1].axis("off") axes[i][2].imshow(pred, cmap="gray", vmin=0, vmax=255) axes[i][2].set_title(f"Prediction: \"{ex['prompt']}\"") axes[i][2].axis("off") plt.tight_layout() plt.savefig(output_dir / "visual_comparison.png", dpi=150, bbox_inches="tight") plt.close() print(f"Saved visual comparison to {output_dir / 'visual_comparison.png'}") if __name__ == "__main__": evaluate()