mosaic-test / data.py
raylim's picture
add aeon/paladin
6db96fd
from enum import Enum
from typing import List
import torch
from torch.utils.data import Dataset
import numpy as np
CANCER_TYPE_TO_INT_MAP = {
"AASTR": 0,
"ACC": 1,
"ACRM": 2,
"ACYC": 3,
"ADNOS": 4,
"ALUCA": 5,
"AMPCA": 6,
"ANGS": 7,
"ANSC": 8,
"AODG": 9,
"APAD": 10,
"ARMM": 11,
"ARMS": 12,
"ASTR": 13,
"ATM": 14,
"BA": 15,
"BCC": 16,
"BLAD": 17,
"BLCA": 18,
"BMGCT": 19,
"BRCA": 20,
"BRCANOS": 21,
"BRCNOS": 22,
"CCOV": 23,
"CCRCC": 24,
"CESC": 25,
"CHDM": 26,
"CHOL": 27,
"CHRCC": 28,
"CHS": 29,
"COAD": 30,
"COADREAD": 31,
"CSCC": 32,
"CSCLC": 33,
"CUP": 34,
"CUPNOS": 35,
"DA": 36,
"DASTR": 37,
"DDLS": 38,
"DES": 39,
"DIFG": 40,
"DSRCT": 41,
"DSTAD": 42,
"ECAD": 43,
"EGC": 44,
"EHAE": 45,
"EHCH": 46,
"EMPD": 47,
"EOV": 48,
"EPDCA": 49,
"EPIS": 50,
"EPM": 51,
"ERMS": 52,
"ES": 53,
"ESCA": 54,
"ESCC": 55,
"GB": 56,
"GBAD": 57,
"GBC": 58,
"GBM": 59,
"GCCAP": 60,
"GEJ": 61,
"GINET": 62,
"GIST": 63,
"GNOS": 64,
"GRCT": 65,
"HCC": 66,
"HGGNOS": 67,
"HGNEC": 68,
"HGSFT": 69,
"HGSOC": 70,
"HNMUCM": 71,
"HNSC": 72,
"IDC": 73,
"IHCH": 74,
"ILC": 75,
"LGGNOS": 76,
"LGSOC": 77,
"LMS": 78,
"LNET": 79,
"LUAD": 80,
"LUAS": 81,
"LUCA": 82,
"LUNE": 83,
"LUPC": 84,
"LUSC": 85,
"LXSC": 86,
"MAAP": 87,
"MACR": 88,
"MBC": 89,
"MCC": 90,
"MDLC": 91,
"MEL": 92,
"MFH": 93,
"MFS": 94,
"MGCT": 95,
"MNG": 96,
"MOV": 97,
"MPNST": 98,
"MRLS": 99,
"MUP": 100,
"MXOV": 101,
"NBL": 102,
"NECNOS": 103,
"NETNOS": 104,
"NOT": 105,
"NPC": 106,
"NSCLC": 107,
"NSCLCPD": 108,
"NSGCT": 109,
"OCS": 110,
"OCSC": 111,
"ODG": 112,
"OOVC": 113,
"OPHSC": 114,
"OS": 115,
"PAAC": 116,
"PAAD": 117,
"PAASC": 118,
"PAMPCA": 119,
"PANET": 120,
"PAST": 121,
"PDC": 122,
"PECOMA": 123,
"PEMESO": 124,
"PHC": 125,
"PLBMESO": 126,
"PLEMESO": 127,
"PLMESO": 128,
"PRAD": 129,
"PRCC": 130,
"PSEC": 131,
"PTAD": 132,
"RBL": 133,
"RCC": 134,
"RCSNOS": 135,
"READ": 136,
"RMS": 137,
"SARCNOS": 138,
"SBC": 139,
"SBOV": 140,
"SBWDNET": 141,
"SCBC": 142,
"SCCNOS": 143,
"SCHW": 144,
"SCLC": 145,
"SCUP": 146,
"SDCA": 147,
"SEM": 148,
"SFT": 149,
"SKCM": 150,
"SOC": 151,
"SPDAC": 152,
"SSRCC": 153,
"STAD": 154,
"SYNS": 155,
"TAC": 156,
"THAP": 157,
"THHC": 158,
"THME": 159,
"THPA": 160,
"THPD": 161,
"THYC": 162,
"THYM": 163,
"TYST": 164,
"UCCC": 165,
"UCEC": 166,
"UCP": 167,
"UCS": 168,
"UCU": 169,
"UDMN": 170,
"UEC": 171,
"ULMS": 172,
"UM": 173,
"UMEC": 174,
"URCC": 175,
"USARC": 176,
"USC": 177,
"UTUC": 178,
"VMM": 179,
"VSC": 180,
"WDLS": 181,
"WT": 182,
}
INT_TO_CANCER_TYPE_MAP = {v: k for k, v in CANCER_TYPE_TO_INT_MAP.items()}
class SiteType(Enum):
PRIMARY = "Primary"
METASTASIS = "Metastasis"
class TileFeatureTensorDataset(Dataset):
def __init__(
self,
site_type: SiteType,
tile_features: np.ndarray,
n_max_tiles: int = 20000,
) -> None:
"""Initialize the dataset.
Args:
site_type: the site type as str, either "Primary" or "Metastasis"
tile_features: the tile feature array
n_max_tiles: the maximum number of tiles to use as int
Returns:
None
"""
self.site_type = site_type
self.n_max_tiles = n_max_tiles
self.features = self._get_features(tile_features)
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
int: the length of the dataset
"""
return 1
def _get_features(self, features) -> torch.Tensor:
"""Get the tile features
Args:
features: the tile features as a numpy array
Returns:
torch.Tensor: the tile tensor
"""
features = torch.tensor(features, dtype=torch.float32)
if features.shape[0] > self.n_max_tiles:
indices = torch.randperm(features.shape[0])[: self.n_max_tiles]
features = features[indices]
if features.shape[0] < self.n_max_tiles:
padding = torch.zeros(
self.n_max_tiles - features.shape[0], features.shape[1]
)
features = torch.cat([features, padding], dim=0)
return features
def __getitem__(self, idx: int) -> dict:
"""Return an item from the dataset.
Args:
idx: the index of the item to return
Returns:
dict: the item
"""
return {
"site": self.site_type.value,
"tile_tensor": self.features
}