import torch import torch.nn as nn import torch.nn.functional as F class ConvEncoder(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(), ) def forward(self, x): return self.features(x) class GenConViT(nn.Module): def __init__(self): super().__init__() self.encoder = ConvEncoder() self.classifier = nn.Sequential( nn.Linear(128 * 56 * 56, 256), nn.ReLU(), nn.Linear(256, 2) ) def forward(self, x): feat = self.encoder(x) feat = feat.view(feat.size(0), -1) out = self.classifier(feat) return out