raylim commited on
Commit
49fbf68
·
unverified ·
1 Parent(s): de40714

Complete implementation of sex and tissue site parameters

Browse files

- Add UI dropdowns for Sex and Tissue Site selection
- Update settings dataframe to include Sex and Tissue Site columns
- Pass sex and tissue_site parameters through the analysis pipeline
- Convert sex string to index (0=Male, 1=Female, None=Unknown)
- Convert tissue site string to index using tissue site mapping
- Add get_tissue_sites() helper function in ui/utils.py
- Update all inference pipeline functions to handle these parameters
- Pass sex and tissue_site_idx to Aeon model for improved inference

src/mosaic/analysis.py CHANGED
@@ -154,13 +154,15 @@ def _extract_optimus_features(filtered_coords, slide_path, attrs, num_workers):
154
  return features
155
 
156
 
157
- def _run_aeon_inference(features, site_type, num_workers):
158
  """Run Aeon cancer subtype inference on GPU.
159
 
160
  Args:
161
  features: Optimus features
162
  site_type: Site type ("Primary" or "Metastatic")
163
  num_workers: Number of worker processes
 
 
164
 
165
  Returns:
166
  Aeon results DataFrame
@@ -183,6 +185,8 @@ def _run_aeon_inference(features, site_type, num_workers):
183
  metastatic=(site_type == "Metastatic"),
184
  batch_size=8,
185
  num_workers=num_workers,
 
 
186
  use_cpu=False,
187
  )
188
  end_time = pd.Timestamp.now()
@@ -260,6 +264,8 @@ def _run_inference_pipeline_free(
260
  slide_path,
261
  attrs,
262
  site_type,
 
 
263
  cancer_subtype,
264
  cancer_subtype_name_map,
265
  num_workers,
@@ -267,8 +273,8 @@ def _run_inference_pipeline_free(
267
  ):
268
  """Run inference pipeline with 60s GPU limit (for free users)."""
269
  return _run_inference_pipeline_impl(
270
- coords, slide_path, attrs, site_type, cancer_subtype,
271
- cancer_subtype_name_map, num_workers, progress
272
  )
273
 
274
 
@@ -278,6 +284,8 @@ def _run_inference_pipeline_pro(
278
  slide_path,
279
  attrs,
280
  site_type,
 
 
281
  cancer_subtype,
282
  cancer_subtype_name_map,
283
  num_workers,
@@ -285,8 +293,8 @@ def _run_inference_pipeline_pro(
285
  ):
286
  """Run inference pipeline with 300s GPU limit (for PRO users)."""
287
  return _run_inference_pipeline_impl(
288
- coords, slide_path, attrs, site_type, cancer_subtype,
289
- cancer_subtype_name_map, num_workers, progress
290
  )
291
 
292
 
@@ -295,6 +303,8 @@ def _run_inference_pipeline_impl(
295
  slide_path,
296
  attrs,
297
  site_type,
 
 
298
  cancer_subtype,
299
  cancer_subtype_name_map,
300
  num_workers,
@@ -351,7 +361,7 @@ def _run_inference_pipeline_impl(
351
  # Step 5: Run Aeon to predict histology if not supplied
352
  if cancer_subtype == "Unknown":
353
  progress(0.9, desc="Running Aeon for cancer subtype inference")
354
- aeon_results = _run_aeon_inference(features, site_type, num_workers)
355
  else:
356
  cancer_subtype_code = cancer_subtype_name_map.get(cancer_subtype)
357
  aeon_results = pd.DataFrame(
@@ -509,6 +519,19 @@ def analyze_slide(
509
  import traceback
510
  logger.warning(traceback.format_exc())
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  if is_logged_in:
513
  logger.info("Using 300s GPU allocation (logged-in user)")
514
  aeon_results, paladin_results = _run_inference_pipeline_pro(
@@ -516,6 +539,8 @@ def analyze_slide(
516
  slide_path,
517
  attrs,
518
  site_type,
 
 
519
  cancer_subtype,
520
  cancer_subtype_name_map,
521
  num_workers,
@@ -528,6 +553,8 @@ def analyze_slide(
528
  slide_path,
529
  attrs,
530
  site_type,
 
 
531
  cancer_subtype,
532
  cancer_subtype_name_map,
533
  num_workers,
 
154
  return features
155
 
156
 
157
+ def _run_aeon_inference(features, site_type, num_workers, sex=None, tissue_site_idx=None):
158
  """Run Aeon cancer subtype inference on GPU.
159
 
160
  Args:
161
  features: Optimus features
162
  site_type: Site type ("Primary" or "Metastatic")
163
  num_workers: Number of worker processes
164
+ sex: Patient sex (0=Male, 1=Female), optional
165
+ tissue_site_idx: Tissue site index (0-56), optional
166
 
167
  Returns:
168
  Aeon results DataFrame
 
185
  metastatic=(site_type == "Metastatic"),
186
  batch_size=8,
187
  num_workers=num_workers,
188
+ sex=sex,
189
+ tissue_site_idx=tissue_site_idx,
190
  use_cpu=False,
191
  )
192
  end_time = pd.Timestamp.now()
 
264
  slide_path,
265
  attrs,
266
  site_type,
267
+ sex,
268
+ tissue_site_idx,
269
  cancer_subtype,
270
  cancer_subtype_name_map,
271
  num_workers,
 
273
  ):
274
  """Run inference pipeline with 60s GPU limit (for free users)."""
275
  return _run_inference_pipeline_impl(
276
+ coords, slide_path, attrs, site_type, sex, tissue_site_idx,
277
+ cancer_subtype, cancer_subtype_name_map, num_workers, progress
278
  )
279
 
280
 
 
284
  slide_path,
285
  attrs,
286
  site_type,
287
+ sex,
288
+ tissue_site_idx,
289
  cancer_subtype,
290
  cancer_subtype_name_map,
291
  num_workers,
 
293
  ):
294
  """Run inference pipeline with 300s GPU limit (for PRO users)."""
295
  return _run_inference_pipeline_impl(
296
+ coords, slide_path, attrs, site_type, sex, tissue_site_idx,
297
+ cancer_subtype, cancer_subtype_name_map, num_workers, progress
298
  )
299
 
300
 
 
303
  slide_path,
304
  attrs,
305
  site_type,
306
+ sex,
307
+ tissue_site_idx,
308
  cancer_subtype,
309
  cancer_subtype_name_map,
310
  num_workers,
 
361
  # Step 5: Run Aeon to predict histology if not supplied
362
  if cancer_subtype == "Unknown":
363
  progress(0.9, desc="Running Aeon for cancer subtype inference")
364
+ aeon_results = _run_aeon_inference(features, site_type, num_workers, sex, tissue_site_idx)
365
  else:
366
  cancer_subtype_code = cancer_subtype_name_map.get(cancer_subtype)
367
  aeon_results = pd.DataFrame(
 
519
  import traceback
520
  logger.warning(traceback.format_exc())
521
 
522
+ # Convert sex and tissue_site to indices for Aeon model
523
+ sex_idx = None
524
+ if sex and sex != "Unknown":
525
+ sex_idx = 0 if sex == "Male" else 1
526
+
527
+ tissue_site_idx = None
528
+ if tissue_site and tissue_site != "Unknown":
529
+ from mosaic.inference.data import get_tissue_site_map
530
+ tissue_site_map = get_tissue_site_map()
531
+ tissue_site_idx = tissue_site_map.get(tissue_site)
532
+ if tissue_site_idx is None:
533
+ logger.warning(f"Unknown tissue site: {tissue_site}")
534
+
535
  if is_logged_in:
536
  logger.info("Using 300s GPU allocation (logged-in user)")
537
  aeon_results, paladin_results = _run_inference_pipeline_pro(
 
539
  slide_path,
540
  attrs,
541
  site_type,
542
+ sex_idx,
543
+ tissue_site_idx,
544
  cancer_subtype,
545
  cancer_subtype_name_map,
546
  num_workers,
 
553
  slide_path,
554
  attrs,
555
  site_type,
556
+ sex_idx,
557
+ tissue_site_idx,
558
  cancer_subtype,
559
  cancer_subtype_name_map,
560
  num_workers,
src/mosaic/inference/aeon.py CHANGED
@@ -53,7 +53,8 @@ NUM_WORKERS = 8
53
 
54
 
55
  def run(
56
- features, model_path, metastatic=False, batch_size=8, num_workers=8, use_cpu=False
 
57
  ):
58
  """Run Aeon model inference for cancer subtype prediction.
59
 
@@ -64,6 +65,8 @@ def run(
64
  batch_size: Batch size for inference
65
  num_workers: Number of workers for data loading
66
  use_cpu: Force CPU usage instead of GPU
 
 
67
 
68
  Returns:
69
  tuple: (results_df, part_embedding)
@@ -85,6 +88,8 @@ def run(
85
  dataset = TileFeatureTensorDataset(
86
  site_type=site_type,
87
  tile_features=features,
 
 
88
  n_max_tiles=20000,
89
  )
90
  dataloader = DataLoader(
@@ -95,6 +100,10 @@ def run(
95
  batch = next(iter(dataloader))
96
  with torch.no_grad():
97
  batch["tile_tensor"] = batch["tile_tensor"].to(device)
 
 
 
 
98
  y = model(batch)
99
  y["logits"][:, col_indices_to_drop] = -1e6
100
 
 
53
 
54
 
55
  def run(
56
+ features, model_path, metastatic=False, batch_size=8, num_workers=8, use_cpu=False,
57
+ sex=None, tissue_site_idx=None
58
  ):
59
  """Run Aeon model inference for cancer subtype prediction.
60
 
 
65
  batch_size: Batch size for inference
66
  num_workers: Number of workers for data loading
67
  use_cpu: Force CPU usage instead of GPU
68
+ sex: Patient sex (0=Male, 1=Female), optional
69
+ tissue_site_idx: Tissue site index (0-56), optional
70
 
71
  Returns:
72
  tuple: (results_df, part_embedding)
 
88
  dataset = TileFeatureTensorDataset(
89
  site_type=site_type,
90
  tile_features=features,
91
+ sex=sex,
92
+ tissue_site_idx=tissue_site_idx,
93
  n_max_tiles=20000,
94
  )
95
  dataloader = DataLoader(
 
100
  batch = next(iter(dataloader))
101
  with torch.no_grad():
102
  batch["tile_tensor"] = batch["tile_tensor"].to(device)
103
+ if "SEX" in batch:
104
+ batch["SEX"] = batch["SEX"].to(device)
105
+ if "TISSUE_SITE" in batch:
106
+ batch["TISSUE_SITE"] = batch["TISSUE_SITE"].to(device)
107
  y = model(batch)
108
  y["logits"][:, col_indices_to_drop] = -1e6
109
 
src/mosaic/inference/data.py CHANGED
@@ -287,6 +287,8 @@ class TileFeatureTensorDataset(Dataset):
287
  self,
288
  site_type: SiteType,
289
  tile_features: np.ndarray,
 
 
290
  n_max_tiles: int = 20000,
291
  ) -> None:
292
  """Initialize the dataset.
@@ -294,12 +296,16 @@ class TileFeatureTensorDataset(Dataset):
294
  Args:
295
  site_type: the site type as str, either "Primary" or "Metastasis"
296
  tile_features: the tile feature array
 
 
297
  n_max_tiles: the maximum number of tiles to use as int
298
 
299
  Returns:
300
  None
301
  """
302
  self.site_type = site_type
 
 
303
  self.n_max_tiles = n_max_tiles
304
  self.features = self._get_features(tile_features)
305
 
@@ -340,7 +346,22 @@ class TileFeatureTensorDataset(Dataset):
340
  Returns:
341
  dict: the item
342
  """
343
- return {
344
  "site": self.site_type.value,
345
  "tile_tensor": self.features
346
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  self,
288
  site_type: SiteType,
289
  tile_features: np.ndarray,
290
+ sex: int = None,
291
+ tissue_site_idx: int = None,
292
  n_max_tiles: int = 20000,
293
  ) -> None:
294
  """Initialize the dataset.
 
296
  Args:
297
  site_type: the site type as str, either "Primary" or "Metastasis"
298
  tile_features: the tile feature array
299
+ sex: patient sex (0=Male, 1=Female), optional for Aeon
300
+ tissue_site_idx: tissue site index (0-56), optional for Aeon
301
  n_max_tiles: the maximum number of tiles to use as int
302
 
303
  Returns:
304
  None
305
  """
306
  self.site_type = site_type
307
+ self.sex = sex
308
+ self.tissue_site_idx = tissue_site_idx
309
  self.n_max_tiles = n_max_tiles
310
  self.features = self._get_features(tile_features)
311
 
 
346
  Returns:
347
  dict: the item
348
  """
349
+ result = {
350
  "site": self.site_type.value,
351
  "tile_tensor": self.features
352
  }
353
+
354
+ # Add sex and tissue_site if provided (for Aeon)
355
+ if self.sex is not None:
356
+ result["SEX"] = torch.tensor(
357
+ tissue_site_to_one_hot(self.sex, num_classes=3),
358
+ dtype=torch.float32
359
+ )
360
+
361
+ if self.tissue_site_idx is not None:
362
+ result["TISSUE_SITE"] = torch.tensor(
363
+ tissue_site_to_one_hot(self.tissue_site_idx, num_classes=57),
364
+ dtype=torch.float32
365
+ )
366
+
367
+ return result
src/mosaic/ui/app.py CHANGED
@@ -18,7 +18,9 @@ from mosaic.ui.utils import (
18
  create_user_directory,
19
  load_settings,
20
  validate_settings,
 
21
  IHC_SUBTYPES,
 
22
  SETTINGS_COLUMNS,
23
  )
24
  from mosaic.analysis import analyze_slide
@@ -80,6 +82,8 @@ def analyze_slides(
80
  slides[idx],
81
  row["Segmentation Config"],
82
  row["Site Type"],
 
 
83
  row["Cancer Subtype"],
84
  cancer_subtype_name_map,
85
  row["IHC Subtype"],
@@ -177,6 +181,16 @@ def launch_gradio(server_name, server_port, share):
177
  label="Site Type",
178
  value="Primary",
179
  )
 
 
 
 
 
 
 
 
 
 
180
  cancer_subtype_dropdown = gr.Dropdown(
181
  choices=[name for name in cancer_subtype_name_map.keys()],
182
  label="Cancer Subtype",
@@ -195,15 +209,9 @@ def launch_gradio(server_name, server_port, share):
195
  )
196
  with gr.Row():
197
  settings_input = gr.Dataframe(
198
- headers=[
199
- "Slide",
200
- "Site Type",
201
- "Cancer Subtype",
202
- "IHC Subtype",
203
- "Segmentation Config",
204
- ],
205
  label="Current Settings",
206
- datatype=["str", "str", "str", "str", "str"],
207
  visible=False,
208
  interactive=True,
209
  static_columns="Slide",
@@ -270,7 +278,7 @@ def launch_gradio(server_name, server_port, share):
270
  gr.File(visible=False),
271
  )
272
 
273
- def get_settings(files, site_type, cancer_subtype, ihc_subtype, seg_config):
274
  if files is None:
275
  return pd.DataFrame()
276
  settings = []
@@ -278,7 +286,7 @@ def launch_gradio(server_name, server_port, share):
278
  filename = file.name if hasattr(file, "name") else file
279
  slide_name = filename.split("/")[-1]
280
  settings.append(
281
- [slide_name, site_type, cancer_subtype, ihc_subtype, seg_config]
282
  )
283
  df = pd.DataFrame(settings, columns=SETTINGS_COLUMNS)
284
  return df
@@ -288,6 +296,8 @@ def launch_gradio(server_name, server_port, share):
288
  [
289
  input_slides.change,
290
  site_dropdown.change,
 
 
291
  cancer_subtype_dropdown.change,
292
  ihc_subtype_dropdown.change,
293
  seg_config_dropdown.change,
@@ -295,18 +305,20 @@ def launch_gradio(server_name, server_port, share):
295
  inputs=[
296
  input_slides,
297
  site_dropdown,
 
 
298
  cancer_subtype_dropdown,
299
  ihc_subtype_dropdown,
300
  seg_config_dropdown,
301
  ],
302
  outputs=[settings_input, settings_csv, ihc_subtype_dropdown],
303
  )
304
- def update_settings(files, site_type, cancer_subtype, ihc_subtype, seg_config):
305
  has_ihc = "Breast" in cancer_subtype
306
  if not files:
307
  return None, None, gr.Dropdown(visible=has_ihc)
308
  settings_df = get_settings(
309
- files, site_type, cancer_subtype, ihc_subtype, seg_config
310
  )
311
  if settings_df is not None:
312
  has_ihc = any("Breast" in cs for cs in settings_df["Cancer Subtype"])
 
18
  create_user_directory,
19
  load_settings,
20
  validate_settings,
21
+ get_tissue_sites,
22
  IHC_SUBTYPES,
23
+ SEX_OPTIONS,
24
  SETTINGS_COLUMNS,
25
  )
26
  from mosaic.analysis import analyze_slide
 
82
  slides[idx],
83
  row["Segmentation Config"],
84
  row["Site Type"],
85
+ row["Sex"],
86
+ row["Tissue Site"],
87
  row["Cancer Subtype"],
88
  cancer_subtype_name_map,
89
  row["IHC Subtype"],
 
181
  label="Site Type",
182
  value="Primary",
183
  )
184
+ sex_dropdown = gr.Dropdown(
185
+ choices=SEX_OPTIONS,
186
+ label="Sex",
187
+ value="Unknown",
188
+ )
189
+ tissue_site_dropdown = gr.Dropdown(
190
+ choices=get_tissue_sites(),
191
+ label="Tissue Site",
192
+ value="Unknown",
193
+ )
194
  cancer_subtype_dropdown = gr.Dropdown(
195
  choices=[name for name in cancer_subtype_name_map.keys()],
196
  label="Cancer Subtype",
 
209
  )
210
  with gr.Row():
211
  settings_input = gr.Dataframe(
212
+ headers=SETTINGS_COLUMNS,
 
 
 
 
 
 
213
  label="Current Settings",
214
+ datatype=["str"] * len(SETTINGS_COLUMNS),
215
  visible=False,
216
  interactive=True,
217
  static_columns="Slide",
 
278
  gr.File(visible=False),
279
  )
280
 
281
+ def get_settings(files, site_type, sex, tissue_site, cancer_subtype, ihc_subtype, seg_config):
282
  if files is None:
283
  return pd.DataFrame()
284
  settings = []
 
286
  filename = file.name if hasattr(file, "name") else file
287
  slide_name = filename.split("/")[-1]
288
  settings.append(
289
+ [slide_name, site_type, sex, tissue_site, cancer_subtype, ihc_subtype, seg_config]
290
  )
291
  df = pd.DataFrame(settings, columns=SETTINGS_COLUMNS)
292
  return df
 
296
  [
297
  input_slides.change,
298
  site_dropdown.change,
299
+ sex_dropdown.change,
300
+ tissue_site_dropdown.change,
301
  cancer_subtype_dropdown.change,
302
  ihc_subtype_dropdown.change,
303
  seg_config_dropdown.change,
 
305
  inputs=[
306
  input_slides,
307
  site_dropdown,
308
+ sex_dropdown,
309
+ tissue_site_dropdown,
310
  cancer_subtype_dropdown,
311
  ihc_subtype_dropdown,
312
  seg_config_dropdown,
313
  ],
314
  outputs=[settings_input, settings_csv, ihc_subtype_dropdown],
315
  )
316
+ def update_settings(files, site_type, sex, tissue_site, cancer_subtype, ihc_subtype, seg_config):
317
  has_ihc = "Breast" in cancer_subtype
318
  if not files:
319
  return None, None, gr.Dropdown(visible=has_ihc)
320
  settings_df = get_settings(
321
+ files, site_type, sex, tissue_site, cancer_subtype, ihc_subtype, seg_config
322
  )
323
  if settings_df is not None:
324
  has_ihc = any("Breast" in cs for cs in settings_df["Cancer Subtype"])
src/mosaic/ui/utils.py CHANGED
@@ -17,6 +17,7 @@ import requests
17
  TEMP_USER_DATA_DIR = Path(tempfile.gettempdir()) / "mosaic_user_data"
18
 
19
  IHC_SUBTYPES = ["", "HR+/HER2+", "HR+/HER2-", "HR-/HER2+", "HR-/HER2-"]
 
20
 
21
  SETTINGS_COLUMNS = [
22
  "Slide",
@@ -29,6 +30,23 @@ SETTINGS_COLUMNS = [
29
  ]
30
 
31
  oncotree_code_map = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def get_oncotree_code_name(code):
@@ -100,6 +118,10 @@ def load_settings(slide_csv_path):
100
  settings_df["Cancer Subtype"] = "Unknown"
101
  if "IHC Subtype" not in settings_df.columns:
102
  settings_df["IHC Subtype"] = ""
 
 
 
 
103
  if not set(SETTINGS_COLUMNS).issubset(settings_df.columns):
104
  raise ValueError("Missing required column in CSV file")
105
  settings_df = settings_df[SETTINGS_COLUMNS]
 
17
  TEMP_USER_DATA_DIR = Path(tempfile.gettempdir()) / "mosaic_user_data"
18
 
19
  IHC_SUBTYPES = ["", "HR+/HER2+", "HR+/HER2-", "HR-/HER2+", "HR-/HER2-"]
20
+ SEX_OPTIONS = ["Unknown", "Male", "Female"]
21
 
22
  SETTINGS_COLUMNS = [
23
  "Slide",
 
30
  ]
31
 
32
  oncotree_code_map = {}
33
+ tissue_site_list = None
34
+
35
+
36
+ def get_tissue_sites():
37
+ """Get the list of tissue sites from the tissue site map file.
38
+
39
+ Returns:
40
+ List of tissue site names
41
+ """
42
+ global tissue_site_list
43
+ if tissue_site_list is None:
44
+ current_dir = Path(__file__).parent.parent.parent
45
+ tissue_site_map_path = current_dir / "data" / "tissue_site_original_to_idx.csv"
46
+ df = pd.read_csv(tissue_site_map_path)
47
+ # Get unique tissue sites and sort them
48
+ tissue_site_list = ["Unknown"] + sorted(df["TISSUE_SITE"].unique().tolist())
49
+ return tissue_site_list
50
 
51
 
52
  def get_oncotree_code_name(code):
 
118
  settings_df["Cancer Subtype"] = "Unknown"
119
  if "IHC Subtype" not in settings_df.columns:
120
  settings_df["IHC Subtype"] = ""
121
+ if "Sex" not in settings_df.columns:
122
+ settings_df["Sex"] = "Unknown"
123
+ if "Tissue Site" not in settings_df.columns:
124
+ settings_df["Tissue Site"] = "Unknown"
125
  if not set(SETTINGS_COLUMNS).issubset(settings_df.columns):
126
  raise ValueError("Missing required column in CSV file")
127
  settings_df = settings_df[SETTINGS_COLUMNS]