raylim commited on
Commit
6bd4e42
·
unverified ·
0 Parent(s):

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/.DS_Store
2
+ __pycache__/
3
+ *.log
4
+ .venv*/
5
+ venv*/
6
+ *.pyc
7
+ outputs/
8
+ *.env
9
+ tmp*
10
+ *~
11
+ *.swp
12
+ .idea/
13
+ .vscode/
14
+ data/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mosaic: H&E Whole Slide Image Cancer Subtype and Biomarker Inference
2
+
3
+ Mosaic is a deep learning model designed for predicting cancer subtypes and biomarkers from Hematoxylin and Eosin (H&E) stained whole slide images (WSIs). This repository provides the code, pre-trained models, and instructions to use Mosaic for your own datasets.
4
+
5
+ ## Table of Contents
6
+
7
+ - [Installation](#installation)
8
+ - [Usage](#usage)
9
+
10
+ ### System requirements
11
+
12
+ Supported systems:
13
+
14
+ - Linux (x86) with GPU (NVIDIA CUDA)
15
+
16
+ ### Pre-requisites
17
+
18
+ - [python3.11+](https://www.python.org/)
19
+ - [uv](https://docs.astral.sh/uv/)
20
+
21
+ ```bash
22
+ curl -LsSf https://astral.sh/uv/install.sh | sh
23
+ ```
24
+
25
+ ## Installation
26
+
27
+ ```bash
28
+ uv pip install git+ssh://git@github.com/pathology-data-mining/paladin_webapp.git@dev
29
+ ```
30
+
31
+ ## Usage
32
+
33
+ ### Initial Setup
34
+
35
+ <b>NOTE</b>: In order to run this app, the user needs to be added to the [PDM Group](https://huggingface.co/PDM-Group) and the user needs to set the following environment variable. The token may be obtained from clicking on the user icon on the top right of the HuggingFace website and selecting "Access Tokens". When creating the token, select all read options for your private space and the PDM-Group space.
36
+
37
+ ```bash
38
+ export HF_TOKEN="TOKEN-FROM-HUGGINGFACE"
39
+ ```
40
+
41
+ Additionally, set the location for huggingface home where models and other data from HuggingFace may be downloaded.
42
+
43
+ ```bash
44
+ export HF_HOME="PATH-TO-HUGGINGFACE-HOME"
45
+ ```
46
+
47
+ ### Web Application
48
+
49
+ Run the web application with:
50
+
51
+ ```bash
52
+ mosaic_app
53
+ ```
54
+
55
+ It will start a web server on port 7860 by default. You can access the web interface by navigating to `http://localhost:7860` in your web browser.
56
+
57
+ ### Command Line Interface
58
+
59
+ To process a WSI, use the following command:
60
+
61
+ ```bash
62
+ mosaic_app --slide-path /path/to/your/wsi.svs --output-dir /path/to/output/directory
63
+ ```
64
+
65
+ To process a batch of WSIs, use:
66
+
67
+ ```bash
68
+ mosaic_app --slide-csv /path/to/your/wsi_list.csv --output-dir /path/to/output/directory
69
+ ```
70
+
71
+ The CSV file should at least contain columns `Slide`, and `Site Type`.
72
+ Optionally, it can also contain `Cancer Subtype`, `Segmentation Config`, and `IHC Subtype`.
73
+
74
+ - `Slide` should contain the full path to the WSI file.
75
+ - `Site Type` should be one of `Primary`, or `Metastatic`.
76
+ - `Cancer Subtype` should be the oncotree code for the cancer subtype.
77
+ - `Segmentation Config` should be one of `Biopsy`, `Resection`, or `TCGA`.
78
+ - `IHC Subtype` should be one of `HR+/HER2+`, `HR+/HER2-`, `HR-/HER2+`, or `HR-/HER2-`.
79
+
80
+ See additional options with the help command. This command may take a few seconds to run:
81
+
82
+ ```bash
83
+ mosaic_app --help
84
+ ```
85
+
86
+ If setting port to run in server mode, you may check for available ports using `ss -tuln | grep :PORT` where PORT is the port number you want to check. No output indicates the port may be available. If port is available, set environment variable `export GRADIO_SERVER_PORT="PORT"`
87
+
88
+ ### Notes
89
+
90
+ - The first time you run the application, it will download the necessary models from HuggingFace. This may take some time depending on your internet connection.
91
+ - The models are downloaded to a directory relative to where you run the application. (A subdirectory named `data`).
pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["uv_build>=0.8.1,<0.9.0"]
3
+ build-backend = "uv_build"
4
+
5
+ [project]
6
+ name = "mosaic"
7
+ version = "0.1.0"
8
+ description = "Mussel-Aeon-Paladin workflow for digital pathology"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "gradio>=5.49.0",
13
+ "loguru>=0.7.3",
14
+ "memory-profiler>=0.61.0",
15
+ "mussel[torch-gpu]",
16
+ "paladin",
17
+ ]
18
+
19
+ [project.scripts]
20
+ aeon_inference = "mosaic.inference.aeon:main"
21
+ paladin_inference = "mosaic.inference.paladin:main"
22
+ mosaic = "mosaic.gradio_app:main"
23
+
24
+ [dependency-groups]
25
+ dev = ["black>=25.1.0", "pylint>=3.3.6"]
26
+
27
+ [tool.pylint."messages control"]
28
+ disable = [
29
+ "logging-fstring-interpolation",
30
+ "broad-exception-caught",
31
+ "unspecified-encoding",
32
+ ]
33
+
34
+ [tool.uv.sources]
35
+ paladin = { git = "ssh://git@github.com/pathology-data-mining/paladin.git", rev = "dev" }
36
+ mussel = { git = "https://github.com/pathology-data-mining/Mussel.git", rev = "ray-dev" }
src/mosaic/__init__.py ADDED
File without changes
src/mosaic/favicon.svg ADDED
src/mosaic/gradio_app.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import pickle
5
+ from mussel.models import ModelType
6
+ from mussel.utils import get_features, segment_tissue, filter_features
7
+ from mussel.utils.segment import draw_slide_mask
8
+ from mussel.cli.tessellate import BiopsySegConfig, ResectionSegConfig, TcgaSegConfig
9
+ import torch
10
+ from pathlib import Path
11
+ from huggingface_hub import snapshot_download
12
+ import tempfile
13
+ import requests
14
+
15
+ from mosaic.inference import run_aeon, run_paladin
16
+ from loguru import logger
17
+
18
+ current_dir = Path(__file__).parent
19
+
20
+ # This path should be outside your project directory if running locally
21
+ TEMP_USER_DATA_DIR = Path(tempfile.gettempdir()) / "mosaic_user_data"
22
+
23
+ IHC_SUBTYPES = ["", "HR+/HER2+", "HR+/HER2-", "HR-/HER2+", "HR-/HER2-"]
24
+
25
+ SETTINGS_COLUMNS = [
26
+ "Slide",
27
+ "Site Type",
28
+ "Cancer Subtype",
29
+ "IHC Subtype",
30
+ "Segmentation Config",
31
+ ]
32
+
33
+ oncotree_code_map = {}
34
+
35
+
36
+ def get_oncotree_code_name(code):
37
+ global oncotree_code_map
38
+ if code in oncotree_code_map.keys():
39
+ return oncotree_code_map[code]
40
+
41
+ url = f"https://oncotree.mskcc.org/api/tumorTypes/search/code/{code}?exactMatch=true&version=oncotree_2025_04_08"
42
+ response = requests.get(url)
43
+ code_name = "Unknown"
44
+ if response.status_code == 200:
45
+ data = response.json()
46
+ if data:
47
+ code_name = data[0]["name"]
48
+
49
+ oncotree_code_map[code] = code_name
50
+ return code_name
51
+
52
+
53
+ def download_and_process_models():
54
+ global cancer_subtype_name_map, reversed_cancer_subtype_name_map, cancer_subtypes
55
+ snapshot_download(repo_id="PDM-Group/paladin-aeon-models", local_dir="data")
56
+
57
+ model_map = pd.read_csv(
58
+ "data/paladin_model_map.csv",
59
+ )
60
+ cancer_subtypes = model_map["cancer_subtype"].unique().tolist()
61
+ cancer_subtype_name_map = {
62
+ f"{get_oncotree_code_name(code)} ({code})": code for code in cancer_subtypes
63
+ }
64
+ cancer_subtype_name_map["Unknown"] = "UNK"
65
+ reversed_cancer_subtype_name_map = {
66
+ value: key for key, value in cancer_subtype_name_map.items()
67
+ }
68
+
69
+
70
+ def create_user_directory(state, request: gr.Request):
71
+ """Create a unique directory for each user session."""
72
+ session_hash = request.session_hash
73
+ if session_hash is None:
74
+ return None, None
75
+
76
+ user_dir = TEMP_USER_DATA_DIR / session_hash
77
+ user_dir.mkdir(parents=True, exist_ok=True)
78
+ return user_dir
79
+
80
+
81
+ def load_settings(slide_csv_path):
82
+ """Load settings from CSV file and validate columns."""
83
+ settings_df = pd.read_csv(slide_csv_path, na_filter=False)
84
+ if "Segmentation Config" not in settings_df.columns:
85
+ settings_df["Segmentation Config"] = "Biopsy"
86
+ if "Cancer Subtype" not in settings_df.columns:
87
+ settings_df["Cancer Subtype"] = "Unknown"
88
+ if "IHC Subtype" not in settings_df.columns:
89
+ settings_df["IHC Subtype"] = ""
90
+ if not set(SETTINGS_COLUMNS).issubset(settings_df.columns):
91
+ raise ValueError("Missing required column in CSV file")
92
+ settings_df = settings_df[SETTINGS_COLUMNS]
93
+ return settings_df
94
+
95
+
96
+ def validate_settings(settings_df):
97
+ """Validate settings DataFrame and provide warnings for invalid entries."""
98
+ settings_df.columns = SETTINGS_COLUMNS
99
+ warnings = []
100
+ for idx, row in settings_df.iterrows():
101
+ slide_name = row["Slide"]
102
+ subtype = row["Cancer Subtype"]
103
+ if subtype in cancer_subtypes:
104
+ settings_df.at[idx, "Cancer Subtype"] = reversed_cancer_subtype_name_map[
105
+ subtype
106
+ ]
107
+ if settings_df.at[idx, "Cancer Subtype"] not in cancer_subtype_name_map.keys():
108
+ warnings.append(
109
+ f"Slide {slide_name}: Unknown cancer subtype. Valid subtypes are: {', '.join(cancer_subtype_name_map.keys())}. "
110
+ )
111
+ settings_df.at[idx, "Cancer Subtype"] = "Unknown"
112
+ if row["Site Type"] not in ["Metastatic", "Primary"]:
113
+ warnings.append(
114
+ f"Slide {slide_name}: Unknown site type. Valid types are: Metastatic, Primary. "
115
+ )
116
+ settings_df.at[idx, "Site Type"] = "Primary"
117
+ if (
118
+ "Breast" not in settings_df.at[idx, "Cancer Subtype"]
119
+ and row["IHC Subtype"] != ""
120
+ ):
121
+ warnings.append(
122
+ f"Slide {slide_name}: IHC subtype should be empty for non-breast cancer subtypes. "
123
+ )
124
+ settings_df.at[idx, "IHC Subtype"] = ""
125
+ if row["IHC Subtype"] not in IHC_SUBTYPES:
126
+ warnings.append(
127
+ f"Slide {slide_name}: Unknown IHC subtype. Valid subtypes are: {', '.join(IHC_SUBTYPES)}. "
128
+ )
129
+ settings_df.at[idx, "IHC Subtype"] = ""
130
+ if row["Segmentation Config"] not in ["Biopsy", "Resection", "TCGA"]:
131
+ warnings.append(
132
+ f"Slide {slide_name}: Unknown segmentation config. Valid configs are: Biopsy, Resection, TCGA. "
133
+ )
134
+ settings_df.at[idx, "Segmentation Config"] = "Biopsy"
135
+
136
+ if warnings:
137
+ gr.Warning("\n".join(warnings))
138
+
139
+ return settings_df
140
+
141
+
142
+ def export_to_csv(df):
143
+ if df is None or df.empty:
144
+ raise gr.Error("No data to export.")
145
+ csv_path = "paladin_results.csv"
146
+ df.to_csv(csv_path, index=False)
147
+ return csv_path
148
+
149
+
150
+ def analyze_slides(
151
+ slides,
152
+ settings_input,
153
+ user_dir,
154
+ progress=gr.Progress(track_tqdm=True),
155
+ ):
156
+ if slides is None or len(slides) == 0:
157
+ raise gr.Error("Please upload at least one slide.")
158
+ if user_dir is None:
159
+ user_dir = create_user_directory(None, gr.Request())
160
+ settings_input = validate_settings(settings_input)
161
+ if len(slides) != len(settings_input):
162
+ raise gr.Error("Missing settings for uploaded slides")
163
+
164
+ all_slide_masks = []
165
+ all_aeon_results = []
166
+ all_paladin_results = []
167
+
168
+ progress(0.0, desc="Starting analysis")
169
+ for idx, row in settings_input.iterrows():
170
+ slide_name = row["Slide"]
171
+ progress(
172
+ idx / len(settings_input),
173
+ desc=f"Analyzing {slide_name}, slide {idx + 1} of {len(settings_input)}",
174
+ )
175
+ for x in slides:
176
+ s = x.split("/")[-1]
177
+ if s == slide_name:
178
+ slide_mask = x
179
+
180
+ (
181
+ slide_mask,
182
+ aeon_results,
183
+ paladin_results,
184
+ ) = analyze_slide(
185
+ slides[idx],
186
+ row["Segmentation Config"],
187
+ row["Site Type"],
188
+ row["Cancer Subtype"],
189
+ row["IHC Subtype"],
190
+ progress=progress,
191
+ )
192
+ if aeon_results is not None:
193
+ if len(slides) > 1:
194
+ aeon_results.columns = [f"{slide_name}"]
195
+ if row["Cancer Subtype"] == "Unknown":
196
+ all_aeon_results.append(aeon_results)
197
+ if paladin_results is not None:
198
+ paladin_results.insert(
199
+ 0, "Slide", pd.Series([slide_name] * len(paladin_results))
200
+ )
201
+ all_paladin_results.append(paladin_results)
202
+ if slide_mask is not None:
203
+ all_slide_masks.append((slide_mask, slide_name))
204
+ # yield slide_mask, None, None, None # Yield intermediate results
205
+ progress(0.99, desc="Analysis complete, wrapping up results")
206
+
207
+ timestamp = pd.Timestamp.now().strftime("%Y%m%d-%H%M%S")
208
+ combined_paladin_results = (
209
+ pd.concat(all_paladin_results, ignore_index=True)
210
+ if all_paladin_results
211
+ else pd.DataFrame()
212
+ )
213
+ combined_aeon_results = gr.DataFrame(visible=False)
214
+ aeon_output = gr.DownloadButton(visible=False)
215
+ if all_aeon_results:
216
+ combined_aeon_results = pd.concat(all_aeon_results, axis=1)
217
+ combined_aeon_results.reset_index(inplace=True)
218
+
219
+ combined_aeon_results = combined_aeon_results.round(3)
220
+ cancer_subtype_names = [
221
+ f"{get_oncotree_code_name(code)} ({code})"
222
+ for code in combined_aeon_results["Cancer Subtype"]
223
+ ]
224
+ combined_aeon_results["Cancer Subtype"] = cancer_subtype_names
225
+
226
+ aeon_output_path = user_dir / f"aeon_results-{timestamp}.csv"
227
+ combined_aeon_results.to_csv(aeon_output_path)
228
+
229
+ combined_aeon_results = gr.DataFrame(
230
+ combined_aeon_results,
231
+ visible=True,
232
+ column_widths=["4px"] + ["2px"] * (combined_aeon_results.shape[1] - 1),
233
+ )
234
+ aeon_output = gr.DownloadButton(value=aeon_output_path, visible=True)
235
+
236
+ # Convert Oncotree codes to names for display
237
+ cancer_subtype_names = [
238
+ f"{get_oncotree_code_name(code)} ({code})"
239
+ for code in combined_paladin_results["Cancer Subtype"]
240
+ ]
241
+ combined_paladin_results["Cancer Subtype"] = cancer_subtype_names
242
+ if len(combined_paladin_results) > 0:
243
+ combined_paladin_results["Score"] = combined_paladin_results["Score"].round(3)
244
+
245
+ paladin_output = gr.DownloadButton(visible=False)
246
+ if len(combined_paladin_results) > 0:
247
+ paladin_output_path = user_dir / f"paladin_results-{timestamp}.csv"
248
+ combined_paladin_results.to_csv(paladin_output_path, index=False)
249
+ paladin_output = gr.DownloadButton(value=paladin_output_path, visible=True)
250
+
251
+ progress(1.0, desc="All done!")
252
+
253
+ return (
254
+ all_slide_masks,
255
+ combined_aeon_results,
256
+ aeon_output,
257
+ combined_paladin_results if len(combined_paladin_results) > 0 else None,
258
+ paladin_output,
259
+ user_dir,
260
+ )
261
+
262
+
263
+ def analyze_slide(
264
+ slide_path,
265
+ seg_config,
266
+ site_type,
267
+ cancer_subtype,
268
+ ihc_subtype="",
269
+ num_workers=4,
270
+ progress=gr.Progress(track_tqdm=True),
271
+ ):
272
+ if slide_path is None:
273
+ raise gr.Error("Please upload a slide.")
274
+ # Step 1: Segment tissue
275
+ start_time = pd.Timestamp.now()
276
+
277
+ if seg_config == "Biopsy":
278
+ seg_config = BiopsySegConfig()
279
+ elif seg_config == "Resection":
280
+ seg_config = ResectionSegConfig()
281
+ elif seg_config == "TCGA":
282
+ seg_config = TcgaSegConfig()
283
+ else:
284
+ raise ValueError(f"Unknown segmentation configuration: {seg_config}")
285
+
286
+ progress(0.0, desc="Segmenting tissue")
287
+ logger.info(f"Segmenting tissue for slide: {slide_path}")
288
+ if values := segment_tissue(
289
+ slide_path=slide_path,
290
+ patch_size=224,
291
+ mpp=0.5,
292
+ seg_level=-1,
293
+ segment_threshold=seg_config.segment_threshold,
294
+ median_blur_ksize=seg_config.median_blur_ksize,
295
+ morphology_ex_kernel=seg_config.morphology_ex_kernel,
296
+ tissue_area_threshold=seg_config.tissue_area_threshold,
297
+ hole_area_threshold=seg_config.hole_area_threshold,
298
+ max_num_holes=seg_config.max_num_holes,
299
+ ):
300
+ polygon, _, coords, attrs = values
301
+ else:
302
+ gr.Warning(f"No tissue detected in slide: {slide_path}")
303
+ return None, None, None
304
+ end_time = pd.Timestamp.now()
305
+ logger.info(f"Tissue segmentation took {end_time - start_time}")
306
+ logger.info(f"Found {len(coords)} tissue tiles")
307
+ progress(0.2, desc="Tissue segmented")
308
+
309
+ # Draw slide mask for visualization
310
+ logger.info("Drawing slide mask")
311
+ progress(0.25, desc="Drawing slide mask")
312
+ slide_mask = draw_slide_mask(
313
+ slide_path, polygon, outline="black", fill=(255, 0, 0, 80), vis_level=-1
314
+ )
315
+ logger.info("Slide mask drawn")
316
+
317
+ # Step 2: Extract features with CTransPath
318
+ start_time = pd.Timestamp.now()
319
+ progress(0.3, desc="Extracting CTransPath features")
320
+ logger.info("Extracting CTransPath features")
321
+ ctranspath_features, _ = get_features(
322
+ coords,
323
+ slide_path,
324
+ attrs,
325
+ model_type=ModelType.CTRANSPATH,
326
+ model_path="data/ctranspath.pth",
327
+ num_workers=num_workers,
328
+ batch_size=64,
329
+ use_gpu=True,
330
+ )
331
+ end_time = pd.Timestamp.now()
332
+ max_gpu_memory = (
333
+ torch.cuda.max_memory_allocated() / (1024**3)
334
+ if torch.cuda.is_available()
335
+ else 0
336
+ )
337
+ logger.info(
338
+ f"CTransPath Feature extraction took {end_time - start_time} and used {max_gpu_memory:.2f} GB GPU memory"
339
+ )
340
+
341
+ torch.cuda.reset_peak_memory_stats()
342
+
343
+ # Step 3: Filter features using marker classifier
344
+ start_time = pd.Timestamp.now()
345
+ marker_classifier = pickle.load(open("data/marker_classifier.pkl", "rb"))
346
+ progress(0.35, desc="Filtering features with marker classifier")
347
+ logger.info("Filtering features with marker classifier")
348
+ _, filtered_coords = filter_features(
349
+ ctranspath_features,
350
+ coords,
351
+ marker_classifier,
352
+ threshold=0.25,
353
+ )
354
+ end_time = pd.Timestamp.now()
355
+ logger.info(f"Feature filtering took {end_time - start_time}")
356
+ logger.info(
357
+ f"Filtered from {len(coords)} to {len(filtered_coords)} tiles using marker classifier"
358
+ )
359
+
360
+ # Step 4: Extract features with Optimus on filtered coords
361
+ start_time = pd.Timestamp.now()
362
+ progress(0.4, desc="Extracting Optimus features")
363
+ logger.info("Extracting Optimus features")
364
+ features, _ = get_features(
365
+ filtered_coords,
366
+ slide_path,
367
+ attrs,
368
+ model_type=ModelType.OPTIMUS,
369
+ model_path="data/optimus.pkl",
370
+ num_workers=num_workers,
371
+ batch_size=64,
372
+ use_gpu=True,
373
+ )
374
+ end_time = pd.Timestamp.now()
375
+ max_gpu_memory = (
376
+ torch.cuda.max_memory_allocated() / (1024**3)
377
+ if torch.cuda.is_available()
378
+ else 0
379
+ )
380
+ logger.info(
381
+ f"Optimus Feature extraction took {end_time - start_time} and used {max_gpu_memory:.2f} GB GPU memory"
382
+ )
383
+
384
+ torch.cuda.reset_peak_memory_stats()
385
+
386
+ # Step 3: Run Aeon to predict histology if not supplied
387
+ if cancer_subtype == "Unknown":
388
+ start_time = pd.Timestamp.now()
389
+ progress(0.9, desc="Running Aeon for cancer subtype inference")
390
+ logger.info("Running Aeon for cancer subtype inference")
391
+ aeon_results, _ = run_aeon(
392
+ features=features,
393
+ model_path="data/aeon_model.pkl",
394
+ metastatic=(site_type == "Metastatic"),
395
+ batch_size=8,
396
+ num_workers=num_workers,
397
+ use_cpu=False,
398
+ )
399
+ end_time = pd.Timestamp.now()
400
+ max_gpu_memory = (
401
+ torch.cuda.max_memory_allocated() / (1024**3)
402
+ if torch.cuda.is_available()
403
+ else 0
404
+ )
405
+ logger.info(
406
+ f"Aeon inference took {end_time - start_time} and used {max_gpu_memory:.2f} GB GPU memory"
407
+ )
408
+ torch.cuda.reset_peak_memory_stats()
409
+ else:
410
+ cancer_subtype_code = cancer_subtype_name_map.get(cancer_subtype)
411
+ aeon_results = pd.DataFrame(
412
+ {
413
+ "Cancer Subtype": [cancer_subtype_code],
414
+ "Confidence": [1.0],
415
+ }
416
+ )
417
+ logger.info(f"Using user-supplied cancer subtype: {cancer_subtype}")
418
+
419
+ # Step 4: Run Paladin to predict biomarkers
420
+ if len(aeon_results) == 0:
421
+ logger.warning("No Aeon results, skipping Paladin inference")
422
+ return slide_mask, None, None
423
+ start_time = pd.Timestamp.now()
424
+ progress(0.95, desc="Running Paladin for biomarker inference")
425
+ logger.info("Running Paladin for biomarker inference")
426
+ paladin_results = run_paladin(
427
+ features=features,
428
+ model_map_path="data/paladin_model_map.csv",
429
+ aeon_results=aeon_results,
430
+ metastatic=(site_type == "Metastatic"),
431
+ batch_size=8,
432
+ num_workers=num_workers,
433
+ use_cpu=False,
434
+ )
435
+ end_time = pd.Timestamp.now()
436
+ max_gpu_memory = (
437
+ torch.cuda.max_memory_allocated() / (1024**3)
438
+ if torch.cuda.is_available()
439
+ else 0
440
+ )
441
+ logger.info(
442
+ f"Paladin inference took {end_time - start_time} and used {max_gpu_memory:.2f} GB GPU memory"
443
+ )
444
+
445
+ aeon_results.set_index("Cancer Subtype", inplace=True)
446
+
447
+ return slide_mask, aeon_results, paladin_results
448
+
449
+
450
+ def launch_gradio(server_name, server_port, share):
451
+ with gr.Blocks(title="Mosaic") as demo:
452
+ user_dir_state = gr.State(None)
453
+ gr.Markdown(
454
+ "# Mosaic: H&E Whole Slide Image Cancer Subtype and Biomarker Inference"
455
+ )
456
+ gr.Markdown(
457
+ "Upload an H&E whole slide image in SVS or TIFF format. The slide will be processed to infer cancer subtype and relevant biomarkers."
458
+ )
459
+ with gr.Row():
460
+ with gr.Column():
461
+ input_slides = gr.File(
462
+ label="Upload H&E Whole Slide Image",
463
+ file_types=[".svs", ".tiff", ".tif"],
464
+ file_count="multiple",
465
+ )
466
+ site_dropdown = gr.Dropdown(
467
+ choices=["Primary", "Metastatic"],
468
+ label="Site Type",
469
+ value="Primary",
470
+ )
471
+ cancer_subtype_dropdown = gr.Dropdown(
472
+ choices=[name for name in cancer_subtype_name_map.keys()],
473
+ label="Cancer Subtype",
474
+ value="Unknown",
475
+ )
476
+ ihc_subtype_dropdown = gr.Dropdown(
477
+ choices=IHC_SUBTYPES,
478
+ label="IHC Subtype (if applicable)",
479
+ value="",
480
+ visible=False,
481
+ )
482
+ seg_config_dropdown = gr.Dropdown(
483
+ choices=["Biopsy", "Resection", "TCGA"],
484
+ label="Segmentation Config",
485
+ value="Biopsy",
486
+ )
487
+ with gr.Row():
488
+ settings_input = gr.Dataframe(
489
+ headers=[
490
+ "Slide",
491
+ "Site Type",
492
+ "Cancer Subtype",
493
+ "IHC Subtype",
494
+ "Segmentation Config",
495
+ ],
496
+ label="Current Settings",
497
+ datatype=["str", "str", "str", "str", "str"],
498
+ visible=False,
499
+ interactive=True,
500
+ static_columns="Slide",
501
+ )
502
+
503
+ with gr.Row():
504
+ settings_csv = gr.File(
505
+ file_types=[".csv"], label="Upload Settings CSV", visible=False
506
+ )
507
+
508
+ with gr.Row():
509
+ clear_button = gr.Button("Clear")
510
+ analyze_button = gr.Button("Analyze", variant="primary")
511
+ with gr.Column():
512
+ slide_masks = gr.Gallery(
513
+ label="Slide Masks",
514
+ columns=3,
515
+ object_fit="contain",
516
+ height="auto",
517
+ )
518
+ aeon_output_table = gr.Dataframe(
519
+ headers=["Cancer Subtype", "Slide Name"],
520
+ label="Cancer Subtype Inference Confidence",
521
+ datatype=["str", "number"],
522
+ visible=False,
523
+ )
524
+ aeon_download_button = gr.DownloadButton(
525
+ "Download Aeon Results as CSV",
526
+ label="Download Results",
527
+ visible=False,
528
+ )
529
+ paladin_output_table = gr.Dataframe(
530
+ headers=["Slide", "Cancer Subtype", "Biomarker", "Score"],
531
+ label="Biomarker Inference",
532
+ datatype=["str", "str", "str", "number"],
533
+ )
534
+ paladin_download_button = gr.DownloadButton(
535
+ "Download Paladin Results as CSV",
536
+ label="Download Results",
537
+ visible=False,
538
+ )
539
+
540
+ @clear_button.click(
541
+ outputs=[
542
+ input_slides,
543
+ slide_masks,
544
+ paladin_output_table,
545
+ paladin_download_button,
546
+ aeon_output_table,
547
+ aeon_download_button,
548
+ settings_input,
549
+ settings_csv,
550
+ ],
551
+ )
552
+ def clear_fn():
553
+ return (
554
+ None,
555
+ None,
556
+ None,
557
+ None,
558
+ gr.Dataframe(visible=False),
559
+ gr.DownloadButton(visible=False),
560
+ gr.Dataframe(visible=False),
561
+ gr.File(visible=False),
562
+ )
563
+
564
+ def get_settings(files, site_type, cancer_subtype, ihc_subtype, seg_config):
565
+ if files is None:
566
+ return pd.DataFrame()
567
+ settings = []
568
+ for file in files:
569
+ filename = file.name if hasattr(file, "name") else file
570
+ slide_name = filename.split("/")[-1]
571
+ settings.append(
572
+ [slide_name, site_type, cancer_subtype, ihc_subtype, seg_config]
573
+ )
574
+ df = pd.DataFrame(settings, columns=SETTINGS_COLUMNS)
575
+ return df
576
+
577
+ # Only display settings table and upload button if multiple slides are uploaded
578
+ @gr.on(
579
+ [
580
+ input_slides.change,
581
+ site_dropdown.change,
582
+ cancer_subtype_dropdown.change,
583
+ ihc_subtype_dropdown.change,
584
+ seg_config_dropdown.change,
585
+ ],
586
+ inputs=[
587
+ input_slides,
588
+ site_dropdown,
589
+ cancer_subtype_dropdown,
590
+ ihc_subtype_dropdown,
591
+ seg_config_dropdown,
592
+ ],
593
+ outputs=[settings_input, settings_csv, ihc_subtype_dropdown],
594
+ )
595
+ def update_settings(files, site_type, cancer_subtype, ihc_subtype, seg_config):
596
+ has_ihc = "Breast" in cancer_subtype
597
+ if not files:
598
+ return None, None, gr.Dropdown(visible=has_ihc)
599
+ settings_df = get_settings(
600
+ files, site_type, cancer_subtype, ihc_subtype, seg_config
601
+ )
602
+ if settings_df is not None:
603
+ has_ihc = any("Breast" in cs for cs in settings_df["Cancer Subtype"])
604
+ visible = files and len(files) > 1
605
+ return (
606
+ gr.Dataframe(settings_df, visible=visible),
607
+ gr.File(visible=visible),
608
+ gr.Dropdown(visible=has_ihc),
609
+ )
610
+
611
+ @settings_csv.upload(
612
+ inputs=[settings_csv],
613
+ outputs=[settings_input],
614
+ )
615
+ def read_settings(file):
616
+ if file is None:
617
+ return None
618
+ df = load_settings(file.name if hasattr(file, "name") else file)
619
+ return gr.Dataframe(df, visible=True)
620
+
621
+ analyze_button.click(
622
+ analyze_slides,
623
+ inputs=[
624
+ input_slides,
625
+ settings_input,
626
+ user_dir_state,
627
+ ],
628
+ outputs=[
629
+ slide_masks,
630
+ aeon_output_table,
631
+ aeon_download_button,
632
+ paladin_output_table,
633
+ paladin_download_button,
634
+ user_dir_state,
635
+ ],
636
+ queue=True,
637
+ show_progress_on=paladin_output_table,
638
+ )
639
+ settings_input.change(
640
+ validate_settings, inputs=[settings_input], outputs=[settings_input]
641
+ )
642
+ demo.load(
643
+ create_user_directory,
644
+ inputs=[user_dir_state],
645
+ outputs=[user_dir_state],
646
+ )
647
+
648
+ demo.queue(max_size=10, default_concurrency_limit=8)
649
+ demo.launch(
650
+ server_name=server_name,
651
+ share=share,
652
+ server_port=server_port,
653
+ show_error=True,
654
+ favicon_path=current_dir / "favicon.svg",
655
+ )
656
+
657
+
658
+ def main():
659
+ parser = ArgumentParser()
660
+ parser.add_argument("--debug", action="store_true", help="Enable debug logging")
661
+ parser.add_argument(
662
+ "--server-name", type=str, default="0.0.0.0", help="Server name for Gradio app"
663
+ )
664
+ parser.add_argument(
665
+ "--server-port", type=int, default=None, help="Server port for Gradio app"
666
+ )
667
+ parser.add_argument(
668
+ "--share", action="store_true", help="Share Gradio app publicly"
669
+ )
670
+ parser.add_argument(
671
+ "--slide-csv",
672
+ type=str,
673
+ help="CSV file with slide settings (for batch processing), see README for format",
674
+ )
675
+ parser.add_argument(
676
+ "--slide-path",
677
+ type=str,
678
+ help="Path to a single slide (for single slide processing), not used if --slide-csv is provided",
679
+ )
680
+ parser.add_argument(
681
+ "--site-type",
682
+ type=str,
683
+ choices=["Primary", "Metastatic"],
684
+ default="Primary",
685
+ help="Site type of the slide (for single slide processing)",
686
+ )
687
+ parser.add_argument(
688
+ "--cancer-subtype",
689
+ type=str,
690
+ default="Unknown",
691
+ help="Cancer subtype of the slide (for single slide processing), use 'Unknown' to infer with Aeon",
692
+ )
693
+ parser.add_argument(
694
+ "--ihc-subtype",
695
+ type=str,
696
+ choices=IHC_SUBTYPES,
697
+ default="",
698
+ help="IHC subtype if cancer subtype is breast (for single slide processing)",
699
+ )
700
+ parser.add_argument(
701
+ "--segmentation-config",
702
+ type=str,
703
+ choices=["Biopsy", "Resection", "TCGA"],
704
+ default="Biopsy",
705
+ help="Segmentation configuration (for single slide processing)",
706
+ )
707
+ parser.add_argument(
708
+ "--output-dir", type=str, help="Directory to save output results"
709
+ )
710
+ parser.add_argument(
711
+ "--num-workers",
712
+ type=int,
713
+ default=4,
714
+ help="Number of workers for feature extraction",
715
+ )
716
+ args = parser.parse_args()
717
+ if args.debug:
718
+ logger.add("debug.log", level="DEBUG")
719
+ logger.debug("Debug logging enabled")
720
+
721
+ download_and_process_models()
722
+
723
+ if args.slide_path and not args.slide_csv:
724
+ # Single slide processing mode
725
+ if not args.output_dir:
726
+ raise ValueError("Please provide --output-dir to save results")
727
+ settings_df = pd.DataFrame(
728
+ [
729
+ [
730
+ args.slide_path,
731
+ args.site_type,
732
+ args.cancer_subtype,
733
+ args.ihc_subtype,
734
+ args.segmentation_config,
735
+ ]
736
+ ],
737
+ columns=SETTINGS_COLUMNS,
738
+ )
739
+ settings_df = validate_settings(settings_df)
740
+ slide_mask, aeon_results, paladin_results = analyze_slide(
741
+ args.slide_path,
742
+ args.segmentation_config,
743
+ args.site_type,
744
+ args.cancer_subtype,
745
+ args.ihc_subtype,
746
+ num_workers=args.num_workers,
747
+ )
748
+ output_dir = Path(args.output_dir)
749
+ output_dir.mkdir(parents=True, exist_ok=True)
750
+ slide_name = Path(args.slide_path).stem
751
+ if slide_mask is not None:
752
+ mask_path = output_dir / f"{slide_name}_mask.png"
753
+ slide_mask.save(mask_path)
754
+ logger.info(f"Saved slide mask to {mask_path}")
755
+ if aeon_results is not None:
756
+ aeon_output_path = output_dir / f"{slide_name}_aeon_results.csv"
757
+ aeon_results.reset_index().to_csv(aeon_output_path, index=False)
758
+ logger.info(f"Saved Aeon results to {aeon_output_path}")
759
+ if paladin_results is not None and len(paladin_results) > 0:
760
+ paladin_output_path = output_dir / f"{slide_name}_paladin_results.csv"
761
+ paladin_results.to_csv(paladin_output_path, index=False)
762
+ logger.info(f"Saved Paladin results to {paladin_output_path}")
763
+ elif args.slide_csv:
764
+ if not args.output_dir:
765
+ raise ValueError("Please provide --output-dir to save results")
766
+ # Batch processing mode
767
+
768
+ output_dir = Path(args.output_dir)
769
+ output_dir.mkdir(parents=True, exist_ok=True)
770
+ all_paladin_results = []
771
+ all_aeon_results = []
772
+ settings_df = load_settings(args.slide_csv)
773
+ settings_df = validate_settings(settings_df)
774
+ for idx, row in settings_df.iterrows():
775
+ slide_path = row["Slide"]
776
+ seg_config = row["Segmentation Config"]
777
+ site_type = row["Site Type"]
778
+ cancer_subtype = row["Cancer Subtype"]
779
+ ihc_subtype = row.get("IHC Subtype", "")
780
+ logger.info(
781
+ f"Processing slide {slide_path} ({idx + 1} of {len(settings_df)})"
782
+ )
783
+ slide_mask, aeon_results, paladin_results = analyze_slide(
784
+ slide_path,
785
+ seg_config,
786
+ site_type,
787
+ cancer_subtype,
788
+ ihc_subtype,
789
+ num_workers=args.num_workers,
790
+ )
791
+ slide_name = Path(slide_path).stem
792
+ mask_path = output_dir / f"{slide_name}_mask.png"
793
+ slide_mask.save(mask_path)
794
+ logger.info(f"Saved slide mask to {mask_path}")
795
+ if aeon_results is not None:
796
+ aeon_output_path = output_dir / f"{slide_name}_aeon_results.csv"
797
+ aeon_results.reset_index().to_csv(aeon_output_path, index=False)
798
+ logger.info(f"Saved Aeon results to {aeon_output_path}")
799
+ if paladin_results is not None and len(paladin_results) > 0:
800
+ paladin_output_path = output_dir / f"{slide_name}_paladin_results.csv"
801
+ paladin_results.to_csv(paladin_output_path, index=False)
802
+ logger.info(f"Saved Paladin results to {paladin_output_path}")
803
+ if aeon_results is not None:
804
+ aeon_results.columns = [f"{slide_name}"]
805
+ all_aeon_results.append(aeon_results)
806
+ if paladin_results is not None and len(paladin_results) > 0:
807
+ paladin_results.insert(
808
+ 0, "Slide", pd.Series([slide_name] * len(paladin_results))
809
+ )
810
+ all_paladin_results.append(paladin_results)
811
+ if all_aeon_results:
812
+ combined_aeon_results = pd.concat(all_aeon_results, axis=1)
813
+ combined_aeon_results.reset_index(inplace=True)
814
+ cancer_subtype_names = [
815
+ f"{get_oncotree_code_name(code)} ({code})"
816
+ for code in combined_aeon_results["Cancer Subtype"]
817
+ ]
818
+ combined_aeon_results["Cancer Subtype"] = cancer_subtype_names
819
+ combined_aeon_output_path = output_dir / "combined_aeon_results.csv"
820
+ combined_aeon_results.to_csv(combined_aeon_output_path, index=False)
821
+ logger.info(f"Saved combined Aeon results to {combined_aeon_output_path}")
822
+ if all_paladin_results:
823
+ combined_paladin_results = pd.concat(all_paladin_results, ignore_index=True)
824
+ cancer_subtype_names = [
825
+ f"{get_oncotree_code_name(code)} ({code})"
826
+ for code in combined_paladin_results["Cancer Subtype"]
827
+ ]
828
+ combined_paladin_results["Cancer Subtype"] = cancer_subtype_names
829
+ combined_paladin_output_path = output_dir / "combined_paladin_results.csv"
830
+ combined_paladin_results.to_csv(combined_paladin_output_path, index=False)
831
+ logger.info(
832
+ f"Saved combined Paladin results to {combined_paladin_output_path}"
833
+ )
834
+ else:
835
+ launch_gradio(
836
+ server_name=args.server_name,
837
+ server_port=args.server_port,
838
+ share=args.share,
839
+ )
840
+
841
+
842
+ if __name__ == "__main__":
843
+ main()
src/mosaic/inference/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .aeon import run as run_aeon
2
+ from .paladin import run as run_paladin
src/mosaic/inference/aeon.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle # nosec
2
+ import sys
3
+ from argparse import ArgumentParser
4
+ from pathlib import Path
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+
10
+ from mosaic.inference.data import (
11
+ SiteType,
12
+ TileFeatureTensorDataset,
13
+ INT_TO_CANCER_TYPE_MAP,
14
+ CANCER_TYPE_TO_INT_MAP,
15
+ )
16
+
17
+ from loguru import logger
18
+
19
+ cancer_types_to_drop = [
20
+ "UDMN",
21
+ "ADNOS",
22
+ "CUP",
23
+ "CUPNOS",
24
+ "BRCNOS",
25
+ "GNOS",
26
+ "SCCNOS",
27
+ "PDC",
28
+ "NSCLC",
29
+ "BRCA",
30
+ "SARCNOS",
31
+ "NETNOS",
32
+ "MEL",
33
+ "RCC",
34
+ "BRCANOS",
35
+ "COADREAD",
36
+ "MUP",
37
+ "NECNOS",
38
+ "UCEC",
39
+ "NOT",
40
+ ]
41
+ col_indices_to_drop = [CANCER_TYPE_TO_INT_MAP[x] for x in cancer_types_to_drop]
42
+
43
+
44
+ BATCH_SIZE = 8
45
+ NUM_WORKERS = 8
46
+
47
+
48
+ def run(
49
+ features, model_path, metastatic=False, batch_size=8, num_workers=8, use_cpu=False
50
+ ):
51
+ device = torch.device(
52
+ "cuda" if not use_cpu and torch.cuda.is_available() else "cpu"
53
+ )
54
+
55
+ with open(model_path, "rb") as f:
56
+ model = pickle.load(f) # nosec
57
+ model.to(device)
58
+ model.eval()
59
+
60
+ site_type = SiteType.METASTASIS if metastatic else SiteType.PRIMARY
61
+
62
+ # For UI, InferenceDataset will just be a single slide. Sample id is not relevant.
63
+ dataset = TileFeatureTensorDataset(
64
+ site_type=site_type,
65
+ tile_features=features,
66
+ n_max_tiles=20000,
67
+ )
68
+ dataloader = DataLoader(
69
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
70
+ )
71
+
72
+ results = []
73
+ batch = next(iter(dataloader))
74
+ with torch.no_grad():
75
+ batch["tile_tensor"] = batch["tile_tensor"].to(device)
76
+ y = model(batch)
77
+ y["logits"][:, col_indices_to_drop] = -1e6
78
+
79
+ batch_size = y["logits"].shape[0]
80
+ assert batch_size == 1
81
+
82
+ softmax = torch.nn.functional.softmax(y["logits"][0], dim=0)
83
+ argmax = torch.argmax(softmax, dim=0)
84
+ class_assignment = INT_TO_CANCER_TYPE_MAP[argmax.item()]
85
+ max_confidence = softmax[argmax].item()
86
+ mean_confidence = torch.mean(softmax).item()
87
+
88
+ logger.info(
89
+ f"class {class_assignment} : confidence {max_confidence:8.5f} "
90
+ f"(mean {mean_confidence:8.5f})"
91
+ )
92
+
93
+ part_embedding = y["whole_part_representation"][0].cpu()
94
+
95
+ for cancer_subtype, j in sorted(CANCER_TYPE_TO_INT_MAP.items()):
96
+ confidence = softmax[j].item()
97
+ results.append((cancer_subtype, confidence))
98
+ results.sort(key=lambda row: row[1], reverse=True)
99
+
100
+ results_df = pd.DataFrame(results, columns=["Cancer Subtype", "Confidence"])
101
+
102
+ return results_df, part_embedding
103
+
104
+
105
+ def parse_args():
106
+ parser = ArgumentParser(
107
+ description="Run Aeon inference on a specified set of slides"
108
+ )
109
+ parser.add_argument(
110
+ "-i",
111
+ "--features-path",
112
+ required=True,
113
+ help="Pathname to a .pt file with optimus tile features for this slide",
114
+ )
115
+ parser.add_argument(
116
+ "-o",
117
+ "--output-prediction-path",
118
+ help="The filename for the Aeon predictions file (CSV)",
119
+ required=True,
120
+ )
121
+ parser.add_argument(
122
+ "--output-embedding-path",
123
+ help="The filename for the whole-part representation of the slide (.pt)",
124
+ )
125
+ parser.add_argument(
126
+ "--model-path",
127
+ type=str,
128
+ help="Pathname to the pickle file for an Aeon model",
129
+ required=True,
130
+ )
131
+ parser.add_argument(
132
+ "--metastatic", action="store_true", help="Tissue is from a metastatic site"
133
+ )
134
+ parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size")
135
+ parser.add_argument(
136
+ "--num-workers", type=int, default=NUM_WORKERS, help="Number of workers"
137
+ )
138
+ parser.add_argument("--use-cpu", action="store_true", help="Use CPU")
139
+ opt = parser.parse_args()
140
+
141
+ return opt
142
+
143
+
144
+ def main():
145
+ opt = parse_args()
146
+
147
+ output_path = opt.output_prediction_path
148
+ logger.info(f"output_path: '{output_path}'")
149
+
150
+ embedding_path = opt.output_embedding_path
151
+ logger.info(f"part_embedding_path: '{embedding_path}'")
152
+
153
+ features = torch.load(opt.features_path)
154
+
155
+ results_df, part_embedding = run(
156
+ features=features,
157
+ model_path=opt.model_path,
158
+ metastatic=opt.metastatic,
159
+ batch_size=opt.batch_size,
160
+ num_workers=opt.num_workers,
161
+ use_cpu=opt.use_cpu,
162
+ )
163
+
164
+ results_df.to_csv(output_path, index=False)
165
+ logger.info(f"Wrote {output_path}")
166
+
167
+ if embedding_path:
168
+ torch.save(part_embedding, embedding_path)
169
+ logger.info(f"Wrote {embedding_path}")
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()
src/mosaic/inference/data.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import numpy as np
7
+
8
+ CANCER_TYPE_TO_INT_MAP = {
9
+ "AASTR": 0,
10
+ "ACC": 1,
11
+ "ACRM": 2,
12
+ "ACYC": 3,
13
+ "ADNOS": 4,
14
+ "ALUCA": 5,
15
+ "AMPCA": 6,
16
+ "ANGS": 7,
17
+ "ANSC": 8,
18
+ "AODG": 9,
19
+ "APAD": 10,
20
+ "ARMM": 11,
21
+ "ARMS": 12,
22
+ "ASTR": 13,
23
+ "ATM": 14,
24
+ "BA": 15,
25
+ "BCC": 16,
26
+ "BLAD": 17,
27
+ "BLCA": 18,
28
+ "BMGCT": 19,
29
+ "BRCA": 20,
30
+ "BRCANOS": 21,
31
+ "BRCNOS": 22,
32
+ "CCOV": 23,
33
+ "CCRCC": 24,
34
+ "CESC": 25,
35
+ "CHDM": 26,
36
+ "CHOL": 27,
37
+ "CHRCC": 28,
38
+ "CHS": 29,
39
+ "COAD": 30,
40
+ "COADREAD": 31,
41
+ "CSCC": 32,
42
+ "CSCLC": 33,
43
+ "CUP": 34,
44
+ "CUPNOS": 35,
45
+ "DA": 36,
46
+ "DASTR": 37,
47
+ "DDLS": 38,
48
+ "DES": 39,
49
+ "DIFG": 40,
50
+ "DSRCT": 41,
51
+ "DSTAD": 42,
52
+ "ECAD": 43,
53
+ "EGC": 44,
54
+ "EHAE": 45,
55
+ "EHCH": 46,
56
+ "EMPD": 47,
57
+ "EOV": 48,
58
+ "EPDCA": 49,
59
+ "EPIS": 50,
60
+ "EPM": 51,
61
+ "ERMS": 52,
62
+ "ES": 53,
63
+ "ESCA": 54,
64
+ "ESCC": 55,
65
+ "GB": 56,
66
+ "GBAD": 57,
67
+ "GBC": 58,
68
+ "GBM": 59,
69
+ "GCCAP": 60,
70
+ "GEJ": 61,
71
+ "GINET": 62,
72
+ "GIST": 63,
73
+ "GNOS": 64,
74
+ "GRCT": 65,
75
+ "HCC": 66,
76
+ "HGGNOS": 67,
77
+ "HGNEC": 68,
78
+ "HGSFT": 69,
79
+ "HGSOC": 70,
80
+ "HNMUCM": 71,
81
+ "HNSC": 72,
82
+ "IDC": 73,
83
+ "IHCH": 74,
84
+ "ILC": 75,
85
+ "LGGNOS": 76,
86
+ "LGSOC": 77,
87
+ "LMS": 78,
88
+ "LNET": 79,
89
+ "LUAD": 80,
90
+ "LUAS": 81,
91
+ "LUCA": 82,
92
+ "LUNE": 83,
93
+ "LUPC": 84,
94
+ "LUSC": 85,
95
+ "LXSC": 86,
96
+ "MAAP": 87,
97
+ "MACR": 88,
98
+ "MBC": 89,
99
+ "MCC": 90,
100
+ "MDLC": 91,
101
+ "MEL": 92,
102
+ "MFH": 93,
103
+ "MFS": 94,
104
+ "MGCT": 95,
105
+ "MNG": 96,
106
+ "MOV": 97,
107
+ "MPNST": 98,
108
+ "MRLS": 99,
109
+ "MUP": 100,
110
+ "MXOV": 101,
111
+ "NBL": 102,
112
+ "NECNOS": 103,
113
+ "NETNOS": 104,
114
+ "NOT": 105,
115
+ "NPC": 106,
116
+ "NSCLC": 107,
117
+ "NSCLCPD": 108,
118
+ "NSGCT": 109,
119
+ "OCS": 110,
120
+ "OCSC": 111,
121
+ "ODG": 112,
122
+ "OOVC": 113,
123
+ "OPHSC": 114,
124
+ "OS": 115,
125
+ "PAAC": 116,
126
+ "PAAD": 117,
127
+ "PAASC": 118,
128
+ "PAMPCA": 119,
129
+ "PANET": 120,
130
+ "PAST": 121,
131
+ "PDC": 122,
132
+ "PECOMA": 123,
133
+ "PEMESO": 124,
134
+ "PHC": 125,
135
+ "PLBMESO": 126,
136
+ "PLEMESO": 127,
137
+ "PLMESO": 128,
138
+ "PRAD": 129,
139
+ "PRCC": 130,
140
+ "PSEC": 131,
141
+ "PTAD": 132,
142
+ "RBL": 133,
143
+ "RCC": 134,
144
+ "RCSNOS": 135,
145
+ "READ": 136,
146
+ "RMS": 137,
147
+ "SARCNOS": 138,
148
+ "SBC": 139,
149
+ "SBOV": 140,
150
+ "SBWDNET": 141,
151
+ "SCBC": 142,
152
+ "SCCNOS": 143,
153
+ "SCHW": 144,
154
+ "SCLC": 145,
155
+ "SCUP": 146,
156
+ "SDCA": 147,
157
+ "SEM": 148,
158
+ "SFT": 149,
159
+ "SKCM": 150,
160
+ "SOC": 151,
161
+ "SPDAC": 152,
162
+ "SSRCC": 153,
163
+ "STAD": 154,
164
+ "SYNS": 155,
165
+ "TAC": 156,
166
+ "THAP": 157,
167
+ "THHC": 158,
168
+ "THME": 159,
169
+ "THPA": 160,
170
+ "THPD": 161,
171
+ "THYC": 162,
172
+ "THYM": 163,
173
+ "TYST": 164,
174
+ "UCCC": 165,
175
+ "UCEC": 166,
176
+ "UCP": 167,
177
+ "UCS": 168,
178
+ "UCU": 169,
179
+ "UDMN": 170,
180
+ "UEC": 171,
181
+ "ULMS": 172,
182
+ "UM": 173,
183
+ "UMEC": 174,
184
+ "URCC": 175,
185
+ "USARC": 176,
186
+ "USC": 177,
187
+ "UTUC": 178,
188
+ "VMM": 179,
189
+ "VSC": 180,
190
+ "WDLS": 181,
191
+ "WT": 182,
192
+ }
193
+ INT_TO_CANCER_TYPE_MAP = {v: k for k, v in CANCER_TYPE_TO_INT_MAP.items()}
194
+
195
+
196
+ class SiteType(Enum):
197
+ PRIMARY = "Primary"
198
+ METASTASIS = "Metastasis"
199
+
200
+
201
+ class TileFeatureTensorDataset(Dataset):
202
+ def __init__(
203
+ self,
204
+ site_type: SiteType,
205
+ tile_features: np.ndarray,
206
+ n_max_tiles: int = 20000,
207
+ ) -> None:
208
+ """Initialize the dataset.
209
+
210
+ Args:
211
+ site_type: the site type as str, either "Primary" or "Metastasis"
212
+ tile_features: the tile feature array
213
+ n_max_tiles: the maximum number of tiles to use as int
214
+
215
+ Returns:
216
+ None
217
+ """
218
+ self.site_type = site_type
219
+ self.n_max_tiles = n_max_tiles
220
+ self.features = self._get_features(tile_features)
221
+
222
+ def __len__(self) -> int:
223
+ """Return the length of the dataset.
224
+
225
+ Returns:
226
+ int: the length of the dataset
227
+ """
228
+ return 1
229
+
230
+ def _get_features(self, features) -> torch.Tensor:
231
+ """Get the tile features
232
+
233
+ Args:
234
+ features: the tile features as a numpy array
235
+
236
+ Returns:
237
+ torch.Tensor: the tile tensor
238
+ """
239
+ features = torch.tensor(features, dtype=torch.float32)
240
+ if features.shape[0] > self.n_max_tiles:
241
+ indices = torch.randperm(features.shape[0])[: self.n_max_tiles]
242
+ features = features[indices]
243
+ if features.shape[0] < self.n_max_tiles:
244
+ padding = torch.zeros(
245
+ self.n_max_tiles - features.shape[0], features.shape[1]
246
+ )
247
+ features = torch.cat([features, padding], dim=0)
248
+ return features
249
+
250
+ def __getitem__(self, idx: int) -> dict:
251
+ """Return an item from the dataset.
252
+
253
+ Args:
254
+ idx: the index of the item to return
255
+
256
+ Returns:
257
+ dict: the item
258
+ """
259
+ return {
260
+ "site": self.site_type.value,
261
+ "tile_tensor": self.features
262
+ }
src/mosaic/inference/paladin.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import pickle # nosec
3
+ import sys
4
+ from argparse import ArgumentParser
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+ from typing import Any, List, Optional
8
+
9
+ import numpy
10
+ import pandas as pd
11
+ import numpy as np
12
+ import torch
13
+ from loguru import logger
14
+ from torch.utils.data import DataLoader
15
+
16
+ from mosaic.inference.data import SiteType, TileFeatureTensorDataset
17
+
18
+ # Constants
19
+ BATCH_SIZE = 8
20
+ NUM_WORKERS = 16
21
+
22
+
23
+ class UsageError(Exception):
24
+ """A UsageError is raised when there's a problem with the command-line arguments."""
25
+
26
+ pass
27
+
28
+
29
+ def load_model_map(model_map_path: str) -> dict[Any, Any]:
30
+ """Load the table mapping cancer_subtypes and targets to the paladin
31
+ model (a pickle file) that predicts that target for that cancer subtype.
32
+
33
+ A dict is returned, mapping each cancer_subtype to a table mapping a
34
+ target to the pathname for the model that predicts it.
35
+ """
36
+ models = defaultdict(dict)
37
+ with Path(model_map_path).open() as fp:
38
+ rdr = csv.DictReader(fp)
39
+ for row in rdr:
40
+ cancer_subtype = row["cancer_subtype"]
41
+ target = row["target_name"]
42
+ model = row["model_path"]
43
+ models[cancer_subtype][target] = model
44
+ return models
45
+
46
+
47
+ def load_aeon_scores(df: pd.DataFrame) -> dict[str, float]:
48
+ """Load the output table from a single-slide Aeon run, listing Oncotree
49
+ cancer subtypes and their confidence values.
50
+
51
+ A dict is returned, mapping each cancersubtype to its confidence score.
52
+ """
53
+ score = {}
54
+ for _, row in df.iterrows():
55
+ subtype = row["Cancer Subtype"]
56
+ confidence = row["Confidence"]
57
+ score[subtype] = confidence
58
+ return score
59
+
60
+
61
+ def select_cancer_subtypes(aeon_scores: dict[str, float], k=1) -> list[str]:
62
+ """Return the three top-scoring cancer_subtypes, based on the given Aeon scores."""
63
+ sorted_cancer_subtypes = list(
64
+ sorted([(v, k) for k, v in aeon_scores.items()], reverse=True)
65
+ )
66
+ return [cancer_subtype for score, cancer_subtype in sorted_cancer_subtypes[:k]]
67
+
68
+
69
+ def select_models(cancer_subtypes: list[str], model_map: dict[Any, Any]) -> list[Any]:
70
+ """ """
71
+ models = []
72
+ for cancer_subtype, target, model in model_map.items():
73
+ if cancer_subtype in cancer_subtypes:
74
+ models.append((cancer_subtype, target, model))
75
+ return models
76
+
77
+
78
+ def run_model(device, dataset, model_path: str, num_workers, batch_size) -> float:
79
+ """Run inference for the given embeddings and model.
80
+ The point estimate is returned.
81
+ """
82
+
83
+ logger.debug(f"[loading model {model_path}]")
84
+ with Path(model_path).open("rb") as f:
85
+ model = pickle.load(f) # nosec
86
+ # model = CPU_Unpickler(f).load() # nosec
87
+ model.to(device)
88
+ model.eval()
89
+
90
+ dataloader = DataLoader(
91
+ dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
92
+ )
93
+
94
+ results_df = []
95
+ batch = next(iter(dataloader))
96
+ with torch.no_grad():
97
+ batch["tile_tensor"] = batch["tile_tensor"].to(device)
98
+ outputs = model(batch)
99
+
100
+ logits = outputs["logits"]
101
+ # Apply softplus to ensure positive values for beta-binomial parameters
102
+ logits = torch.nn.functional.softplus(logits) + 1.0 # enforce concavity
103
+ point_estimates = logits_to_point_estimates(logits)
104
+
105
+ # sample_id = batch['sample_id'][0]
106
+ class_assignment = point_estimates[0].item()
107
+ return class_assignment
108
+
109
+
110
+ def logits_to_point_estimates(logits):
111
+ # logits is a tensor of shape (batch_size, 2 * (n_clf_tasks + n_reg_tasks))
112
+ # need to convert it to a tensor of shape (batch_size, n_clf_tasks + n_reg_tasks)
113
+ return logits[:, ::2] / (logits[:, ::2] + logits[:, 1::2])
114
+
115
+
116
+ def run(
117
+ features: np.ndarray,
118
+ aeon_results: Optional[pd.DataFrame] = None,
119
+ cancer_subtype_codes: List[str] = None,
120
+ model_map_path: str = None,
121
+ model_path: str = None,
122
+ metastatic: bool = False,
123
+ batch_size: int = BATCH_SIZE,
124
+ num_workers: int = NUM_WORKERS,
125
+ use_cpu: bool = False,
126
+ ):
127
+ """Run Paladin inference on a single slide, using the given embeddings
128
+ and either a single model or a table mapping cancer_subtypes and targets to models.
129
+ If cancer_subtype_codes is given, it is a list of OncoTree codes for the slide.
130
+ If aeon_predictions_path is given, it is the pathname to a CSV file
131
+ with the output of an Aeon run on the slide.
132
+ If both are given, an error is raised.
133
+ The output is written to the given output_path (a CSV file).
134
+ """
135
+
136
+ if aeon_results is not None:
137
+ aeon_scores = load_aeon_scores(aeon_results)
138
+ target_cancer_subtypes = select_cancer_subtypes(aeon_scores)
139
+ else:
140
+ target_cancer_subtypes = cancer_subtype_codes
141
+
142
+ # Build a dataset to feed to the model
143
+ site = SiteType.METASTASIS if metastatic else SiteType.PRIMARY
144
+
145
+ dataset = TileFeatureTensorDataset(
146
+ tile_features=features,
147
+ site_type=site,
148
+ n_max_tiles=20000,
149
+ )
150
+
151
+ device = torch.device(
152
+ "cuda" if not use_cpu and torch.cuda.is_available() else "cpu"
153
+ )
154
+
155
+ results = []
156
+ if model_path:
157
+ cancer_subtype, target = "None", "None"
158
+ try:
159
+ score = run_model(device, dataset, model_path, num_workers, batch_size)
160
+ results.append((cancer_subtype, target, score))
161
+ logger.info(
162
+ f"cancer_subtype: {cancer_subtype} target: {target} score: {score}"
163
+ )
164
+ except Exception as exc:
165
+ logger.error(
166
+ f"Unable to run model for {cancer_subtype} target {target}\n{exc}"
167
+ )
168
+
169
+ elif model_map_path:
170
+ model_map = load_model_map(model_map_path)
171
+ for cancer_subtype in target_cancer_subtypes:
172
+ if cancer_subtype not in model_map:
173
+ logger.warning(f"Warning: no models found for {cancer_subtype}")
174
+ continue
175
+
176
+ if "MSI_TYPE" in model_map[cancer_subtype]:
177
+ # Run MSI_TYPE model first, to determine if we should run other/MSS models
178
+ logger.info(f"Running MSI_TYPE model for {cancer_subtype} first")
179
+ try:
180
+ msi_score = run_model(
181
+ device,
182
+ dataset,
183
+ model_map[cancer_subtype]["MSI_TYPE"],
184
+ num_workers,
185
+ batch_size,
186
+ )
187
+ results.append((cancer_subtype, "MSI_TYPE", msi_score))
188
+ logger.info(
189
+ f"cancer_subtype: {cancer_subtype} target: MSI score: {msi_score}"
190
+ )
191
+ # If MSI score is high, skip MSS models
192
+ if msi_score >= 0.5:
193
+ logger.info(
194
+ f"Skipping MSS models for {cancer_subtype} due to high MSI score"
195
+ )
196
+ continue
197
+ else:
198
+ logger.info(
199
+ f"Running MSS models for {cancer_subtype} due to low MSI score"
200
+ )
201
+ except Exception as exc:
202
+ logger.error(
203
+ f"Unable to run model for {cancer_subtype} target MSI_TYPE\n{exc}"
204
+ )
205
+
206
+ for target, model in sorted(model_map[cancer_subtype].items()):
207
+ # Skip MSI_TYPE model, already run above
208
+ if target == "MSI_TYPE":
209
+ continue
210
+ try:
211
+ score = run_model(device, dataset, model, num_workers, batch_size)
212
+ results.append((cancer_subtype, target, score))
213
+ logger.info(
214
+ f"cancer_subtype: {cancer_subtype} target: {target} score: {score}"
215
+ )
216
+ except Exception as exc:
217
+ logger.error(
218
+ f"Unable to run model for {cancer_subtype} target {target}\n{exc}"
219
+ )
220
+ df = pd.DataFrame(results, columns=["Cancer Subtype", "Biomarker", "Score"])
221
+
222
+ return df
223
+
224
+
225
+ def parse_args():
226
+ parser = ArgumentParser(description="Run Paladin inference on a single slide")
227
+ parser.add_argument(
228
+ "-i",
229
+ "--features-path",
230
+ required=True,
231
+ help="Pathname to a .pt file with optimus embeddings for this slide",
232
+ )
233
+ parser.add_argument(
234
+ "-o",
235
+ "--output-path",
236
+ help="The filename for the Paladin predictions file (CSV)",
237
+ required=True,
238
+ )
239
+ parser.add_argument(
240
+ "-c",
241
+ "--cancer-subtype-codes",
242
+ help="One or more cancer_subtypes (OncoTree codes, comma-separated)",
243
+ )
244
+ parser.add_argument(
245
+ "-a",
246
+ "--aeon-predictions-path",
247
+ help="Pathname to an aeon-predictions file (CSV) for this slide",
248
+ )
249
+ parser.add_argument(
250
+ "-mm",
251
+ "--model-map-path",
252
+ help="A CSV file mapping cancer subtypes and targets to Paladin models (.pkl files). Contains columns 'cancer_subtype', 'target_name', and 'model_path'.",
253
+ )
254
+ parser.add_argument(
255
+ "-m",
256
+ "--model-path",
257
+ help="The filename for a Paladin model to run inference with",
258
+ )
259
+ parser.add_argument(
260
+ "--metastatic", action="store_true", help="Tissue is from a metastatic site"
261
+ )
262
+ parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Batch size")
263
+ parser.add_argument(
264
+ "--num-workers",
265
+ type=int,
266
+ default=NUM_WORKERS,
267
+ help="Number of workers for data loading",
268
+ )
269
+ parser.add_argument("--use-cpu", action="store_true", help="Use CPU")
270
+ opt = parser.parse_args()
271
+
272
+ if opt.cancer_subtype_codes and opt.aeon_predictions_path:
273
+ raise UsageError(
274
+ "You may specify either --codes or --aeon-predictions-path, but not both."
275
+ )
276
+
277
+ if opt.cancer_subtype_codes:
278
+ opt.cancer_subtype_codes = opt.cancer_subtype_codes.split(",")
279
+
280
+ if opt.model_path is None and opt.model_map_path is None:
281
+ raise UsageError("You must specify either --model-path or --model-map-path")
282
+
283
+ return opt
284
+
285
+
286
+ def main():
287
+ opt = parse_args()
288
+ features = torch.load(opt.features_path)
289
+ logger.info(f"Loaded features from {opt.features_path}")
290
+ aeon_results = None
291
+ if opt.aeon_predictions_path:
292
+ aeon_results = pd.read_csv(opt.aeon_predictions_path)
293
+ logger.info(f"Loaded Aeon results from {opt.aeon_predictions_path}")
294
+ df = run(
295
+ features=features,
296
+ aeon_results=aeon_results,
297
+ cancer_subtype_codes=opt.cancer_subtype_codes,
298
+ model_map_path=opt.model_map_path,
299
+ model_path=opt.model_path,
300
+ metastatic=opt.metastatic,
301
+ batch_size=opt.batch_size,
302
+ num_workers=opt.num_workers,
303
+ use_cpu=opt.use_cpu,
304
+ )
305
+ df.to_csv(opt.output_path, index=False)
306
+ logger.info(f"Wrote {opt.output_path}")
307
+
308
+
309
+ if __name__ == "__main__":
310
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff