File size: 432 Bytes
8ea2eff
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 1x1 conv to collapse 2-ch coarse logits into 1-ch conditioning map
# TODO: Wire with coarse decoder outputs and proper resize/cropping.

import torch
import torch.nn as nn


class Conditioning1x1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=1, bias=True)

    def forward(self, coarse_logits: torch.Tensor) -> torch.Tensor:
        return self.conv(coarse_logits)