ianpan commited on
Commit
9eaf668
·
verified ·
1 Parent(s): b1abe51

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +21 -8
modeling.py CHANGED
@@ -162,17 +162,24 @@ class TotalClassifierModel(PreTrainedModel):
162
  return batch_list
163
 
164
  if return_as_list:
165
- # list of lists
 
 
 
166
  batch_list = []
167
  for i in range(probas.shape[0]):
168
  probas_i = probas[i]
169
- batch_list.append(
170
- [
171
- self.index2label[each_class]
172
- for each_class in range(probas_i.shape[1])
173
- if probas_i[:, each_class] >= threshold
174
- ]
175
- )
 
 
 
 
176
  return batch_list
177
 
178
  return probas
@@ -409,6 +416,7 @@ class TotalClassifierModel(PreTrainedModel):
409
  x: np.ndarray,
410
  mode: str = "2d",
411
  torchify: bool = True,
 
412
  device: str | torch.device | None = None,
413
  ) -> np.ndarray:
414
  if device is not None:
@@ -427,6 +435,11 @@ class TotalClassifierModel(PreTrainedModel):
427
  x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
428
  elif x.ndim == 4:
429
  x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
 
 
 
 
 
430
  if device is not None:
431
  x = x.to(device)
432
  return x
 
162
  return batch_list
163
 
164
  if return_as_list:
165
+ # returns list of list of lists of strings
166
+ # innermost list - list of strings for each organ present based on threshold
167
+ # inner list - list of above for each slice
168
+ # outer list - list of above for each batch element (studies)
169
  batch_list = []
170
  for i in range(probas.shape[0]):
171
  probas_i = probas[i]
172
+ for each_slice in range(probas_i.shape[0]):
173
+ list_for_batch = []
174
+ for each_class in range(probas_i.shape[1]):
175
+ list_for_batch.append(
176
+ [
177
+ self.index2label[each_class]
178
+ for each_class in range(probas_i.shape[1])
179
+ if probas_i[each_slice, each_class] >= threshold
180
+ ]
181
+ )
182
+ batch_list.append(list_for_batch)
183
  return batch_list
184
 
185
  return probas
 
416
  x: np.ndarray,
417
  mode: str = "2d",
418
  torchify: bool = True,
419
+ add_batch_dim: bool = False,
420
  device: str | torch.device | None = None,
421
  ) -> np.ndarray:
422
  if device is not None:
 
435
  x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
436
  elif x.ndim == 4:
437
  x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
438
+ if add_batch_dim:
439
+ if torchify:
440
+ x = x.unsqueeze(0)
441
+ else:
442
+ x = x[np.newaxis]
443
  if device is not None:
444
  x = x.to(device)
445
  return x