File size: 684 Bytes
a19a7aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch.nn as nn
class SpaceEncoder(nn.Module):
def __init__(self, input_channels=1, feature_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)), # 输出 [B, 64, 1, 1]
nn.Flatten(), # [B, 64]
nn.Linear(64, feature_dim), # → [B, 128]
nn.ReLU(inplace=True)
)
def forward(self, x): # x: [B, 1, H, W]
return self.encoder(x) # [B, 128]
|