raylim Claude Sonnet 4.5 commited on
Commit
0ab0da6
·
unverified ·
1 Parent(s): 14c1ba6

Refactor inference modules to eliminate code duplication

Browse files

Simplified run() and run_model() functions to delegate to their
_with_model/_with_preloaded counterparts after loading the model,
removing ~125 lines of duplicated inference logic across both modules.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

src/mosaic/inference/aeon.py CHANGED
@@ -179,79 +179,17 @@ def run(
179
  with open(model_path, "rb") as f:
180
  model = pickle.load(f) # nosec
181
  model.to(device)
182
- model.eval()
183
-
184
- # Load the correct mapping from metadata for this model
185
- data_dir = get_data_directory()
186
- metadata_path = data_dir / "metadata" / "target_dict.tsv"
187
- with open(metadata_path) as f:
188
- target_dict_str = f.read().strip().replace("'", '"')
189
- target_dict = json.loads(target_dict_str)
190
-
191
- histologies = target_dict["histologies"]
192
- INT_TO_CANCER_TYPE_MAP_LOCAL = {
193
- i: histology for i, histology in enumerate(histologies)
194
- }
195
- CANCER_TYPE_TO_INT_MAP_LOCAL = {
196
- v: k for k, v in INT_TO_CANCER_TYPE_MAP_LOCAL.items()
197
- }
198
 
199
- # Calculate col_indices_to_drop using local mapping
200
- col_indices_to_drop_local = [
201
- CANCER_TYPE_TO_INT_MAP_LOCAL[x]
202
- for x in CANCER_TYPES_TO_DROP
203
- if x in CANCER_TYPE_TO_INT_MAP_LOCAL
204
- ]
205
-
206
- site_type = SiteType.METASTASIS if metastatic else SiteType.PRIMARY
207
-
208
- # For UI, InferenceDataset will just be a single slide. Sample id is not relevant.
209
- dataset = TileFeatureTensorDataset(
210
- site_type=site_type,
211
- tile_features=features,
212
  sex=sex,
213
  tissue_site_idx=tissue_site_idx,
214
- n_max_tiles=20000,
215
  )
216
- dataloader = DataLoader(
217
- dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
218
- )
219
-
220
- results = []
221
- batch = next(iter(dataloader))
222
- with torch.no_grad():
223
- batch["tile_tensor"] = batch["tile_tensor"].to(device)
224
- if "SEX" in batch:
225
- batch["SEX"] = batch["SEX"].to(device)
226
- if "TISSUE_SITE" in batch:
227
- batch["TISSUE_SITE"] = batch["TISSUE_SITE"].to(device)
228
- y = model(batch)
229
- y["logits"][:, col_indices_to_drop_local] = -1e6
230
-
231
- batch_size = y["logits"].shape[0]
232
- assert batch_size == 1
233
-
234
- softmax = torch.nn.functional.softmax(y["logits"][0], dim=0)
235
- argmax = torch.argmax(softmax, dim=0)
236
- class_assignment = INT_TO_CANCER_TYPE_MAP_LOCAL[argmax.item()]
237
- max_confidence = softmax[argmax].item()
238
- mean_confidence = torch.mean(softmax).item()
239
-
240
- logger.info(
241
- f"class {class_assignment} : confidence {max_confidence:8.5f} "
242
- f"(mean {mean_confidence:8.5f})"
243
- )
244
-
245
- part_embedding = y["whole_part_representation"][0].cpu()
246
-
247
- for cancer_subtype, j in sorted(CANCER_TYPE_TO_INT_MAP_LOCAL.items()):
248
- confidence = softmax[j].item()
249
- results.append((cancer_subtype, confidence))
250
- results.sort(key=lambda row: row[1], reverse=True)
251
-
252
- results_df = pd.DataFrame(results, columns=["Cancer Subtype", "Confidence"])
253
-
254
- return results_df, part_embedding
255
 
256
 
257
  def parse_args():
 
179
  with open(model_path, "rb") as f:
180
  model = pickle.load(f) # nosec
181
  model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ return run_with_model(
184
+ features=features,
185
+ model=model,
186
+ device=device,
187
+ metastatic=metastatic,
188
+ batch_size=batch_size,
189
+ num_workers=num_workers,
 
 
 
 
 
 
190
  sex=sex,
191
  tissue_site_idx=tissue_site_idx,
 
192
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  def parse_args():
src/mosaic/inference/paladin.py CHANGED
@@ -161,28 +161,9 @@ def run_model(device, dataset, model_path: str, num_workers, batch_size) -> floa
161
  logger.debug(f"[loading model {model_path}]")
162
  with Path(model_path).open("rb") as f:
163
  model = pickle.load(f) # nosec
164
- # model = CPU_Unpickler(f).load() # nosec
165
  model.to(device)
166
- model.eval()
167
-
168
- dataloader = DataLoader(
169
- dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
170
- )
171
-
172
- results_df = []
173
- batch = next(iter(dataloader))
174
- with torch.no_grad():
175
- batch["tile_tensor"] = batch["tile_tensor"].to(device)
176
- outputs = model(batch)
177
 
178
- logits = outputs["logits"]
179
- # Apply softplus to ensure positive values for beta-binomial parameters
180
- logits = torch.nn.functional.softplus(logits) + 1.0 # enforce concavity
181
- point_estimates = logits_to_point_estimates(logits)
182
-
183
- # sample_id = batch['sample_id'][0]
184
- class_assignment = point_estimates[0].item()
185
- return class_assignment
186
 
187
 
188
  def logits_to_point_estimates(logits):
 
161
  logger.debug(f"[loading model {model_path}]")
162
  with Path(model_path).open("rb") as f:
163
  model = pickle.load(f) # nosec
 
164
  model.to(device)
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ return run_model_with_preloaded(device, dataset, model, num_workers, batch_size)
 
 
 
 
 
 
 
167
 
168
 
169
  def logits_to_point_estimates(logits):