| """Evaluate trained CLIPSeg model and generate prediction masks + visuals.""" |
|
|
| import json |
| import time |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import yaml |
| from PIL import Image |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from src.data.dataset import DrywallSegDataset, collate_fn |
| from src.model.clipseg_wrapper import load_model_and_processor |
| from src.train import compute_metrics, get_device |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
|
|
|
|
| def evaluate(config_path: str | None = None): |
| config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml") |
| with open(config_path) as f: |
| config = yaml.safe_load(f) |
|
|
| device = get_device() |
| threshold = config["evaluation"]["threshold"] |
|
|
| |
| model, processor = load_model_and_processor(config["model"]["name"], config["model"]["freeze_backbone"]) |
| ckpt_path = PROJECT_ROOT / "outputs" / "checkpoints" / "best_model.pt" |
| model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) |
| model = model.to(device) |
| model.eval() |
|
|
| |
| model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) |
|
|
| |
| splits_dir = PROJECT_ROOT / "data" / "splits" |
| test_ds = DrywallSegDataset(str(splits_dir / "test.json"), processor, config["data"]["image_size"]) |
| test_loader = DataLoader(test_ds, batch_size=config["training"]["batch_size"], shuffle=False, |
| collate_fn=collate_fn, num_workers=0) |
|
|
| |
| masks_dir = PROJECT_ROOT / "outputs" / "masks" |
| masks_dir.mkdir(parents=True, exist_ok=True) |
|
|
| all_metrics = {"taping": {"miou": [], "dice": []}, "cracks": {"miou": [], "dice": []}} |
| inference_times = [] |
| visual_examples = [] |
| total_samples = 0 |
|
|
| with torch.no_grad(): |
| for batch in tqdm(test_loader, desc="Evaluating"): |
| pixel_values = batch["pixel_values"].to(device) |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| t0 = time.time() |
| outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) |
| inference_times.append((time.time() - t0) / pixel_values.size(0)) |
|
|
| logits = outputs.logits |
| metrics = compute_metrics(logits, labels, threshold) |
| preds = (torch.sigmoid(logits) > threshold).cpu().numpy().astype(np.uint8) |
|
|
| for i in range(pixel_values.size(0)): |
| ds_name = batch["dataset"][i] |
| all_metrics[ds_name]["miou"].append(metrics["miou"]) |
| all_metrics[ds_name]["dice"].append(metrics["dice"]) |
|
|
| |
| orig_w, orig_h = batch["orig_width"][i], batch["orig_height"][i] |
| pred_mask = Image.fromarray(preds[i] * 255, mode="L") |
| pred_mask = pred_mask.resize((orig_w, orig_h), Image.NEAREST) |
|
|
| prompt_slug = batch["prompt"][i].replace(" ", "_") |
| img_stem = Path(batch["image_path"][i]).stem |
| mask_filename = f"{img_stem}__{prompt_slug}.png" |
| pred_mask.save(masks_dir / mask_filename) |
|
|
| total_samples += 1 |
|
|
| |
| if len(visual_examples) < config["evaluation"]["num_visual_examples"]: |
| visual_examples.append({ |
| "image_path": batch["image_path"][i], |
| "mask_path": batch["mask_path"][i], |
| "pred_mask": preds[i], |
| "prompt": batch["prompt"][i], |
| "dataset": ds_name, |
| }) |
|
|
| |
| results = {"per_class": {}, "overall": {}} |
| all_miou, all_dice = [], [] |
| for ds_name in ["taping", "cracks"]: |
| m = all_metrics[ds_name] |
| if m["miou"]: |
| results["per_class"][ds_name] = { |
| "miou": round(float(np.mean(m["miou"])), 4), |
| "dice": round(float(np.mean(m["dice"])), 4), |
| "samples": len(m["miou"]), |
| } |
| all_miou.extend(m["miou"]) |
| all_dice.extend(m["dice"]) |
|
|
| results["overall"] = { |
| "miou": round(float(np.mean(all_miou)), 4) if all_miou else 0, |
| "dice": round(float(np.mean(all_dice)), 4) if all_dice else 0, |
| "total_samples": total_samples, |
| } |
| results["runtime"] = { |
| "avg_inference_ms": round(float(np.mean(inference_times)) * 1000, 1), |
| "model_size_mb": round(model_size_mb, 1), |
| } |
|
|
| |
| log_dir = PROJECT_ROOT / "outputs" / "logs" |
| log_dir.mkdir(parents=True, exist_ok=True) |
| with open(log_dir / "test_results.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print(f"\n{'='*60}") |
| print(f"Test Results") |
| print(f"{'='*60}") |
| for ds_name, m in results["per_class"].items(): |
| print(f" {ds_name:>10s}: mIoU={m['miou']:.4f} Dice={m['dice']:.4f} (n={m['samples']})") |
| print(f" {'overall':>10s}: mIoU={results['overall']['miou']:.4f} Dice={results['overall']['dice']:.4f}") |
| print(f" Avg inference: {results['runtime']['avg_inference_ms']:.1f} ms/image") |
| print(f" Model size: {results['runtime']['model_size_mb']:.1f} MB") |
|
|
| |
| _generate_visuals(visual_examples, PROJECT_ROOT / "reports" / "figures") |
|
|
| return results |
|
|
|
|
| def _generate_visuals(examples: list[dict], output_dir: Path): |
| """Generate original | GT | prediction comparison figures.""" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if not examples: |
| return |
|
|
| fig, axes = plt.subplots(len(examples), 3, figsize=(12, 4 * len(examples))) |
| if len(examples) == 1: |
| axes = [axes] |
|
|
| for i, ex in enumerate(examples): |
| img = Image.open(ex["image_path"]).convert("RGB") |
| gt = Image.open(ex["mask_path"]).convert("L") |
| pred = Image.fromarray(ex["pred_mask"] * 255, mode="L") |
|
|
| axes[i][0].imshow(img) |
| axes[i][0].set_title(f"Original ({ex['dataset']})") |
| axes[i][0].axis("off") |
|
|
| axes[i][1].imshow(gt, cmap="gray", vmin=0, vmax=255) |
| axes[i][1].set_title("Ground Truth") |
| axes[i][1].axis("off") |
|
|
| axes[i][2].imshow(pred, cmap="gray", vmin=0, vmax=255) |
| axes[i][2].set_title(f"Prediction: \"{ex['prompt']}\"") |
| axes[i][2].axis("off") |
|
|
| plt.tight_layout() |
| plt.savefig(output_dir / "visual_comparison.png", dpi=150, bbox_inches="tight") |
| plt.close() |
| print(f"Saved visual comparison to {output_dir / 'visual_comparison.png'}") |
|
|
|
|
| if __name__ == "__main__": |
| evaluate() |
|
|