Upload model
Browse files- modeling.py +21 -8
modeling.py
CHANGED
|
@@ -162,17 +162,24 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 162 |
return batch_list
|
| 163 |
|
| 164 |
if return_as_list:
|
| 165 |
-
# list of lists
|
|
|
|
|
|
|
|
|
|
| 166 |
batch_list = []
|
| 167 |
for i in range(probas.shape[0]):
|
| 168 |
probas_i = probas[i]
|
| 169 |
-
|
| 170 |
-
[
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
return batch_list
|
| 177 |
|
| 178 |
return probas
|
|
@@ -409,6 +416,7 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 409 |
x: np.ndarray,
|
| 410 |
mode: str = "2d",
|
| 411 |
torchify: bool = True,
|
|
|
|
| 412 |
device: str | torch.device | None = None,
|
| 413 |
) -> np.ndarray:
|
| 414 |
if device is not None:
|
|
@@ -427,6 +435,11 @@ class TotalClassifierModel(PreTrainedModel):
|
|
| 427 |
x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
|
| 428 |
elif x.ndim == 4:
|
| 429 |
x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
if device is not None:
|
| 431 |
x = x.to(device)
|
| 432 |
return x
|
|
|
|
| 162 |
return batch_list
|
| 163 |
|
| 164 |
if return_as_list:
|
| 165 |
+
# returns list of list of lists of strings
|
| 166 |
+
# innermost list - list of strings for each organ present based on threshold
|
| 167 |
+
# inner list - list of above for each slice
|
| 168 |
+
# outer list - list of above for each batch element (studies)
|
| 169 |
batch_list = []
|
| 170 |
for i in range(probas.shape[0]):
|
| 171 |
probas_i = probas[i]
|
| 172 |
+
for each_slice in range(probas_i.shape[0]):
|
| 173 |
+
list_for_batch = []
|
| 174 |
+
for each_class in range(probas_i.shape[1]):
|
| 175 |
+
list_for_batch.append(
|
| 176 |
+
[
|
| 177 |
+
self.index2label[each_class]
|
| 178 |
+
for each_class in range(probas_i.shape[1])
|
| 179 |
+
if probas_i[each_slice, each_class] >= threshold
|
| 180 |
+
]
|
| 181 |
+
)
|
| 182 |
+
batch_list.append(list_for_batch)
|
| 183 |
return batch_list
|
| 184 |
|
| 185 |
return probas
|
|
|
|
| 416 |
x: np.ndarray,
|
| 417 |
mode: str = "2d",
|
| 418 |
torchify: bool = True,
|
| 419 |
+
add_batch_dim: bool = False,
|
| 420 |
device: str | torch.device | None = None,
|
| 421 |
) -> np.ndarray:
|
| 422 |
if device is not None:
|
|
|
|
| 435 |
x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
|
| 436 |
elif x.ndim == 4:
|
| 437 |
x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
|
| 438 |
+
if add_batch_dim:
|
| 439 |
+
if torchify:
|
| 440 |
+
x = x.unsqueeze(0)
|
| 441 |
+
else:
|
| 442 |
+
x = x[np.newaxis]
|
| 443 |
if device is not None:
|
| 444 |
x = x.to(device)
|
| 445 |
return x
|