segment-monograms / src /evaluate.py
Saranga7's picture
Deploy monogram segmentation demo
bcc432f verified
Raw
History Blame Contribute Delete
11.5 kB
from __future__ import annotations
import argparse
import re
from pathlib import Path
from shutil import copyfile
import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import DataLoader
from src.data.dataset import MonogramSegmentationDataset, collate_batch
from src.metrics import binary_metrics
from src.models import make_model
def _tensor_rgb(image: torch.Tensor) -> np.ndarray:
return ((image.detach().cpu().clamp(-1, 1) * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype("uint8")
def _tensor_mask(mask: torch.Tensor) -> np.ndarray:
return (mask.detach().cpu().clamp(0, 1).squeeze(0).numpy() * 255).astype("uint8")
def _overlay(image: np.ndarray, mask: np.ndarray, color: tuple[int, int, int]) -> np.ndarray:
color_arr = np.zeros_like(image)
color_arr[..., 0] = color[0]
color_arr[..., 1] = color[1]
color_arr[..., 2] = color[2]
return np.where(mask[..., None] > 0, (0.55 * image + 0.45 * color_arr).astype("uint8"), image)
def _safe_name(value: str) -> str:
return re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("_") or "sample"
def _font(size: int, bold: bool = False) -> ImageFont.ImageFont:
names = ["DejaVuSans-Bold.ttf", "DejaVuSans.ttf"] if bold else ["DejaVuSans.ttf"]
for name in names:
try:
return ImageFont.truetype(name, size=size)
except OSError:
continue
return ImageFont.load_default()
def _draw_centered(draw: ImageDraw.ImageDraw, box: tuple[int, int, int, int], text: str, font: ImageFont.ImageFont, fill: str) -> None:
left, top, right, bottom = box
text_box = draw.textbbox((0, 0), text, font=font)
text_w = text_box[2] - text_box[0]
text_h = text_box[3] - text_box[1]
draw.text((left + (right - left - text_w) / 2, top + (bottom - top - text_h) / 2), text, font=font, fill=fill)
def _primary_visual_group(setting: str, row: pd.Series) -> str:
quality = str(row.get("quality_label", "") or "unlabeled")
collection = str(row.get("collection", "") or "unlabeled")
if setting == "e1_cross_collection":
return collection
if setting == "e2_quality_stratified":
return quality
if setting == "e3_full_data":
return f"{collection}__{quality}"
return collection
def _selected_visual_sample_ids(frame: pd.DataFrame, setting: str, count: int) -> set[str]:
if count <= 0 or frame.empty:
return set()
group_to_ids: dict[str, list[str]] = {}
for _, row in frame.iterrows():
group_to_ids.setdefault(_primary_visual_group(setting, row), []).append(str(row["sample_id"]))
selected: list[str] = []
max_group_len = max(len(ids) for ids in group_to_ids.values())
for offset in range(max_group_len):
for group in sorted(group_to_ids):
ids = group_to_ids[group]
if offset < len(ids):
selected.append(ids[offset])
if len(selected) >= count:
return set(selected)
return set(selected)
def _sample_visual(
image: torch.Tensor,
mask: torch.Tensor,
pred: torch.Tensor,
sample_id: str,
collection: str,
quality_label: str,
) -> Image.Image:
rgb = _tensor_rgb(image)
gt = _tensor_mask(mask)
pred_mask = _tensor_mask(pred)
panels = [
("Input", Image.fromarray(rgb)),
("Ground Truth", Image.fromarray(gt).convert("RGB")),
("Prediction", Image.fromarray(pred_mask).convert("RGB")),
("Overlay", Image.fromarray(_overlay(rgb, pred_mask, (255, 0, 0)))),
]
tile_w, tile_h = panels[0][1].size
header_h = max(34, tile_h // 15)
meta_h = max(46, tile_h // 12)
title_font = _font(max(18, tile_w // 26), bold=True)
meta_font = _font(max(16, tile_w // 32))
row = Image.new("RGB", (tile_w * len(panels), tile_h + header_h + meta_h), "white")
draw = ImageDraw.Draw(row)
draw.rectangle((0, 0, row.width, meta_h), fill=(245, 247, 250))
draw.text((12, 8), f"{sample_id}", font=title_font, fill=(18, 24, 38))
draw.text((12, 8 + max(20, tile_w // 24)), f"collection: {collection} quality: {quality_label}", font=meta_font, fill=(66, 75, 92))
for panel_idx, (label, panel) in enumerate(panels):
left = panel_idx * tile_w
draw.rectangle((left, meta_h, left + tile_w, meta_h + header_h), fill=(32, 37, 50))
_draw_centered(draw, (left, meta_h, left + tile_w, meta_h + header_h), label, title_font, "white")
row.paste(panel, (left, meta_h + header_h))
return row
def _contact_sheet(visuals: list[Image.Image], path: Path) -> None:
if not visuals:
return
path.parent.mkdir(parents=True, exist_ok=True)
width = max(image.width for image in visuals)
height = sum(image.height for image in visuals)
sheet = Image.new("RGB", (width, height), "white")
top = 0
for image in visuals:
sheet.paste(image, (0, top))
top += image.height
sheet.save(path, quality=95)
def _experiment_groups(setting: str, collection: str, quality_label: str) -> dict[str, str]:
quality = quality_label if quality_label and quality_label.lower() != "nan" else "unlabeled"
if setting == "e1_cross_collection":
return {"by_collection": collection}
if setting == "e2_quality_stratified":
return {"by_quality": quality}
if setting == "e3_full_data":
return {
"by_collection": collection,
"by_quality": quality,
"by_collection_quality": f"{collection}__{quality}",
}
return {"by_collection": collection, "by_quality": quality}
def _save_grouped_visual(
visual: Image.Image,
visual_path: Path,
visual_dir: Path,
visual_name: str,
groups: dict[str, str],
grouped_visuals: dict[tuple[str, str], list[Image.Image]],
) -> None:
for group_name, group_value in groups.items():
group_dir = visual_dir / group_name / _safe_name(group_value)
group_dir.mkdir(parents=True, exist_ok=True)
copyfile(visual_path, group_dir / visual_name)
grouped_visuals.setdefault((group_name, group_value), []).append(visual)
def _save_eval_visuals(visuals: list[Image.Image], grouped_visuals: dict[tuple[str, str], list[Image.Image]], out_dir: Path) -> None:
visual_dir = out_dir / "visualizations"
_contact_sheet(visuals, visual_dir / "contact_sheet.jpg")
for (group_name, group_value), group_images in grouped_visuals.items():
_contact_sheet(group_images, visual_dir / group_name / _safe_name(group_value) / "contact_sheet.jpg")
def evaluate(args: argparse.Namespace) -> None:
device = torch.device(args.device if args.device != "cuda" or torch.cuda.is_available() else "cpu")
frame = pd.read_csv(args.csv)
if args.limit is not None:
frame = frame.head(args.limit).copy()
visual_sample_ids = _selected_visual_sample_ids(frame, args.setting, args.visualize_count)
dataset = MonogramSegmentationDataset(frame, image_size=args.image_size, augment=False)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=collate_batch)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model_args = checkpoint.get("args", {})
model_name = args.model or model_args.get("model", "unet")
model_kwargs = {}
if model_name == "unet":
model_kwargs["encoder_name"] = model_args.get("encoder_name", "resnet34")
model_kwargs["encoder_weights"] = None
elif model_name == "segformer":
model_kwargs["model_name"] = model_args.get("segformer_model_name", "nvidia/mit-b2")
model_kwargs["pretrained"] = False
model_kwargs["freeze_encoder"] = bool(model_args.get("segformer_freeze_encoder", False))
model_kwargs["unfreeze_last_n_blocks"] = int(model_args.get("segformer_unfreeze_last_n_blocks", 2))
model = make_model(model_name, **model_kwargs).to(device)
model.load_state_dict(checkpoint["model"], strict=False)
model.eval()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
visual_dir = out_dir / "visualizations"
visual_dir.mkdir(parents=True, exist_ok=True)
visuals: list[Image.Image] = []
grouped_visuals: dict[tuple[str, str], list[Image.Image]] = {}
rows: list[dict] = []
with torch.no_grad():
for batch in loader:
image = batch["image"].to(device)
mask = batch["mask"].to(device)
logits = model(image)
pred = (torch.sigmoid(logits) >= args.threshold).float()
metric_rows = binary_metrics(logits, mask, args.threshold)
for idx, metrics in enumerate(metric_rows):
sample_id = str(batch["sample_id"][idx])
collection = str(batch["collection"][idx])
quality_label = str(batch["quality_label"][idx])
rows.append(
{
"sample_id": sample_id,
"collection": collection,
"quality_label": quality_label,
**metrics,
}
)
if sample_id in visual_sample_ids:
visual = _sample_visual(batch["image"][idx], batch["mask"][idx], pred[idx], sample_id, collection, quality_label)
visuals.append(visual)
visual_name = f"{len(visuals):03d}_{_safe_name(sample_id)}.jpg"
visual_path = visual_dir / visual_name
visual.save(visual_path, quality=95)
_save_grouped_visual(
visual,
visual_path,
visual_dir,
visual_name,
_experiment_groups(args.setting, collection, quality_label),
grouped_visuals,
)
_save_eval_visuals(visuals, grouped_visuals, out_dir)
metrics = pd.DataFrame(rows)
metrics.to_csv(out_dir / "metrics.csv", index=False)
metrics.groupby("collection", dropna=False)[["iou", "dice", "cldice", "precision", "recall"]].mean().to_csv(out_dir / "metrics_by_collection.csv")
metrics.assign(quality_label=metrics["quality_label"].fillna("").replace("", "unlabeled")).groupby("quality_label", dropna=False)[
["iou", "dice", "cldice", "precision", "recall"]
].mean().to_csv(out_dir / "metrics_by_quality.csv")
metrics[["iou", "dice", "cldice", "precision", "recall"]].mean().to_frame("mean").to_csv(out_dir / "metrics_summary.csv")
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate a monogram segmentation checkpoint.")
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--csv", default="splits/e3_full_data/test.csv")
parser.add_argument("--output-dir", default="outputs/eval")
parser.add_argument("--model", choices=["unet", "segformer"])
parser.add_argument("--image-size", type=int, default=512)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--threshold", type=float, default=0.5)
parser.add_argument("--device", default="cuda")
parser.add_argument("--no-pretrained", action="store_true")
parser.add_argument("--limit", type=int)
parser.add_argument("--visualize-count", type=int, default=32)
parser.add_argument("--setting", default="e3_full_data")
evaluate(parser.parse_args())
if __name__ == "__main__":
main()