Update CBCT/Model.py
Browse files- CBCT/Model.py +3 -2
CBCT/Model.py
CHANGED
|
@@ -37,8 +37,9 @@ class Concat(Reduction):
|
|
| 37 |
|
| 38 |
def __call__(self, tensor: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
|
| 39 |
if isinstance(tensor, list):
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
class Uncertainty(Transform):
|
| 44 |
|
|
|
|
| 37 |
|
| 38 |
def __call__(self, tensor: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
|
| 39 |
if isinstance(tensor, list):
|
| 40 |
+
return torch.stack(tensor, dim=2).squeeze(1)
|
| 41 |
+
else:
|
| 42 |
+
return tensor.view(tensor.shape[0]*tensor.shape[1], -1, *tensor.shape[3:])
|
| 43 |
|
| 44 |
class Uncertainty(Transform):
|
| 45 |
|