import torch.nn as nn class SpaceEncoder(nn.Module): def __init__(self, input_channels=1, feature_dim=128): super().__init__() self.encoder = nn.Sequential( nn.Conv2d(input_channels, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), # 输出 [B, 64, 1, 1] nn.Flatten(), # [B, 64] nn.Linear(64, feature_dim), # → [B, 128] nn.ReLU(inplace=True) ) def forward(self, x): # x: [B, 1, H, W] return self.encoder(x) # [B, 128]