"""Data structures and utilities for inference modules. This module provides: - Cancer type to integer mappings for model inputs/outputs - SiteType enum for primary vs metastatic classification - TileFeatureTensorDataset for feeding features to PyTorch models """ from enum import Enum from typing import List import torch from torch.utils.data import Dataset import numpy as np from mosaic.data_directory import get_data_directory CANCER_TYPE_TO_INT_MAP = { "AASTR": 0, "ACC": 1, "ACRM": 2, "ACYC": 3, "ADNOS": 4, "ALUCA": 5, "AMPCA": 6, "ANGS": 7, "ANSC": 8, "AODG": 9, "APAD": 10, "ARMM": 11, "ARMS": 12, "ASTR": 13, "ATM": 14, "BA": 15, "BCC": 16, "BLAD": 17, "BLCA": 18, "BMGCT": 19, "BRCA": 20, "BRCANOS": 21, "BRCNOS": 22, "CCOV": 23, "CCRCC": 24, "CESC": 25, "CHDM": 26, "CHOL": 27, "CHRCC": 28, "CHS": 29, "COAD": 30, "COADREAD": 31, "CSCC": 32, "CSCLC": 33, "CUP": 34, "CUPNOS": 35, "DA": 36, "DASTR": 37, "DDLS": 38, "DES": 39, "DIFG": 40, "DSRCT": 41, "DSTAD": 42, "ECAD": 43, "EGC": 44, "EHAE": 45, "EHCH": 46, "EMPD": 47, "EOV": 48, "EPDCA": 49, "EPIS": 50, "EPM": 51, "ERMS": 52, "ES": 53, "ESCA": 54, "ESCC": 55, "GB": 56, "GBAD": 57, "GBC": 58, "GBM": 59, "GCCAP": 60, "GEJ": 61, "GINET": 62, "GIST": 63, "GNOS": 64, "GRCT": 65, "HCC": 66, "HGGNOS": 67, "HGNEC": 68, "HGSFT": 69, "HGSOC": 70, "HNMUCM": 71, "HNSC": 72, "IDC": 73, "IHCH": 74, "ILC": 75, "LGGNOS": 76, "LGSOC": 77, "LMS": 78, "LNET": 79, "LUAD": 80, "LUAS": 81, "LUCA": 82, "LUNE": 83, "LUPC": 84, "LUSC": 85, "LXSC": 86, "MAAP": 87, "MACR": 88, "MBC": 89, "MCC": 90, "MDLC": 91, "MEL": 92, "MFH": 93, "MFS": 94, "MGCT": 95, "MNG": 96, "MOV": 97, "MPNST": 98, "MRLS": 99, "MUP": 100, "MXOV": 101, "NBL": 102, "NECNOS": 103, "NETNOS": 104, "NOT": 105, "NPC": 106, "NSCLC": 107, "NSCLCPD": 108, "NSGCT": 109, "OCS": 110, "OCSC": 111, "ODG": 112, "OOVC": 113, "OPHSC": 114, "OS": 115, "PAAC": 116, "PAAD": 117, "PAASC": 118, "PAMPCA": 119, "PANET": 120, "PAST": 121, "PDC": 122, "PECOMA": 123, "PEMESO": 124, "PHC": 125, "PLBMESO": 126, "PLEMESO": 127, "PLMESO": 128, "PRAD": 129, "PRCC": 130, "PSEC": 131, "PTAD": 132, "RBL": 133, "RCC": 134, "RCSNOS": 135, "READ": 136, "RMS": 137, "SARCNOS": 138, "SBC": 139, "SBOV": 140, "SBWDNET": 141, "SCBC": 142, "SCCNOS": 143, "SCHW": 144, "SCLC": 145, "SCUP": 146, "SDCA": 147, "SEM": 148, "SFT": 149, "SKCM": 150, "SOC": 151, "SPDAC": 152, "SSRCC": 153, "STAD": 154, "SYNS": 155, "TAC": 156, "THAP": 157, "THHC": 158, "THME": 159, "THPA": 160, "THPD": 161, "THYC": 162, "THYM": 163, "TYST": 164, "UCCC": 165, "UCEC": 166, "UCP": 167, "UCS": 168, "UCU": 169, "UDMN": 170, "UEC": 171, "ULMS": 172, "UM": 173, "UMEC": 174, "URCC": 175, "USARC": 176, "USC": 177, "UTUC": 178, "VMM": 179, "VSC": 180, "WDLS": 181, "WT": 182, } INT_TO_CANCER_TYPE_MAP = {v: k for k, v in CANCER_TYPE_TO_INT_MAP.items()} # Tissue site mapping (module-level cache) _TISSUE_SITE_MAP = None # Default tissue site index for "Not Applicable" DEFAULT_TISSUE_SITE_IDX = 8 def get_tissue_site_map(): """Load tissue site name → index mapping from CSV. Returns: dict: Mapping of tissue site names to indices (0-56) Raises: FileNotFoundError: If the tissue site CSV file is not found """ global _TISSUE_SITE_MAP if _TISSUE_SITE_MAP is None: import pandas as pd data_dir = get_data_directory() csv_path = data_dir / "tissue_site_original_to_idx.csv" try: df = pd.read_csv(csv_path) except FileNotFoundError as e: raise FileNotFoundError( f"Tissue site mapping file not found at {csv_path}. " f"Please ensure the data directory contains 'tissue_site_original_to_idx.csv'." ) from e _TISSUE_SITE_MAP = {} for _, row in df.iterrows(): _TISSUE_SITE_MAP[row["TISSUE_SITE"]] = int(row["idx"]) return _TISSUE_SITE_MAP def get_tissue_site_options(): """Get sorted unique tissue site names for UI dropdowns. Returns: list: Sorted list of unique tissue site names """ site_map = get_tissue_site_map() return sorted(set(site_map.keys())) _SEX_MAP = None def get_sex_map(): """Get the sex to index mapping from CSV file. Returns: dict: Mapping of sex values to indices (0-2) Raises: FileNotFoundError: If the sex mapping CSV file is not found """ global _SEX_MAP if _SEX_MAP is None: import pandas as pd data_dir = get_data_directory() csv_path = data_dir / "sex_original_to_idx.csv" try: df = pd.read_csv(csv_path) except FileNotFoundError as e: raise FileNotFoundError( f"Sex mapping file not found at {csv_path}. " f"Please ensure the data directory contains 'sex_original_to_idx.csv'." ) from e _SEX_MAP = {} for _, row in df.iterrows(): _SEX_MAP[row["SEX"]] = int(row["idx"]) return _SEX_MAP def encode_sex(sex): """Convert sex to numeric encoding. Args: sex: "Male" or "Female" (required, case insensitive) Returns: int: 0 for Male, 1 for Female Raises: ValueError: If sex is not "Male" or "Female" """ sex_map = get_sex_map() if sex not in sex_map: raise ValueError(f"Sex must be 'Male' or 'Female', got: {sex}") return sex_map[sex] def encode_tissue_site(site_name): """Convert tissue site name to index (0-56). Args: site_name: Tissue site name from CSV Returns: int: Tissue site index, defaults to DEFAULT_TISSUE_SITE_IDX ("Not Applicable") """ site_map = get_tissue_site_map() return site_map.get(site_name, DEFAULT_TISSUE_SITE_IDX) def tissue_site_to_one_hot(site_idx, num_classes=57): """Convert tissue site index to one-hot vector. Args: site_idx: Index value (0-56 for tissue site, 0-2 for sex) num_classes: Number of classes (57 for tissue site, 3 for sex) Returns: list: One-hot encoded vector """ one_hot = [0] * num_classes if 0 <= site_idx < num_classes: one_hot[site_idx] = 1 return one_hot class SiteType(Enum): PRIMARY = "Primary" METASTASIS = "Metastasis" class TileFeatureTensorDataset(Dataset): def __init__( self, site_type: SiteType, tile_features: np.ndarray, sex: int = None, tissue_site_idx: int = None, n_max_tiles: int = 20000, ) -> None: """Initialize the dataset. Args: site_type: the site type as str, either "Primary" or "Metastasis" tile_features: the tile feature array sex: patient sex (0=Male, 1=Female), optional for Aeon tissue_site_idx: tissue site index (0-56), optional for Aeon n_max_tiles: the maximum number of tiles to use as int Returns: None """ self.site_type = site_type self.sex = sex self.tissue_site_idx = tissue_site_idx self.n_max_tiles = n_max_tiles self.features = self._get_features(tile_features) def __len__(self) -> int: """Return the length of the dataset. Returns: int: the length of the dataset """ return 1 def _get_features(self, features) -> torch.Tensor: """Get the tile features Args: features: the tile features as a numpy array Returns: torch.Tensor: the tile tensor """ features = torch.tensor(features, dtype=torch.float32) if features.shape[0] > self.n_max_tiles: indices = torch.randperm(features.shape[0])[: self.n_max_tiles] features = features[indices] if features.shape[0] < self.n_max_tiles: padding = torch.zeros( self.n_max_tiles - features.shape[0], features.shape[1] ) features = torch.cat([features, padding], dim=0) return features def __getitem__(self, idx: int) -> dict: """Return an item from the dataset. Args: idx: the index of the item to return Returns: dict: the item """ result = {"site": self.site_type.value, "tile_tensor": self.features} # Add sex and tissue_site if provided (for Aeon) if self.sex is not None: result["SEX"] = torch.tensor( tissue_site_to_one_hot(self.sex, num_classes=3), dtype=torch.float32 ) if self.tissue_site_idx is not None: result["TISSUE_SITE"] = torch.tensor( tissue_site_to_one_hot(self.tissue_site_idx, num_classes=57), dtype=torch.float32, ) return result