VBoussot commited on
Commit
8efb82a
·
verified ·
1 Parent(s): 0d8855a

Update CBCT/Model.py

Browse files
Files changed (1) hide show
  1. CBCT/Model.py +3 -2
CBCT/Model.py CHANGED
@@ -37,8 +37,9 @@ class Concat(Reduction):
37
 
38
  def __call__(self, tensor: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
39
  if isinstance(tensor, list):
40
- tensor = torch.stack(tensor, dim=0)
41
- return tensor.view(tensor.shape[1], -1, *tensor.shape[3:])
 
42
 
43
  class Uncertainty(Transform):
44
 
 
37
 
38
  def __call__(self, tensor: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
39
  if isinstance(tensor, list):
40
+ return torch.stack(tensor, dim=2).squeeze(1)
41
+ else:
42
+ return tensor.view(tensor.shape[0]*tensor.shape[1], -1, *tensor.shape[3:])
43
 
44
  class Uncertainty(Transform):
45