segment-monograms / src /data /dataset.py
Saranga7's picture
Deploy monogram segmentation demo
bcc432f verified
Raw
History Blame Contribute Delete
8.23 kB
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from .transforms import (
annotation_path,
apply_pair_augment,
crop_pair,
find_image_by_stem,
image_transform,
list_images,
mask_transform,
quality_key,
)
DEFAULT_ADOLFO_ROOT = Path("/scratch/mahantas/datasets/MonogramSchema_Seal_pairs")
DEFAULT_OTHER_ROOT = Path("/scratch/mahantas/datasets/285_monograms_segmentations")
DEFAULT_METADATA_CSV = Path("/scratch/mahantas/datasets/metadata_merged.csv")
@dataclass(frozen=True)
class DatasetRoots:
adolfo_root: Path = DEFAULT_ADOLFO_ROOT
other_root: Path = DEFAULT_OTHER_ROOT
metadata_csv: Path = DEFAULT_METADATA_CSV
def _load_metadata(path: Path | None) -> dict[str, dict[str, Any]]:
if path is None or not path.exists():
return {}
frame = pd.read_csv(path)
required = {"monogram_id", "quality_label"}
missing = required.difference(frame.columns)
if missing:
raise ValueError(f"{path} is missing required columns: {sorted(missing)}")
return {str(row["monogram_id"]).strip(): row for row in frame.to_dict("records")}
def _clean_label(value: Any) -> str:
if pd.isna(value):
return ""
return str(value)
def _metadata_value(metadata: dict[str, dict[str, Any]], source_id: str, key: str, default: Any = "") -> Any:
value = metadata.get(source_id, {}).get(key, default)
return default if pd.isna(value) else value
def _adolfo_manifest(root: Path, merged_metadata: dict[str, dict[str, Any]]) -> pd.DataFrame:
metadata = pd.read_csv(root / "metadata.csv")
rows: list[dict[str, Any]] = []
for row in metadata.to_dict("records"):
sample_id = str(row["monogram_id"]).strip()
quality = quality_key(_metadata_value(merged_metadata, sample_id, "quality_label", row.get("quality_label")))
if quality == "q3":
continue
try:
seal_path = find_image_by_stem(root / "seals", sample_id)
mask_path = find_image_by_stem(root / "binary_masks", sample_id)
ann_path = annotation_path(root / "ann", seal_path)
except FileNotFoundError:
continue
rows.append(
{
"sample_id": f"adolfo__{sample_id}",
"source_id": sample_id,
"collection": "adolfo",
"seal_path": str(seal_path),
"mask_path": str(mask_path),
"ann_path": str(ann_path),
"quality_label": quality,
"monogram_type": _metadata_value(merged_metadata, sample_id, "monogram_type", row.get("monogram_type", "")),
"broad_monogram_type": _metadata_value(
merged_metadata,
sample_id,
"broad_monogram_type",
row.get("broad_monogram_type", ""),
),
}
)
frame = pd.DataFrame(rows)
if len(frame) != 332:
raise ValueError(f"Expected 332 usable Adolfo q0-q2 masks, found {len(frame)}")
return frame
def _other_manifest(root: Path, merged_metadata: dict[str, dict[str, Any]]) -> pd.DataFrame:
rows: list[dict[str, Any]] = []
for collection_dir in sorted(path for path in root.iterdir() if path.is_dir()):
collection = collection_dir.name
image_dir = collection_dir / "img"
mask_dir = collection_dir / "binary_masks"
ann_dir = collection_dir / "ann"
for image_path in list_images(image_dir):
sample_id = f"{collection}__{image_path.stem}"
source_id = image_path.stem
quality = quality_key(_metadata_value(merged_metadata, source_id, "quality_label", ""))
if quality == "q3":
continue
mask_path = find_image_by_stem(mask_dir, image_path.stem)
ann_path = annotation_path(ann_dir, image_path)
rows.append(
{
"sample_id": sample_id,
"source_id": source_id,
"collection": collection,
"seal_path": str(image_path),
"mask_path": str(mask_path),
"ann_path": str(ann_path),
"quality_label": quality,
"monogram_type": _metadata_value(merged_metadata, source_id, "monogram_type", ""),
"broad_monogram_type": _metadata_value(merged_metadata, source_id, "broad_monogram_type", ""),
}
)
frame = pd.DataFrame(rows)
if len(frame) != 273:
raise ValueError(f"Expected 273 other-collection q0-q2 masks, found {len(frame)}")
return frame
def build_sample_manifest(
adolfo_root: str | Path = DEFAULT_ADOLFO_ROOT,
other_root: str | Path = DEFAULT_OTHER_ROOT,
metadata_csv: str | Path | None = DEFAULT_METADATA_CSV,
) -> pd.DataFrame:
metadata_path = Path(metadata_csv) if metadata_csv is not None else None
merged_metadata = _load_metadata(metadata_path)
frame = pd.concat(
[_adolfo_manifest(Path(adolfo_root), merged_metadata), _other_manifest(Path(other_root), merged_metadata)],
ignore_index=True,
sort=False,
)
if len(frame) != 605:
raise ValueError(f"Expected 605 total q0-q2 supervised masks, found {len(frame)}")
if frame["sample_id"].duplicated().any():
dupes = frame.loc[frame["sample_id"].duplicated(), "sample_id"].tolist()
raise ValueError(f"Duplicate sample IDs: {dupes[:10]}")
return frame
class MonogramSegmentationDataset(Dataset):
def __init__(
self,
frame: pd.DataFrame,
image_size: int = 512,
crop_padding: float = 0.08,
augment: bool = False,
):
self.frame = frame.reset_index(drop=True)
self.image_size = image_size
self.crop_padding = crop_padding
self.augment = augment
self.image_t = image_transform()
self.mask_t = mask_transform()
def __len__(self) -> int:
return len(self.frame)
def __getitem__(self, idx: int) -> dict[str, Any]:
row = self.frame.iloc[idx].to_dict()
image = Image.open(row["seal_path"]).convert("RGB")
mask = Image.open(row["mask_path"]).convert("L")
crop_jitter = 0.06 if self.augment else 0.0
crop_scale_jitter = 0.05 if self.augment else 0.0
image, mask = crop_pair(
image,
mask,
row.get("ann_path"),
self.image_size,
self.crop_padding,
crop_jitter=crop_jitter,
crop_scale_jitter=crop_scale_jitter,
)
if self.augment:
image, mask = apply_pair_augment(image, mask)
image_tensor = self.image_t(image)
mask_tensor = self.mask_t(mask)
return {
"image": image_tensor,
"mask": mask_tensor,
"sample_id": str(row["sample_id"]),
"collection": str(row["collection"]),
"quality_label": _clean_label(row.get("quality_label", "")),
"monogram_type": _clean_label(row.get("monogram_type", "")),
"broad_monogram_type": _clean_label(row.get("broad_monogram_type", "")),
"seal_path": str(row["seal_path"]),
"mask_path": str(row["mask_path"]),
"ann_path": str(row["ann_path"]),
}
def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]:
return {
"image": torch.stack([item["image"] for item in batch]),
"mask": torch.stack([item["mask"] for item in batch]),
"sample_id": [item["sample_id"] for item in batch],
"collection": [item["collection"] for item in batch],
"quality_label": [item["quality_label"] for item in batch],
"monogram_type": [item["monogram_type"] for item in batch],
"broad_monogram_type": [item["broad_monogram_type"] for item in batch],
"seal_path": [item["seal_path"] for item in batch],
"mask_path": [item["mask_path"] for item in batch],
"ann_path": [item["ann_path"] for item in batch],
}