File size: 432 Bytes
8ea2eff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | # 1x1 conv to collapse 2-ch coarse logits into 1-ch conditioning map
# TODO: Wire with coarse decoder outputs and proper resize/cropping.
import torch
import torch.nn as nn
class Conditioning1x1(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size=1, bias=True)
def forward(self, coarse_logits: torch.Tensor) -> torch.Tensor:
return self.conv(coarse_logits)
|