raylim's picture
Make sex parameter required in encode_sex function
466e964 unverified
"""Data structures and utilities for inference modules.
This module provides:
- Cancer type to integer mappings for model inputs/outputs
- SiteType enum for primary vs metastatic classification
- TileFeatureTensorDataset for feeding features to PyTorch models
"""
from enum import Enum
from typing import List
import torch
from torch.utils.data import Dataset
import numpy as np
from mosaic.data_directory import get_data_directory
CANCER_TYPE_TO_INT_MAP = {
"AASTR": 0,
"ACC": 1,
"ACRM": 2,
"ACYC": 3,
"ADNOS": 4,
"ALUCA": 5,
"AMPCA": 6,
"ANGS": 7,
"ANSC": 8,
"AODG": 9,
"APAD": 10,
"ARMM": 11,
"ARMS": 12,
"ASTR": 13,
"ATM": 14,
"BA": 15,
"BCC": 16,
"BLAD": 17,
"BLCA": 18,
"BMGCT": 19,
"BRCA": 20,
"BRCANOS": 21,
"BRCNOS": 22,
"CCOV": 23,
"CCRCC": 24,
"CESC": 25,
"CHDM": 26,
"CHOL": 27,
"CHRCC": 28,
"CHS": 29,
"COAD": 30,
"COADREAD": 31,
"CSCC": 32,
"CSCLC": 33,
"CUP": 34,
"CUPNOS": 35,
"DA": 36,
"DASTR": 37,
"DDLS": 38,
"DES": 39,
"DIFG": 40,
"DSRCT": 41,
"DSTAD": 42,
"ECAD": 43,
"EGC": 44,
"EHAE": 45,
"EHCH": 46,
"EMPD": 47,
"EOV": 48,
"EPDCA": 49,
"EPIS": 50,
"EPM": 51,
"ERMS": 52,
"ES": 53,
"ESCA": 54,
"ESCC": 55,
"GB": 56,
"GBAD": 57,
"GBC": 58,
"GBM": 59,
"GCCAP": 60,
"GEJ": 61,
"GINET": 62,
"GIST": 63,
"GNOS": 64,
"GRCT": 65,
"HCC": 66,
"HGGNOS": 67,
"HGNEC": 68,
"HGSFT": 69,
"HGSOC": 70,
"HNMUCM": 71,
"HNSC": 72,
"IDC": 73,
"IHCH": 74,
"ILC": 75,
"LGGNOS": 76,
"LGSOC": 77,
"LMS": 78,
"LNET": 79,
"LUAD": 80,
"LUAS": 81,
"LUCA": 82,
"LUNE": 83,
"LUPC": 84,
"LUSC": 85,
"LXSC": 86,
"MAAP": 87,
"MACR": 88,
"MBC": 89,
"MCC": 90,
"MDLC": 91,
"MEL": 92,
"MFH": 93,
"MFS": 94,
"MGCT": 95,
"MNG": 96,
"MOV": 97,
"MPNST": 98,
"MRLS": 99,
"MUP": 100,
"MXOV": 101,
"NBL": 102,
"NECNOS": 103,
"NETNOS": 104,
"NOT": 105,
"NPC": 106,
"NSCLC": 107,
"NSCLCPD": 108,
"NSGCT": 109,
"OCS": 110,
"OCSC": 111,
"ODG": 112,
"OOVC": 113,
"OPHSC": 114,
"OS": 115,
"PAAC": 116,
"PAAD": 117,
"PAASC": 118,
"PAMPCA": 119,
"PANET": 120,
"PAST": 121,
"PDC": 122,
"PECOMA": 123,
"PEMESO": 124,
"PHC": 125,
"PLBMESO": 126,
"PLEMESO": 127,
"PLMESO": 128,
"PRAD": 129,
"PRCC": 130,
"PSEC": 131,
"PTAD": 132,
"RBL": 133,
"RCC": 134,
"RCSNOS": 135,
"READ": 136,
"RMS": 137,
"SARCNOS": 138,
"SBC": 139,
"SBOV": 140,
"SBWDNET": 141,
"SCBC": 142,
"SCCNOS": 143,
"SCHW": 144,
"SCLC": 145,
"SCUP": 146,
"SDCA": 147,
"SEM": 148,
"SFT": 149,
"SKCM": 150,
"SOC": 151,
"SPDAC": 152,
"SSRCC": 153,
"STAD": 154,
"SYNS": 155,
"TAC": 156,
"THAP": 157,
"THHC": 158,
"THME": 159,
"THPA": 160,
"THPD": 161,
"THYC": 162,
"THYM": 163,
"TYST": 164,
"UCCC": 165,
"UCEC": 166,
"UCP": 167,
"UCS": 168,
"UCU": 169,
"UDMN": 170,
"UEC": 171,
"ULMS": 172,
"UM": 173,
"UMEC": 174,
"URCC": 175,
"USARC": 176,
"USC": 177,
"UTUC": 178,
"VMM": 179,
"VSC": 180,
"WDLS": 181,
"WT": 182,
}
INT_TO_CANCER_TYPE_MAP = {v: k for k, v in CANCER_TYPE_TO_INT_MAP.items()}
# Tissue site mapping (module-level cache)
_TISSUE_SITE_MAP = None
# Default tissue site index for "Not Applicable"
DEFAULT_TISSUE_SITE_IDX = 8
def get_tissue_site_map():
"""Load tissue site name → index mapping from CSV.
Returns:
dict: Mapping of tissue site names to indices (0-56)
Raises:
FileNotFoundError: If the tissue site CSV file is not found
"""
global _TISSUE_SITE_MAP
if _TISSUE_SITE_MAP is None:
import pandas as pd
data_dir = get_data_directory()
csv_path = data_dir / "tissue_site_original_to_idx.csv"
try:
df = pd.read_csv(csv_path)
except FileNotFoundError as e:
raise FileNotFoundError(
f"Tissue site mapping file not found at {csv_path}. "
f"Please ensure the data directory contains 'tissue_site_original_to_idx.csv'."
) from e
_TISSUE_SITE_MAP = {}
for _, row in df.iterrows():
_TISSUE_SITE_MAP[row["TISSUE_SITE"]] = int(row["idx"])
return _TISSUE_SITE_MAP
def get_tissue_site_options():
"""Get sorted unique tissue site names for UI dropdowns.
Returns:
list: Sorted list of unique tissue site names
"""
site_map = get_tissue_site_map()
return sorted(set(site_map.keys()))
_SEX_MAP = None
def get_sex_map():
"""Get the sex to index mapping from CSV file.
Returns:
dict: Mapping of sex values to indices (0-2)
Raises:
FileNotFoundError: If the sex mapping CSV file is not found
"""
global _SEX_MAP
if _SEX_MAP is None:
import pandas as pd
data_dir = get_data_directory()
csv_path = data_dir / "sex_original_to_idx.csv"
try:
df = pd.read_csv(csv_path)
except FileNotFoundError as e:
raise FileNotFoundError(
f"Sex mapping file not found at {csv_path}. "
f"Please ensure the data directory contains 'sex_original_to_idx.csv'."
) from e
_SEX_MAP = {}
for _, row in df.iterrows():
_SEX_MAP[row["SEX"]] = int(row["idx"])
return _SEX_MAP
def encode_sex(sex):
"""Convert sex to numeric encoding.
Args:
sex: "Male" or "Female" (required, case insensitive)
Returns:
int: 0 for Male, 1 for Female
Raises:
ValueError: If sex is not "Male" or "Female"
"""
sex_map = get_sex_map()
if sex not in sex_map:
raise ValueError(f"Sex must be 'Male' or 'Female', got: {sex}")
return sex_map[sex]
def encode_tissue_site(site_name):
"""Convert tissue site name to index (0-56).
Args:
site_name: Tissue site name from CSV
Returns:
int: Tissue site index, defaults to DEFAULT_TISSUE_SITE_IDX ("Not Applicable")
"""
site_map = get_tissue_site_map()
return site_map.get(site_name, DEFAULT_TISSUE_SITE_IDX)
def tissue_site_to_one_hot(site_idx, num_classes=57):
"""Convert tissue site index to one-hot vector.
Args:
site_idx: Index value (0-56 for tissue site, 0-2 for sex)
num_classes: Number of classes (57 for tissue site, 3 for sex)
Returns:
list: One-hot encoded vector
"""
one_hot = [0] * num_classes
if 0 <= site_idx < num_classes:
one_hot[site_idx] = 1
return one_hot
class SiteType(Enum):
PRIMARY = "Primary"
METASTASIS = "Metastasis"
class TileFeatureTensorDataset(Dataset):
def __init__(
self,
site_type: SiteType,
tile_features: np.ndarray,
sex: int = None,
tissue_site_idx: int = None,
n_max_tiles: int = 20000,
) -> None:
"""Initialize the dataset.
Args:
site_type: the site type as str, either "Primary" or "Metastasis"
tile_features: the tile feature array
sex: patient sex (0=Male, 1=Female), optional for Aeon
tissue_site_idx: tissue site index (0-56), optional for Aeon
n_max_tiles: the maximum number of tiles to use as int
Returns:
None
"""
self.site_type = site_type
self.sex = sex
self.tissue_site_idx = tissue_site_idx
self.n_max_tiles = n_max_tiles
self.features = self._get_features(tile_features)
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
int: the length of the dataset
"""
return 1
def _get_features(self, features) -> torch.Tensor:
"""Get the tile features
Args:
features: the tile features as a numpy array
Returns:
torch.Tensor: the tile tensor
"""
features = torch.tensor(features, dtype=torch.float32)
if features.shape[0] > self.n_max_tiles:
indices = torch.randperm(features.shape[0])[: self.n_max_tiles]
features = features[indices]
if features.shape[0] < self.n_max_tiles:
padding = torch.zeros(
self.n_max_tiles - features.shape[0], features.shape[1]
)
features = torch.cat([features, padding], dim=0)
return features
def __getitem__(self, idx: int) -> dict:
"""Return an item from the dataset.
Args:
idx: the index of the item to return
Returns:
dict: the item
"""
result = {"site": self.site_type.value, "tile_tensor": self.features}
# Add sex and tissue_site if provided (for Aeon)
if self.sex is not None:
result["SEX"] = torch.tensor(
tissue_site_to_one_hot(self.sex, num_classes=3), dtype=torch.float32
)
if self.tissue_site_idx is not None:
result["TISSUE_SITE"] = torch.tensor(
tissue_site_to_one_hot(self.tissue_site_idx, num_classes=57),
dtype=torch.float32,
)
return result