raylim's picture
Update AEON inference to require sex parameter
0d1b788 unverified
"""Aeon model inference module for cancer subtype prediction.
This module provides functionality to run the Aeon deep learning model
for predicting cancer subtypes from H&E whole slide image features.
"""
import json
import pickle # nosec
import sys
from argparse import ArgumentParser
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import DataLoader
from mosaic.inference.data import (
SiteType,
TileFeatureTensorDataset,
encode_sex,
encode_tissue_site,
)
from loguru import logger
from mosaic.data_directory import get_data_directory
# Cancer types excluded from prediction (too broad or ambiguous)
# These are used to mask out predictions for overly general cancer types
CANCER_TYPES_TO_DROP = [
"UDMN",
"ADNOS",
"CUP",
"CUPNOS",
"NOT",
]
BATCH_SIZE = 8
NUM_WORKERS = 8
def run_with_model(
features,
model,
device,
metastatic=False,
batch_size=8,
num_workers=8,
sex=None,
tissue_site_idx=None,
):
"""Run Aeon model inference using a pre-loaded model (for batch processing).
This function is optimized for batch processing where the model is loaded
once and reused across multiple slides instead of being reloaded each time.
Args:
features: NumPy array of tile features extracted from the WSI
model: Pre-loaded Aeon model (torch.nn.Module)
device: torch.device for GPU/CPU placement
metastatic: Whether the slide is from a metastatic site
batch_size: Batch size for inference
num_workers: Number of workers for data loading
sex: Patient sex (0=Male, 1=Female), optional
tissue_site_idx: Tissue site index (0-56), optional
Returns:
tuple: (results_df, part_embedding)
- results_df: DataFrame with cancer subtypes and confidence scores
- part_embedding: Torch tensor of the learned part representation
"""
# Model is already loaded and on device, just set to eval mode
model.eval()
# Load the correct mapping from metadata for this model
data_dir = get_data_directory()
metadata_path = data_dir / "metadata" / "target_dict.tsv"
with open(metadata_path) as f:
target_dict_str = f.read().strip().replace("'", '"')
target_dict = json.loads(target_dict_str)
histologies = target_dict["histologies"]
INT_TO_CANCER_TYPE_MAP_LOCAL = {
i: histology for i, histology in enumerate(histologies)
}
CANCER_TYPE_TO_INT_MAP_LOCAL = {
v: k for k, v in INT_TO_CANCER_TYPE_MAP_LOCAL.items()
}
# Calculate col_indices_to_drop using local mapping
col_indices_to_drop_local = [
CANCER_TYPE_TO_INT_MAP_LOCAL[x]
for x in CANCER_TYPES_TO_DROP
if x in CANCER_TYPE_TO_INT_MAP_LOCAL
]
site_type = SiteType.METASTASIS if metastatic else SiteType.PRIMARY
# For UI, InferenceDataset will just be a single slide. Sample id is not relevant.
dataset = TileFeatureTensorDataset(
site_type=site_type,
tile_features=features,
sex=sex,
tissue_site_idx=tissue_site_idx,
n_max_tiles=20000,
)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
results = []
batch = next(iter(dataloader))
with torch.no_grad():
batch["tile_tensor"] = batch["tile_tensor"].to(device)
if "SEX" in batch:
batch["SEX"] = batch["SEX"].to(device)
if "TISSUE_SITE" in batch:
batch["TISSUE_SITE"] = batch["TISSUE_SITE"].to(device)
y = model(batch)
y["logits"][:, col_indices_to_drop_local] = -1e6
batch_size = y["logits"].shape[0]
assert batch_size == 1
softmax = torch.nn.functional.softmax(y["logits"][0], dim=0)
argmax = torch.argmax(softmax, dim=0)
class_assignment = INT_TO_CANCER_TYPE_MAP_LOCAL[argmax.item()]
max_confidence = softmax[argmax].item()
mean_confidence = torch.mean(softmax).item()
logger.info(
f"class {class_assignment} : confidence {max_confidence:8.5f} "
f"(mean {mean_confidence:8.5f})"
)
part_embedding = y["whole_part_representation"][0].cpu()
for cancer_subtype, j in sorted(CANCER_TYPE_TO_INT_MAP_LOCAL.items()):
confidence = softmax[j].item()
results.append((cancer_subtype, confidence))
results.sort(key=lambda row: row[1], reverse=True)
results_df = pd.DataFrame(results, columns=["Cancer Subtype", "Confidence"])
return results_df, part_embedding
def run(
features,
model_path,
metastatic=False,
batch_size=8,
num_workers=8,
use_cpu=False,
sex=None,
tissue_site_idx=None,
):
"""Run Aeon model inference for cancer subtype prediction.
Args:
features: NumPy array of tile features extracted from the WSI
model_path: Path to the pickled Aeon model file
metastatic: Whether the slide is from a metastatic site
batch_size: Batch size for inference
num_workers: Number of workers for data loading
use_cpu: Force CPU usage instead of GPU
sex: Patient sex (0=Male, 1=Female), optional
tissue_site_idx: Tissue site index (0-56), optional
Returns:
tuple: (results_df, part_embedding)
- results_df: DataFrame with cancer subtypes and confidence scores
- part_embedding: Torch tensor of the learned part representation
"""
device = torch.device(
"cuda" if not use_cpu and torch.cuda.is_available() else "cpu"
)
with open(model_path, "rb") as f:
model = pickle.load(f) # nosec
model.to(device)
return run_with_model(
features=features,
model=model,
device=device,
metastatic=metastatic,
batch_size=batch_size,
num_workers=num_workers,
sex=sex,
tissue_site_idx=tissue_site_idx,
)
def parse_args():
parser = ArgumentParser(
description="Run Aeon inference on a specified set of slides"
)
parser.add_argument(
"-i",
"--features-path",
required=True,
help="Pathname to a .pt file with optimus tile features for this slide",
)
parser.add_argument(
"-o",
"--output-prediction-path",
help="The filename for the Aeon predictions file (CSV)",
required=True,
)
parser.add_argument(
"--output-embedding-path",
help="The filename for the whole-part representation of the slide (.pt)",
)
parser.add_argument(
"--model-path",
type=str,
help="Pathname to the pickle file for an Aeon model",
required=True,
)
parser.add_argument(
"--metastatic", action="store_true", help="Tissue is from a metastatic site"
)
parser.add_argument(
"--sex",
type=str,
choices=["Male", "Female"],
default=None,
help="Patient sex (Male or Female, required for inference)",
)
parser.add_argument(
"--tissue-site",
type=str,
default=None,
help="Tissue site name",
)
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size")
parser.add_argument(
"--num-workers", type=int, default=NUM_WORKERS, help="Number of workers"
)
parser.add_argument("--use-cpu", action="store_true", help="Use CPU")
opt = parser.parse_args()
return opt
def main():
opt = parse_args()
output_path = opt.output_prediction_path
logger.info(f"output_path: '{output_path}'")
embedding_path = opt.output_embedding_path
logger.info(f"part_embedding_path: '{embedding_path}'")
features = torch.load(opt.features_path)
# Encode sex and tissue site if provided
sex_encoded = None
if opt.sex:
sex_encoded = encode_sex(opt.sex)
logger.info(f"Using sex: {opt.sex} (encoded as {sex_encoded})")
tissue_site_idx = None
if opt.tissue_site:
tissue_site_idx = encode_tissue_site(opt.tissue_site)
logger.info(
f"Using tissue site: {opt.tissue_site} (encoded as {tissue_site_idx})"
)
results_df, part_embedding = run(
features=features,
model_path=opt.model_path,
metastatic=opt.metastatic,
batch_size=opt.batch_size,
num_workers=opt.num_workers,
use_cpu=opt.use_cpu,
sex=sex_encoded,
tissue_site_idx=tissue_site_idx,
)
results_df.to_csv(output_path, index=False)
logger.info(f"Wrote {output_path}")
if embedding_path:
torch.save(part_embedding, embedding_path)
logger.info(f"Wrote {embedding_path}")
if __name__ == "__main__":
main()