from __future__ import annotations import argparse import re from pathlib import Path from shutil import copyfile import numpy as np import pandas as pd import torch from PIL import Image, ImageDraw, ImageFont from torch.utils.data import DataLoader from src.data.dataset import MonogramSegmentationDataset, collate_batch from src.metrics import binary_metrics from src.models import make_model def _tensor_rgb(image: torch.Tensor) -> np.ndarray: return ((image.detach().cpu().clamp(-1, 1) * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype("uint8") def _tensor_mask(mask: torch.Tensor) -> np.ndarray: return (mask.detach().cpu().clamp(0, 1).squeeze(0).numpy() * 255).astype("uint8") def _overlay(image: np.ndarray, mask: np.ndarray, color: tuple[int, int, int]) -> np.ndarray: color_arr = np.zeros_like(image) color_arr[..., 0] = color[0] color_arr[..., 1] = color[1] color_arr[..., 2] = color[2] return np.where(mask[..., None] > 0, (0.55 * image + 0.45 * color_arr).astype("uint8"), image) def _safe_name(value: str) -> str: return re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("_") or "sample" def _font(size: int, bold: bool = False) -> ImageFont.ImageFont: names = ["DejaVuSans-Bold.ttf", "DejaVuSans.ttf"] if bold else ["DejaVuSans.ttf"] for name in names: try: return ImageFont.truetype(name, size=size) except OSError: continue return ImageFont.load_default() def _draw_centered(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int], text: str, font: ImageFont.ImageFont, fill: str) -> None: left, top, right, bottom = box text_box = draw.textbbox((0, 0), text, font=font) text_w = text_box[2] - text_box[0] text_h = text_box[3] - text_box[1] draw.text((left + (right - left - text_w) / 2, top + (bottom - top - text_h) / 2), text, font=font, fill=fill) def _primary_visual_group(setting: str, row: pd.Series) -> str: quality = str(row.get("quality_label", "") or "unlabeled") collection = str(row.get("collection", "") or "unlabeled") if setting == "e1_cross_collection": return collection if setting == "e2_quality_stratified": return quality if setting == "e3_full_data": return f"{collection}__{quality}" return collection def _selected_visual_sample_ids(frame: pd.DataFrame, setting: str, count: int) -> set[str]: if count <= 0 or frame.empty: return set() group_to_ids: dict[str, list[str]] = {} for _, row in frame.iterrows(): group_to_ids.setdefault(_primary_visual_group(setting, row), []).append(str(row["sample_id"])) selected: list[str] = [] max_group_len = max(len(ids) for ids in group_to_ids.values()) for offset in range(max_group_len): for group in sorted(group_to_ids): ids = group_to_ids[group] if offset < len(ids): selected.append(ids[offset]) if len(selected) >= count: return set(selected) return set(selected) def _sample_visual( image: torch.Tensor, mask: torch.Tensor, pred: torch.Tensor, sample_id: str, collection: str, quality_label: str, ) -> Image.Image: rgb = _tensor_rgb(image) gt = _tensor_mask(mask) pred_mask = _tensor_mask(pred) panels = [ ("Input", Image.fromarray(rgb)), ("Ground Truth", Image.fromarray(gt).convert("RGB")), ("Prediction", Image.fromarray(pred_mask).convert("RGB")), ("Overlay", Image.fromarray(_overlay(rgb, pred_mask, (255, 0, 0)))), ] tile_w, tile_h = panels[0][1].size header_h = max(34, tile_h // 15) meta_h = max(46, tile_h // 12) title_font = _font(max(18, tile_w // 26), bold=True) meta_font = _font(max(16, tile_w // 32)) row = Image.new("RGB", (tile_w * len(panels), tile_h + header_h + meta_h), "white") draw = ImageDraw.Draw(row) draw.rectangle((0, 0, row.width, meta_h), fill=(245, 247, 250)) draw.text((12, 8), f"{sample_id}", font=title_font, fill=(18, 24, 38)) draw.text((12, 8 + max(20, tile_w // 24)), f"collection: {collection} quality: {quality_label}", font=meta_font, fill=(66, 75, 92)) for panel_idx, (label, panel) in enumerate(panels): left = panel_idx * tile_w draw.rectangle((left, meta_h, left + tile_w, meta_h + header_h), fill=(32, 37, 50)) _draw_centered(draw, (left, meta_h, left + tile_w, meta_h + header_h), label, title_font, "white") row.paste(panel, (left, meta_h + header_h)) return row def _contact_sheet(visuals: list[Image.Image], path: Path) -> None: if not visuals: return path.parent.mkdir(parents=True, exist_ok=True) width = max(image.width for image in visuals) height = sum(image.height for image in visuals) sheet = Image.new("RGB", (width, height), "white") top = 0 for image in visuals: sheet.paste(image, (0, top)) top += image.height sheet.save(path, quality=95) def _experiment_groups(setting: str, collection: str, quality_label: str) -> dict[str, str]: quality = quality_label if quality_label and quality_label.lower() != "nan" else "unlabeled" if setting == "e1_cross_collection": return {"by_collection": collection} if setting == "e2_quality_stratified": return {"by_quality": quality} if setting == "e3_full_data": return { "by_collection": collection, "by_quality": quality, "by_collection_quality": f"{collection}__{quality}", } return {"by_collection": collection, "by_quality": quality} def _save_grouped_visual( visual: Image.Image, visual_path: Path, visual_dir: Path, visual_name: str, groups: dict[str, str], grouped_visuals: dict[tuple[str, str], list[Image.Image]], ) -> None: for group_name, group_value in groups.items(): group_dir = visual_dir / group_name / _safe_name(group_value) group_dir.mkdir(parents=True, exist_ok=True) copyfile(visual_path, group_dir / visual_name) grouped_visuals.setdefault((group_name, group_value), []).append(visual) def _save_eval_visuals(visuals: list[Image.Image], grouped_visuals: dict[tuple[str, str], list[Image.Image]], out_dir: Path) -> None: visual_dir = out_dir / "visualizations" _contact_sheet(visuals, visual_dir / "contact_sheet.jpg") for (group_name, group_value), group_images in grouped_visuals.items(): _contact_sheet(group_images, visual_dir / group_name / _safe_name(group_value) / "contact_sheet.jpg") def evaluate(args: argparse.Namespace) -> None: device = torch.device(args.device if args.device != "cuda" or torch.cuda.is_available() else "cpu") frame = pd.read_csv(args.csv) if args.limit is not None: frame = frame.head(args.limit).copy() visual_sample_ids = _selected_visual_sample_ids(frame, args.setting, args.visualize_count) dataset = MonogramSegmentationDataset(frame, image_size=args.image_size, augment=False) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=collate_batch) checkpoint = torch.load(args.checkpoint, map_location="cpu") model_args = checkpoint.get("args", {}) model_name = args.model or model_args.get("model", "unet") model_kwargs = {} if model_name == "unet": model_kwargs["encoder_name"] = model_args.get("encoder_name", "resnet34") model_kwargs["encoder_weights"] = None elif model_name == "segformer": model_kwargs["model_name"] = model_args.get("segformer_model_name", "nvidia/mit-b2") model_kwargs["pretrained"] = False model_kwargs["freeze_encoder"] = bool(model_args.get("segformer_freeze_encoder", False)) model_kwargs["unfreeze_last_n_blocks"] = int(model_args.get("segformer_unfreeze_last_n_blocks", 2)) model = make_model(model_name, **model_kwargs).to(device) model.load_state_dict(checkpoint["model"], strict=False) model.eval() out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) visual_dir = out_dir / "visualizations" visual_dir.mkdir(parents=True, exist_ok=True) visuals: list[Image.Image] = [] grouped_visuals: dict[tuple[str, str], list[Image.Image]] = {} rows: list[dict] = [] with torch.no_grad(): for batch in loader: image = batch["image"].to(device) mask = batch["mask"].to(device) logits = model(image) pred = (torch.sigmoid(logits) >= args.threshold).float() metric_rows = binary_metrics(logits, mask, args.threshold) for idx, metrics in enumerate(metric_rows): sample_id = str(batch["sample_id"][idx]) collection = str(batch["collection"][idx]) quality_label = str(batch["quality_label"][idx]) rows.append( { "sample_id": sample_id, "collection": collection, "quality_label": quality_label, **metrics, } ) if sample_id in visual_sample_ids: visual = _sample_visual(batch["image"][idx], batch["mask"][idx], pred[idx], sample_id, collection, quality_label) visuals.append(visual) visual_name = f"{len(visuals):03d}_{_safe_name(sample_id)}.jpg" visual_path = visual_dir / visual_name visual.save(visual_path, quality=95) _save_grouped_visual( visual, visual_path, visual_dir, visual_name, _experiment_groups(args.setting, collection, quality_label), grouped_visuals, ) _save_eval_visuals(visuals, grouped_visuals, out_dir) metrics = pd.DataFrame(rows) metrics.to_csv(out_dir / "metrics.csv", index=False) metrics.groupby("collection", dropna=False)[["iou", "dice", "cldice", "precision", "recall"]].mean().to_csv(out_dir / "metrics_by_collection.csv") metrics.assign(quality_label=metrics["quality_label"].fillna("").replace("", "unlabeled")).groupby("quality_label", dropna=False)[ ["iou", "dice", "cldice", "precision", "recall"] ].mean().to_csv(out_dir / "metrics_by_quality.csv") metrics[["iou", "dice", "cldice", "precision", "recall"]].mean().to_frame("mean").to_csv(out_dir / "metrics_summary.csv") def main() -> None: parser = argparse.ArgumentParser(description="Evaluate a monogram segmentation checkpoint.") parser.add_argument("--checkpoint", required=True) parser.add_argument("--csv", default="splits/e3_full_data/test.csv") parser.add_argument("--output-dir", default="outputs/eval") parser.add_argument("--model", choices=["unet", "segformer"]) parser.add_argument("--image-size", type=int, default=512) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--threshold", type=float, default=0.5) parser.add_argument("--device", default="cuda") parser.add_argument("--no-pretrained", action="store_true") parser.add_argument("--limit", type=int) parser.add_argument("--visualize-count", type=int, default=32) parser.add_argument("--setting", default="e3_full_data") evaluate(parser.parse_args()) if __name__ == "__main__": main()