Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from .transforms import ( | |
| annotation_path, | |
| apply_pair_augment, | |
| crop_pair, | |
| find_image_by_stem, | |
| image_transform, | |
| list_images, | |
| mask_transform, | |
| quality_key, | |
| ) | |
| DEFAULT_ADOLFO_ROOT = Path("/scratch/mahantas/datasets/MonogramSchema_Seal_pairs") | |
| DEFAULT_OTHER_ROOT = Path("/scratch/mahantas/datasets/285_monograms_segmentations") | |
| DEFAULT_METADATA_CSV = Path("/scratch/mahantas/datasets/metadata_merged.csv") | |
| class DatasetRoots: | |
| adolfo_root: Path = DEFAULT_ADOLFO_ROOT | |
| other_root: Path = DEFAULT_OTHER_ROOT | |
| metadata_csv: Path = DEFAULT_METADATA_CSV | |
| def _load_metadata(path: Path | None) -> dict[str, dict[str, Any]]: | |
| if path is None or not path.exists(): | |
| return {} | |
| frame = pd.read_csv(path) | |
| required = {"monogram_id", "quality_label"} | |
| missing = required.difference(frame.columns) | |
| if missing: | |
| raise ValueError(f"{path} is missing required columns: {sorted(missing)}") | |
| return {str(row["monogram_id"]).strip(): row for row in frame.to_dict("records")} | |
| def _clean_label(value: Any) -> str: | |
| if pd.isna(value): | |
| return "" | |
| return str(value) | |
| def _metadata_value(metadata: dict[str, dict[str, Any]], source_id: str, key: str, default: Any = "") -> Any: | |
| value = metadata.get(source_id, {}).get(key, default) | |
| return default if pd.isna(value) else value | |
| def _adolfo_manifest(root: Path, merged_metadata: dict[str, dict[str, Any]]) -> pd.DataFrame: | |
| metadata = pd.read_csv(root / "metadata.csv") | |
| rows: list[dict[str, Any]] = [] | |
| for row in metadata.to_dict("records"): | |
| sample_id = str(row["monogram_id"]).strip() | |
| quality = quality_key(_metadata_value(merged_metadata, sample_id, "quality_label", row.get("quality_label"))) | |
| if quality == "q3": | |
| continue | |
| try: | |
| seal_path = find_image_by_stem(root / "seals", sample_id) | |
| mask_path = find_image_by_stem(root / "binary_masks", sample_id) | |
| ann_path = annotation_path(root / "ann", seal_path) | |
| except FileNotFoundError: | |
| continue | |
| rows.append( | |
| { | |
| "sample_id": f"adolfo__{sample_id}", | |
| "source_id": sample_id, | |
| "collection": "adolfo", | |
| "seal_path": str(seal_path), | |
| "mask_path": str(mask_path), | |
| "ann_path": str(ann_path), | |
| "quality_label": quality, | |
| "monogram_type": _metadata_value(merged_metadata, sample_id, "monogram_type", row.get("monogram_type", "")), | |
| "broad_monogram_type": _metadata_value( | |
| merged_metadata, | |
| sample_id, | |
| "broad_monogram_type", | |
| row.get("broad_monogram_type", ""), | |
| ), | |
| } | |
| ) | |
| frame = pd.DataFrame(rows) | |
| if len(frame) != 332: | |
| raise ValueError(f"Expected 332 usable Adolfo q0-q2 masks, found {len(frame)}") | |
| return frame | |
| def _other_manifest(root: Path, merged_metadata: dict[str, dict[str, Any]]) -> pd.DataFrame: | |
| rows: list[dict[str, Any]] = [] | |
| for collection_dir in sorted(path for path in root.iterdir() if path.is_dir()): | |
| collection = collection_dir.name | |
| image_dir = collection_dir / "img" | |
| mask_dir = collection_dir / "binary_masks" | |
| ann_dir = collection_dir / "ann" | |
| for image_path in list_images(image_dir): | |
| sample_id = f"{collection}__{image_path.stem}" | |
| source_id = image_path.stem | |
| quality = quality_key(_metadata_value(merged_metadata, source_id, "quality_label", "")) | |
| if quality == "q3": | |
| continue | |
| mask_path = find_image_by_stem(mask_dir, image_path.stem) | |
| ann_path = annotation_path(ann_dir, image_path) | |
| rows.append( | |
| { | |
| "sample_id": sample_id, | |
| "source_id": source_id, | |
| "collection": collection, | |
| "seal_path": str(image_path), | |
| "mask_path": str(mask_path), | |
| "ann_path": str(ann_path), | |
| "quality_label": quality, | |
| "monogram_type": _metadata_value(merged_metadata, source_id, "monogram_type", ""), | |
| "broad_monogram_type": _metadata_value(merged_metadata, source_id, "broad_monogram_type", ""), | |
| } | |
| ) | |
| frame = pd.DataFrame(rows) | |
| if len(frame) != 273: | |
| raise ValueError(f"Expected 273 other-collection q0-q2 masks, found {len(frame)}") | |
| return frame | |
| def build_sample_manifest( | |
| adolfo_root: str | Path = DEFAULT_ADOLFO_ROOT, | |
| other_root: str | Path = DEFAULT_OTHER_ROOT, | |
| metadata_csv: str | Path | None = DEFAULT_METADATA_CSV, | |
| ) -> pd.DataFrame: | |
| metadata_path = Path(metadata_csv) if metadata_csv is not None else None | |
| merged_metadata = _load_metadata(metadata_path) | |
| frame = pd.concat( | |
| [_adolfo_manifest(Path(adolfo_root), merged_metadata), _other_manifest(Path(other_root), merged_metadata)], | |
| ignore_index=True, | |
| sort=False, | |
| ) | |
| if len(frame) != 605: | |
| raise ValueError(f"Expected 605 total q0-q2 supervised masks, found {len(frame)}") | |
| if frame["sample_id"].duplicated().any(): | |
| dupes = frame.loc[frame["sample_id"].duplicated(), "sample_id"].tolist() | |
| raise ValueError(f"Duplicate sample IDs: {dupes[:10]}") | |
| return frame | |
| class MonogramSegmentationDataset(Dataset): | |
| def __init__( | |
| self, | |
| frame: pd.DataFrame, | |
| image_size: int = 512, | |
| crop_padding: float = 0.08, | |
| augment: bool = False, | |
| ): | |
| self.frame = frame.reset_index(drop=True) | |
| self.image_size = image_size | |
| self.crop_padding = crop_padding | |
| self.augment = augment | |
| self.image_t = image_transform() | |
| self.mask_t = mask_transform() | |
| def __len__(self) -> int: | |
| return len(self.frame) | |
| def __getitem__(self, idx: int) -> dict[str, Any]: | |
| row = self.frame.iloc[idx].to_dict() | |
| image = Image.open(row["seal_path"]).convert("RGB") | |
| mask = Image.open(row["mask_path"]).convert("L") | |
| crop_jitter = 0.06 if self.augment else 0.0 | |
| crop_scale_jitter = 0.05 if self.augment else 0.0 | |
| image, mask = crop_pair( | |
| image, | |
| mask, | |
| row.get("ann_path"), | |
| self.image_size, | |
| self.crop_padding, | |
| crop_jitter=crop_jitter, | |
| crop_scale_jitter=crop_scale_jitter, | |
| ) | |
| if self.augment: | |
| image, mask = apply_pair_augment(image, mask) | |
| image_tensor = self.image_t(image) | |
| mask_tensor = self.mask_t(mask) | |
| return { | |
| "image": image_tensor, | |
| "mask": mask_tensor, | |
| "sample_id": str(row["sample_id"]), | |
| "collection": str(row["collection"]), | |
| "quality_label": _clean_label(row.get("quality_label", "")), | |
| "monogram_type": _clean_label(row.get("monogram_type", "")), | |
| "broad_monogram_type": _clean_label(row.get("broad_monogram_type", "")), | |
| "seal_path": str(row["seal_path"]), | |
| "mask_path": str(row["mask_path"]), | |
| "ann_path": str(row["ann_path"]), | |
| } | |
| def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: | |
| return { | |
| "image": torch.stack([item["image"] for item in batch]), | |
| "mask": torch.stack([item["mask"] for item in batch]), | |
| "sample_id": [item["sample_id"] for item in batch], | |
| "collection": [item["collection"] for item in batch], | |
| "quality_label": [item["quality_label"] for item in batch], | |
| "monogram_type": [item["monogram_type"] for item in batch], | |
| "broad_monogram_type": [item["broad_monogram_type"] for item in batch], | |
| "seal_path": [item["seal_path"] for item in batch], | |
| "mask_path": [item["mask_path"] for item in batch], | |
| "ann_path": [item["ann_path"] for item in batch], | |
| } | |