ugb_zero / train /model.py
chengscott's picture
train
026a224
from torch import nn
class BasicBlock(nn.Module):
def __init__(self, in_channels, channels, bias, k=3, p=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, channels, k, stride=1, padding=p, bias=bias)
self.bn1 = nn.BatchNorm2d(channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, k, stride=1, padding=p, bias=bias)
self.bn2 = nn.BatchNorm2d(channels)
self.relu2 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.bn1(y)
y = self.relu1(y)
y = self.conv2(y)
y = self.bn2(y)
x = x + y
x = self.relu2(x)
return x
class Bottleneck(nn.Module):
def __init__(self, in_channels, channels, bias):
super().__init__()
mid_channels = channels // 2
self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias)
self.bn3 = nn.BatchNorm2d(channels)
self.relu3 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.bn1(y)
y = self.relu1(y)
y = self.conv2(y)
y = self.bn2(y)
y = self.relu2(y)
y = self.conv3(y)
y = self.bn3(y)
x = x + y
x = self.relu3(x)
return x
class Bottlenest(nn.Module):
def __init__(self, in_channels, channels, bias):
super().__init__()
mid_channels = channels // 2
self.conv0 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias)
self.bn0 = nn.BatchNorm2d(mid_channels)
self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
self.bn3 = nn.BatchNorm2d(mid_channels)
self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias)
self.bn4 = nn.BatchNorm2d(mid_channels)
self.relu4 = nn.ReLU()
self.conv5 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias)
self.bn5 = nn.BatchNorm2d(channels)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv0(x)
y = self.bn0(y)
z = self.conv1(y)
z = self.bn1(z)
z = self.relu1(z)
z = self.conv2(z)
z = self.bn2(z)
y = y + z
y = self.relu2(y)
z = self.conv3(y)
z = self.bn3(z)
z = self.relu3(z)
z = self.conv4(z)
z = self.bn4(z)
y = y + z
y = self.relu4(y)
y = self.conv5(y)
y = self.bn5(y)
x = x + y
x = self.relu5(x)
return x
class ResNet(nn.Module):
def __init__(self, block, in_channels, layers, channels, bias):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels, channels, kernel_size=5, stride=1, padding=2, bias=bias
),
nn.BatchNorm2d(channels),
nn.ReLU(),
)
self.convs = nn.ModuleList(
[block(channels, channels, bias) for _ in range(layers)]
)
def forward(self, x):
x = self.conv1(x)
for conv in self.convs:
x = conv(x)
return x
class AlphaZero(nn.Module):
def __init__(
self,
in_channels,
layers,
channels,
moves,
board_size,
value_heads=1,
bias=False,
block=BasicBlock,
):
super().__init__()
self.board_size = board_size
self.resnet = ResNet(block, in_channels, layers, channels, bias)
# policy head
self.policy_head_front = nn.Sequential(
nn.Conv2d(channels, 2, 1),
nn.BatchNorm2d(2),
nn.ReLU(),
)
self.policy_head_end = nn.Linear(2 * board_size, moves)
# value head
self.value_head_front = nn.Sequential(
nn.Conv2d(channels, 1, 1),
nn.BatchNorm2d(1),
nn.ReLU(),
)
self.value_head_end = nn.Sequential(
nn.Linear(board_size, channels),
nn.ReLU(),
nn.Linear(channels, value_heads),
nn.Tanh(),
)
def forward(self, x):
x = self.resnet(x)
# policy head
p = self.policy_head_front(x)
p = p.view(-1, 2 * self.board_size)
p = self.policy_head_end(p)
# value head
v = self.value_head_front(x)
v = v.view(-1, self.board_size)
v = self.value_head_end(v)
return p, v