from enum import Enum from typing import List import torch from torch.utils.data import Dataset import numpy as np CANCER_TYPE_TO_INT_MAP = { "AASTR": 0, "ACC": 1, "ACRM": 2, "ACYC": 3, "ADNOS": 4, "ALUCA": 5, "AMPCA": 6, "ANGS": 7, "ANSC": 8, "AODG": 9, "APAD": 10, "ARMM": 11, "ARMS": 12, "ASTR": 13, "ATM": 14, "BA": 15, "BCC": 16, "BLAD": 17, "BLCA": 18, "BMGCT": 19, "BRCA": 20, "BRCANOS": 21, "BRCNOS": 22, "CCOV": 23, "CCRCC": 24, "CESC": 25, "CHDM": 26, "CHOL": 27, "CHRCC": 28, "CHS": 29, "COAD": 30, "COADREAD": 31, "CSCC": 32, "CSCLC": 33, "CUP": 34, "CUPNOS": 35, "DA": 36, "DASTR": 37, "DDLS": 38, "DES": 39, "DIFG": 40, "DSRCT": 41, "DSTAD": 42, "ECAD": 43, "EGC": 44, "EHAE": 45, "EHCH": 46, "EMPD": 47, "EOV": 48, "EPDCA": 49, "EPIS": 50, "EPM": 51, "ERMS": 52, "ES": 53, "ESCA": 54, "ESCC": 55, "GB": 56, "GBAD": 57, "GBC": 58, "GBM": 59, "GCCAP": 60, "GEJ": 61, "GINET": 62, "GIST": 63, "GNOS": 64, "GRCT": 65, "HCC": 66, "HGGNOS": 67, "HGNEC": 68, "HGSFT": 69, "HGSOC": 70, "HNMUCM": 71, "HNSC": 72, "IDC": 73, "IHCH": 74, "ILC": 75, "LGGNOS": 76, "LGSOC": 77, "LMS": 78, "LNET": 79, "LUAD": 80, "LUAS": 81, "LUCA": 82, "LUNE": 83, "LUPC": 84, "LUSC": 85, "LXSC": 86, "MAAP": 87, "MACR": 88, "MBC": 89, "MCC": 90, "MDLC": 91, "MEL": 92, "MFH": 93, "MFS": 94, "MGCT": 95, "MNG": 96, "MOV": 97, "MPNST": 98, "MRLS": 99, "MUP": 100, "MXOV": 101, "NBL": 102, "NECNOS": 103, "NETNOS": 104, "NOT": 105, "NPC": 106, "NSCLC": 107, "NSCLCPD": 108, "NSGCT": 109, "OCS": 110, "OCSC": 111, "ODG": 112, "OOVC": 113, "OPHSC": 114, "OS": 115, "PAAC": 116, "PAAD": 117, "PAASC": 118, "PAMPCA": 119, "PANET": 120, "PAST": 121, "PDC": 122, "PECOMA": 123, "PEMESO": 124, "PHC": 125, "PLBMESO": 126, "PLEMESO": 127, "PLMESO": 128, "PRAD": 129, "PRCC": 130, "PSEC": 131, "PTAD": 132, "RBL": 133, "RCC": 134, "RCSNOS": 135, "READ": 136, "RMS": 137, "SARCNOS": 138, "SBC": 139, "SBOV": 140, "SBWDNET": 141, "SCBC": 142, "SCCNOS": 143, "SCHW": 144, "SCLC": 145, "SCUP": 146, "SDCA": 147, "SEM": 148, "SFT": 149, "SKCM": 150, "SOC": 151, "SPDAC": 152, "SSRCC": 153, "STAD": 154, "SYNS": 155, "TAC": 156, "THAP": 157, "THHC": 158, "THME": 159, "THPA": 160, "THPD": 161, "THYC": 162, "THYM": 163, "TYST": 164, "UCCC": 165, "UCEC": 166, "UCP": 167, "UCS": 168, "UCU": 169, "UDMN": 170, "UEC": 171, "ULMS": 172, "UM": 173, "UMEC": 174, "URCC": 175, "USARC": 176, "USC": 177, "UTUC": 178, "VMM": 179, "VSC": 180, "WDLS": 181, "WT": 182, } INT_TO_CANCER_TYPE_MAP = {v: k for k, v in CANCER_TYPE_TO_INT_MAP.items()} class SiteType(Enum): PRIMARY = "Primary" METASTASIS = "Metastasis" class TileFeatureTensorDataset(Dataset): def __init__( self, site_type: SiteType, tile_features: np.ndarray, n_max_tiles: int = 20000, ) -> None: """Initialize the dataset. Args: site_type: the site type as str, either "Primary" or "Metastasis" tile_features: the tile feature array n_max_tiles: the maximum number of tiles to use as int Returns: None """ self.site_type = site_type self.n_max_tiles = n_max_tiles self.features = self._get_features(tile_features) def __len__(self) -> int: """Return the length of the dataset. Returns: int: the length of the dataset """ return 1 def _get_features(self, features) -> torch.Tensor: """Get the tile features Args: features: the tile features as a numpy array Returns: torch.Tensor: the tile tensor """ features = torch.tensor(features, dtype=torch.float32) if features.shape[0] > self.n_max_tiles: indices = torch.randperm(features.shape[0])[: self.n_max_tiles] features = features[indices] if features.shape[0] < self.n_max_tiles: padding = torch.zeros( self.n_max_tiles - features.shape[0], features.shape[1] ) features = torch.cat([features, padding], dim=0) return features def __getitem__(self, idx: int) -> dict: """Return an item from the dataset. Args: idx: the index of the item to return Returns: dict: the item """ return { "site": self.site_type.value, "tile_tensor": self.features }