Spaces:
Running
on
Zero
Running
on
Zero
Refactor inference modules to eliminate code duplication
Browse filesSimplified run() and run_model() functions to delegate to their
_with_model/_with_preloaded counterparts after loading the model,
removing ~125 lines of duplicated inference logic across both modules.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- src/mosaic/inference/aeon.py +7 -69
- src/mosaic/inference/paladin.py +1 -20
src/mosaic/inference/aeon.py
CHANGED
|
@@ -179,79 +179,17 @@ def run(
|
|
| 179 |
with open(model_path, "rb") as f:
|
| 180 |
model = pickle.load(f) # nosec
|
| 181 |
model.to(device)
|
| 182 |
-
model.eval()
|
| 183 |
-
|
| 184 |
-
# Load the correct mapping from metadata for this model
|
| 185 |
-
data_dir = get_data_directory()
|
| 186 |
-
metadata_path = data_dir / "metadata" / "target_dict.tsv"
|
| 187 |
-
with open(metadata_path) as f:
|
| 188 |
-
target_dict_str = f.read().strip().replace("'", '"')
|
| 189 |
-
target_dict = json.loads(target_dict_str)
|
| 190 |
-
|
| 191 |
-
histologies = target_dict["histologies"]
|
| 192 |
-
INT_TO_CANCER_TYPE_MAP_LOCAL = {
|
| 193 |
-
i: histology for i, histology in enumerate(histologies)
|
| 194 |
-
}
|
| 195 |
-
CANCER_TYPE_TO_INT_MAP_LOCAL = {
|
| 196 |
-
v: k for k, v in INT_TO_CANCER_TYPE_MAP_LOCAL.items()
|
| 197 |
-
}
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
site_type = SiteType.METASTASIS if metastatic else SiteType.PRIMARY
|
| 207 |
-
|
| 208 |
-
# For UI, InferenceDataset will just be a single slide. Sample id is not relevant.
|
| 209 |
-
dataset = TileFeatureTensorDataset(
|
| 210 |
-
site_type=site_type,
|
| 211 |
-
tile_features=features,
|
| 212 |
sex=sex,
|
| 213 |
tissue_site_idx=tissue_site_idx,
|
| 214 |
-
n_max_tiles=20000,
|
| 215 |
)
|
| 216 |
-
dataloader = DataLoader(
|
| 217 |
-
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
results = []
|
| 221 |
-
batch = next(iter(dataloader))
|
| 222 |
-
with torch.no_grad():
|
| 223 |
-
batch["tile_tensor"] = batch["tile_tensor"].to(device)
|
| 224 |
-
if "SEX" in batch:
|
| 225 |
-
batch["SEX"] = batch["SEX"].to(device)
|
| 226 |
-
if "TISSUE_SITE" in batch:
|
| 227 |
-
batch["TISSUE_SITE"] = batch["TISSUE_SITE"].to(device)
|
| 228 |
-
y = model(batch)
|
| 229 |
-
y["logits"][:, col_indices_to_drop_local] = -1e6
|
| 230 |
-
|
| 231 |
-
batch_size = y["logits"].shape[0]
|
| 232 |
-
assert batch_size == 1
|
| 233 |
-
|
| 234 |
-
softmax = torch.nn.functional.softmax(y["logits"][0], dim=0)
|
| 235 |
-
argmax = torch.argmax(softmax, dim=0)
|
| 236 |
-
class_assignment = INT_TO_CANCER_TYPE_MAP_LOCAL[argmax.item()]
|
| 237 |
-
max_confidence = softmax[argmax].item()
|
| 238 |
-
mean_confidence = torch.mean(softmax).item()
|
| 239 |
-
|
| 240 |
-
logger.info(
|
| 241 |
-
f"class {class_assignment} : confidence {max_confidence:8.5f} "
|
| 242 |
-
f"(mean {mean_confidence:8.5f})"
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
part_embedding = y["whole_part_representation"][0].cpu()
|
| 246 |
-
|
| 247 |
-
for cancer_subtype, j in sorted(CANCER_TYPE_TO_INT_MAP_LOCAL.items()):
|
| 248 |
-
confidence = softmax[j].item()
|
| 249 |
-
results.append((cancer_subtype, confidence))
|
| 250 |
-
results.sort(key=lambda row: row[1], reverse=True)
|
| 251 |
-
|
| 252 |
-
results_df = pd.DataFrame(results, columns=["Cancer Subtype", "Confidence"])
|
| 253 |
-
|
| 254 |
-
return results_df, part_embedding
|
| 255 |
|
| 256 |
|
| 257 |
def parse_args():
|
|
|
|
| 179 |
with open(model_path, "rb") as f:
|
| 180 |
model = pickle.load(f) # nosec
|
| 181 |
model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
+
return run_with_model(
|
| 184 |
+
features=features,
|
| 185 |
+
model=model,
|
| 186 |
+
device=device,
|
| 187 |
+
metastatic=metastatic,
|
| 188 |
+
batch_size=batch_size,
|
| 189 |
+
num_workers=num_workers,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
sex=sex,
|
| 191 |
tissue_site_idx=tissue_site_idx,
|
|
|
|
| 192 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
|
| 195 |
def parse_args():
|
src/mosaic/inference/paladin.py
CHANGED
|
@@ -161,28 +161,9 @@ def run_model(device, dataset, model_path: str, num_workers, batch_size) -> floa
|
|
| 161 |
logger.debug(f"[loading model {model_path}]")
|
| 162 |
with Path(model_path).open("rb") as f:
|
| 163 |
model = pickle.load(f) # nosec
|
| 164 |
-
# model = CPU_Unpickler(f).load() # nosec
|
| 165 |
model.to(device)
|
| 166 |
-
model.eval()
|
| 167 |
-
|
| 168 |
-
dataloader = DataLoader(
|
| 169 |
-
dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
results_df = []
|
| 173 |
-
batch = next(iter(dataloader))
|
| 174 |
-
with torch.no_grad():
|
| 175 |
-
batch["tile_tensor"] = batch["tile_tensor"].to(device)
|
| 176 |
-
outputs = model(batch)
|
| 177 |
|
| 178 |
-
|
| 179 |
-
# Apply softplus to ensure positive values for beta-binomial parameters
|
| 180 |
-
logits = torch.nn.functional.softplus(logits) + 1.0 # enforce concavity
|
| 181 |
-
point_estimates = logits_to_point_estimates(logits)
|
| 182 |
-
|
| 183 |
-
# sample_id = batch['sample_id'][0]
|
| 184 |
-
class_assignment = point_estimates[0].item()
|
| 185 |
-
return class_assignment
|
| 186 |
|
| 187 |
|
| 188 |
def logits_to_point_estimates(logits):
|
|
|
|
| 161 |
logger.debug(f"[loading model {model_path}]")
|
| 162 |
with Path(model_path).open("rb") as f:
|
| 163 |
model = pickle.load(f) # nosec
|
|
|
|
| 164 |
model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
return run_model_with_preloaded(device, dataset, model, num_workers, batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
def logits_to_point_estimates(logits):
|