Spaces:
Sleeping
Sleeping
| from enum import Enum | |
| from typing import List | |
| import torch | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| CANCER_TYPE_TO_INT_MAP = { | |
| "AASTR": 0, | |
| "ACC": 1, | |
| "ACRM": 2, | |
| "ACYC": 3, | |
| "ADNOS": 4, | |
| "ALUCA": 5, | |
| "AMPCA": 6, | |
| "ANGS": 7, | |
| "ANSC": 8, | |
| "AODG": 9, | |
| "APAD": 10, | |
| "ARMM": 11, | |
| "ARMS": 12, | |
| "ASTR": 13, | |
| "ATM": 14, | |
| "BA": 15, | |
| "BCC": 16, | |
| "BLAD": 17, | |
| "BLCA": 18, | |
| "BMGCT": 19, | |
| "BRCA": 20, | |
| "BRCANOS": 21, | |
| "BRCNOS": 22, | |
| "CCOV": 23, | |
| "CCRCC": 24, | |
| "CESC": 25, | |
| "CHDM": 26, | |
| "CHOL": 27, | |
| "CHRCC": 28, | |
| "CHS": 29, | |
| "COAD": 30, | |
| "COADREAD": 31, | |
| "CSCC": 32, | |
| "CSCLC": 33, | |
| "CUP": 34, | |
| "CUPNOS": 35, | |
| "DA": 36, | |
| "DASTR": 37, | |
| "DDLS": 38, | |
| "DES": 39, | |
| "DIFG": 40, | |
| "DSRCT": 41, | |
| "DSTAD": 42, | |
| "ECAD": 43, | |
| "EGC": 44, | |
| "EHAE": 45, | |
| "EHCH": 46, | |
| "EMPD": 47, | |
| "EOV": 48, | |
| "EPDCA": 49, | |
| "EPIS": 50, | |
| "EPM": 51, | |
| "ERMS": 52, | |
| "ES": 53, | |
| "ESCA": 54, | |
| "ESCC": 55, | |
| "GB": 56, | |
| "GBAD": 57, | |
| "GBC": 58, | |
| "GBM": 59, | |
| "GCCAP": 60, | |
| "GEJ": 61, | |
| "GINET": 62, | |
| "GIST": 63, | |
| "GNOS": 64, | |
| "GRCT": 65, | |
| "HCC": 66, | |
| "HGGNOS": 67, | |
| "HGNEC": 68, | |
| "HGSFT": 69, | |
| "HGSOC": 70, | |
| "HNMUCM": 71, | |
| "HNSC": 72, | |
| "IDC": 73, | |
| "IHCH": 74, | |
| "ILC": 75, | |
| "LGGNOS": 76, | |
| "LGSOC": 77, | |
| "LMS": 78, | |
| "LNET": 79, | |
| "LUAD": 80, | |
| "LUAS": 81, | |
| "LUCA": 82, | |
| "LUNE": 83, | |
| "LUPC": 84, | |
| "LUSC": 85, | |
| "LXSC": 86, | |
| "MAAP": 87, | |
| "MACR": 88, | |
| "MBC": 89, | |
| "MCC": 90, | |
| "MDLC": 91, | |
| "MEL": 92, | |
| "MFH": 93, | |
| "MFS": 94, | |
| "MGCT": 95, | |
| "MNG": 96, | |
| "MOV": 97, | |
| "MPNST": 98, | |
| "MRLS": 99, | |
| "MUP": 100, | |
| "MXOV": 101, | |
| "NBL": 102, | |
| "NECNOS": 103, | |
| "NETNOS": 104, | |
| "NOT": 105, | |
| "NPC": 106, | |
| "NSCLC": 107, | |
| "NSCLCPD": 108, | |
| "NSGCT": 109, | |
| "OCS": 110, | |
| "OCSC": 111, | |
| "ODG": 112, | |
| "OOVC": 113, | |
| "OPHSC": 114, | |
| "OS": 115, | |
| "PAAC": 116, | |
| "PAAD": 117, | |
| "PAASC": 118, | |
| "PAMPCA": 119, | |
| "PANET": 120, | |
| "PAST": 121, | |
| "PDC": 122, | |
| "PECOMA": 123, | |
| "PEMESO": 124, | |
| "PHC": 125, | |
| "PLBMESO": 126, | |
| "PLEMESO": 127, | |
| "PLMESO": 128, | |
| "PRAD": 129, | |
| "PRCC": 130, | |
| "PSEC": 131, | |
| "PTAD": 132, | |
| "RBL": 133, | |
| "RCC": 134, | |
| "RCSNOS": 135, | |
| "READ": 136, | |
| "RMS": 137, | |
| "SARCNOS": 138, | |
| "SBC": 139, | |
| "SBOV": 140, | |
| "SBWDNET": 141, | |
| "SCBC": 142, | |
| "SCCNOS": 143, | |
| "SCHW": 144, | |
| "SCLC": 145, | |
| "SCUP": 146, | |
| "SDCA": 147, | |
| "SEM": 148, | |
| "SFT": 149, | |
| "SKCM": 150, | |
| "SOC": 151, | |
| "SPDAC": 152, | |
| "SSRCC": 153, | |
| "STAD": 154, | |
| "SYNS": 155, | |
| "TAC": 156, | |
| "THAP": 157, | |
| "THHC": 158, | |
| "THME": 159, | |
| "THPA": 160, | |
| "THPD": 161, | |
| "THYC": 162, | |
| "THYM": 163, | |
| "TYST": 164, | |
| "UCCC": 165, | |
| "UCEC": 166, | |
| "UCP": 167, | |
| "UCS": 168, | |
| "UCU": 169, | |
| "UDMN": 170, | |
| "UEC": 171, | |
| "ULMS": 172, | |
| "UM": 173, | |
| "UMEC": 174, | |
| "URCC": 175, | |
| "USARC": 176, | |
| "USC": 177, | |
| "UTUC": 178, | |
| "VMM": 179, | |
| "VSC": 180, | |
| "WDLS": 181, | |
| "WT": 182, | |
| } | |
| INT_TO_CANCER_TYPE_MAP = {v: k for k, v in CANCER_TYPE_TO_INT_MAP.items()} | |
| class SiteType(Enum): | |
| PRIMARY = "Primary" | |
| METASTASIS = "Metastasis" | |
| class TileFeatureTensorDataset(Dataset): | |
| def __init__( | |
| self, | |
| site_type: SiteType, | |
| tile_features: np.ndarray, | |
| n_max_tiles: int = 20000, | |
| ) -> None: | |
| """Initialize the dataset. | |
| Args: | |
| site_type: the site type as str, either "Primary" or "Metastasis" | |
| tile_features: the tile feature array | |
| n_max_tiles: the maximum number of tiles to use as int | |
| Returns: | |
| None | |
| """ | |
| self.site_type = site_type | |
| self.n_max_tiles = n_max_tiles | |
| self.features = self._get_features(tile_features) | |
| def __len__(self) -> int: | |
| """Return the length of the dataset. | |
| Returns: | |
| int: the length of the dataset | |
| """ | |
| return 1 | |
| def _get_features(self, features) -> torch.Tensor: | |
| """Get the tile features | |
| Args: | |
| features: the tile features as a numpy array | |
| Returns: | |
| torch.Tensor: the tile tensor | |
| """ | |
| features = torch.tensor(features, dtype=torch.float32) | |
| if features.shape[0] > self.n_max_tiles: | |
| indices = torch.randperm(features.shape[0])[: self.n_max_tiles] | |
| features = features[indices] | |
| if features.shape[0] < self.n_max_tiles: | |
| padding = torch.zeros( | |
| self.n_max_tiles - features.shape[0], features.shape[1] | |
| ) | |
| features = torch.cat([features, padding], dim=0) | |
| return features | |
| def __getitem__(self, idx: int) -> dict: | |
| """Return an item from the dataset. | |
| Args: | |
| idx: the index of the item to return | |
| Returns: | |
| dict: the item | |
| """ | |
| return { | |
| "site": self.site_type.value, | |
| "tile_tensor": self.features | |
| } | |