Deepcube / model /DNN.py
hanxiaofeng
first commit
b570cf2
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.Cube import TARGET_STATE_ONE_HOT
class ResidualBlock(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(ResidualBlock, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.bn1 = nn.BatchNorm1d(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.bn2 = nn.BatchNorm1d(hidden_dim)
def forward(self, x):
residual = x
out = F.relu(self.bn1(self.fc1(x)))
out = self.bn2(self.fc2(out))
out += residual
out = F.relu(out)
return out
class DNN(nn.Module):
def __init__(self, input_dim, num_residual_blocks=4):
super(DNN, self).__init__()
# 前两个隐藏层
self.fc1 = nn.Linear(input_dim, 5000)
self.bn1 = nn.BatchNorm1d(5000)
self.fc2 = nn.Linear(5000, 1000)
self.bn2 = nn.BatchNorm1d(1000)
# 残差块
self.residual_blocks = nn.ModuleList()
for _ in range(num_residual_blocks):
self.residual_blocks.append(ResidualBlock(1000, 1000))
# 输出层
self.output_layer = nn.Linear(1000, 1)
def forward(self, x):
# 前两个隐藏层
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
# 残差块
for block in self.residual_blocks:
x = block(x)
# 输出层
x = self.output_layer(x)
return x # * self.K
# 示例用法
if __name__ == '__main__':
# 假设输入维度为54*6=324(根据Readme中提到的魔方状态表示)
input_dim = 324
model = DNN(input_dim, num_residual_blocks=4)
print(model)
# 测试前向传播
test_input = torch.randn(10, input_dim)
output = model(test_input)
print(f'Input shape: {test_input.shape}')
print(f'Output shape: {output.shape}')