Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model.Cube import TARGET_STATE_ONE_HOT | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super(ResidualBlock, self).__init__() | |
| self.fc1 = nn.Linear(input_dim, hidden_dim) | |
| self.bn1 = nn.BatchNorm1d(hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, hidden_dim) | |
| self.bn2 = nn.BatchNorm1d(hidden_dim) | |
| def forward(self, x): | |
| residual = x | |
| out = F.relu(self.bn1(self.fc1(x))) | |
| out = self.bn2(self.fc2(out)) | |
| out += residual | |
| out = F.relu(out) | |
| return out | |
| class DNN(nn.Module): | |
| def __init__(self, input_dim, num_residual_blocks=4): | |
| super(DNN, self).__init__() | |
| # 前两个隐藏层 | |
| self.fc1 = nn.Linear(input_dim, 5000) | |
| self.bn1 = nn.BatchNorm1d(5000) | |
| self.fc2 = nn.Linear(5000, 1000) | |
| self.bn2 = nn.BatchNorm1d(1000) | |
| # 残差块 | |
| self.residual_blocks = nn.ModuleList() | |
| for _ in range(num_residual_blocks): | |
| self.residual_blocks.append(ResidualBlock(1000, 1000)) | |
| # 输出层 | |
| self.output_layer = nn.Linear(1000, 1) | |
| def forward(self, x): | |
| # 前两个隐藏层 | |
| x = F.relu(self.bn1(self.fc1(x))) | |
| x = F.relu(self.bn2(self.fc2(x))) | |
| # 残差块 | |
| for block in self.residual_blocks: | |
| x = block(x) | |
| # 输出层 | |
| x = self.output_layer(x) | |
| return x # * self.K | |
| # 示例用法 | |
| if __name__ == '__main__': | |
| # 假设输入维度为54*6=324(根据Readme中提到的魔方状态表示) | |
| input_dim = 324 | |
| model = DNN(input_dim, num_residual_blocks=4) | |
| print(model) | |
| # 测试前向传播 | |
| test_input = torch.randn(10, input_dim) | |
| output = model(test_input) | |
| print(f'Input shape: {test_input.shape}') | |
| print(f'Output shape: {output.shape}') |