Spaces:
Runtime error
Runtime error
| import torch | |
| def pad_to_22_channels(input_tensor): | |
| if input_tensor.shape[1] == 3: # RGB input | |
| # Repeat channels to make 22 channels | |
| return torch.cat([input_tensor] * 7 + [input_tensor[:, 0:1]], dim=1) | |
| return input_tensor |