from torch import nn class BasicBlock(nn.Module): def __init__(self, in_channels, channels, bias, k=3, p=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, channels, k, stride=1, padding=p, bias=bias) self.bn1 = nn.BatchNorm2d(channels) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(channels, channels, k, stride=1, padding=p, bias=bias) self.bn2 = nn.BatchNorm2d(channels) self.relu2 = nn.ReLU() def forward(self, x): y = self.conv1(x) y = self.bn1(y) y = self.relu1(y) y = self.conv2(y) y = self.bn2(y) x = x + y x = self.relu2(x) return x class Bottleneck(nn.Module): def __init__(self, in_channels, channels, bias): super().__init__() mid_channels = channels // 2 self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias) self.bn1 = nn.BatchNorm2d(mid_channels) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) self.bn2 = nn.BatchNorm2d(mid_channels) self.relu2 = nn.ReLU() self.conv3 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias) self.bn3 = nn.BatchNorm2d(channels) self.relu3 = nn.ReLU() def forward(self, x): y = self.conv1(x) y = self.bn1(y) y = self.relu1(y) y = self.conv2(y) y = self.bn2(y) y = self.relu2(y) y = self.conv3(y) y = self.bn3(y) x = x + y x = self.relu3(x) return x class Bottlenest(nn.Module): def __init__(self, in_channels, channels, bias): super().__init__() mid_channels = channels // 2 self.conv0 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias) self.bn0 = nn.BatchNorm2d(mid_channels) self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) self.bn1 = nn.BatchNorm2d(mid_channels) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) self.bn2 = nn.BatchNorm2d(mid_channels) self.relu2 = nn.ReLU() self.conv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) self.bn3 = nn.BatchNorm2d(mid_channels) self.relu3 = nn.ReLU() self.conv4 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) self.bn4 = nn.BatchNorm2d(mid_channels) self.relu4 = nn.ReLU() self.conv5 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias) self.bn5 = nn.BatchNorm2d(channels) self.relu5 = nn.ReLU() def forward(self, x): y = self.conv0(x) y = self.bn0(y) z = self.conv1(y) z = self.bn1(z) z = self.relu1(z) z = self.conv2(z) z = self.bn2(z) y = y + z y = self.relu2(y) z = self.conv3(y) z = self.bn3(z) z = self.relu3(z) z = self.conv4(z) z = self.bn4(z) y = y + z y = self.relu4(y) y = self.conv5(y) y = self.bn5(y) x = x + y x = self.relu5(x) return x class ResNet(nn.Module): def __init__(self, block, in_channels, layers, channels, bias): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d( in_channels, channels, kernel_size=5, stride=1, padding=2, bias=bias ), nn.BatchNorm2d(channels), nn.ReLU(), ) self.convs = nn.ModuleList( [block(channels, channels, bias) for _ in range(layers)] ) def forward(self, x): x = self.conv1(x) for conv in self.convs: x = conv(x) return x class AlphaZero(nn.Module): def __init__( self, in_channels, layers, channels, moves, board_size, value_heads=1, bias=False, block=BasicBlock, ): super().__init__() self.board_size = board_size self.resnet = ResNet(block, in_channels, layers, channels, bias) # policy head self.policy_head_front = nn.Sequential( nn.Conv2d(channels, 2, 1), nn.BatchNorm2d(2), nn.ReLU(), ) self.policy_head_end = nn.Linear(2 * board_size, moves) # value head self.value_head_front = nn.Sequential( nn.Conv2d(channels, 1, 1), nn.BatchNorm2d(1), nn.ReLU(), ) self.value_head_end = nn.Sequential( nn.Linear(board_size, channels), nn.ReLU(), nn.Linear(channels, value_heads), nn.Tanh(), ) def forward(self, x): x = self.resnet(x) # policy head p = self.policy_head_front(x) p = p.view(-1, 2 * self.board_size) p = self.policy_head_end(p) # value head v = self.value_head_front(x) v = v.view(-1, self.board_size) v = self.value_head_end(v) return p, v