| import json
|
| import gzip
|
| import torch
|
| import pathlib
|
| import requests
|
| import traceback
|
| import numpy as np
|
|
|
| from torch import nn, Tensor
|
| from torch.nn import functional as F
|
| from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
|
| from torch.distributions import Normal, Categorical
|
| from typing import *
|
| from functools import partial
|
| from itertools import permutations
|
| from libriichi.mjai import Bot
|
| from libriichi.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE
|
|
|
|
|
| OT_REQUEST_TIMEOUT = 2
|
| ot_settings = {
|
| "server": "http://example.com",
|
| "online": False,
|
| "api_key": "example_api_key",
|
| }
|
| is_online = False
|
|
|
| def online_settings_init():
|
| global ot_settings
|
|
|
| if (pathlib.Path(__file__).parent / 'ot_settings.json').exists():
|
| with open(pathlib.Path(__file__).parent / 'ot_settings.json', 'r') as f:
|
| ot_settings = json.load(f)
|
|
|
| online_settings_init()
|
|
|
|
|
| class ChannelAttention(nn.Module):
|
| def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True):
|
| super().__init__()
|
| self.shared_mlp = nn.Sequential(
|
| nn.Linear(channels, channels // ratio, bias=bias),
|
| actv_builder(),
|
| nn.Linear(channels // ratio, channels, bias=bias),
|
| )
|
| if bias:
|
| for mod in self.modules():
|
| if isinstance(mod, nn.Linear):
|
| nn.init.constant_(mod.bias, 0)
|
|
|
| def forward(self, x: Tensor):
|
| avg_out = self.shared_mlp(x.mean(-1))
|
| max_out = self.shared_mlp(x.amax(-1))
|
| weight = (avg_out + max_out).sigmoid()
|
| x = weight.unsqueeze(-1) * x
|
| return x
|
|
|
| class ResBlock(nn.Module):
|
| def __init__(
|
| self,
|
| channels,
|
| *,
|
| norm_builder = nn.Identity,
|
| actv_builder = nn.ReLU,
|
| pre_actv = False,
|
| ):
|
| super().__init__()
|
| self.pre_actv = pre_actv
|
|
|
| if pre_actv:
|
| self.res_unit = nn.Sequential(
|
| norm_builder(),
|
| actv_builder(),
|
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
|
| norm_builder(),
|
| actv_builder(),
|
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
|
| )
|
| else:
|
| self.res_unit = nn.Sequential(
|
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
|
| norm_builder(),
|
| actv_builder(),
|
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
|
| norm_builder(),
|
| )
|
| self.actv = actv_builder()
|
| self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True)
|
|
|
| def forward(self, x):
|
| out = self.res_unit(x)
|
| out = self.ca(out)
|
| out = out + x
|
| if not self.pre_actv:
|
| out = self.actv(out)
|
| return out
|
|
|
| class ResNet(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels,
|
| conv_channels,
|
| num_blocks,
|
| *,
|
| norm_builder = nn.Identity,
|
| actv_builder = nn.ReLU,
|
| pre_actv = False,
|
| ):
|
| super().__init__()
|
|
|
| blocks = []
|
| for _ in range(num_blocks):
|
| blocks.append(ResBlock(
|
| conv_channels,
|
| norm_builder = norm_builder,
|
| actv_builder = actv_builder,
|
| pre_actv = pre_actv,
|
| ))
|
|
|
| layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)]
|
| if pre_actv:
|
| layers += [*blocks, norm_builder(), actv_builder()]
|
| else:
|
| layers += [norm_builder(), actv_builder(), *blocks]
|
| layers += [
|
| nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1),
|
| actv_builder(),
|
| nn.Flatten(),
|
| nn.Linear(32 * 34, 1024),
|
| ]
|
| self.net = nn.Sequential(*layers)
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
| class Brain(nn.Module):
|
| def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1):
|
| super().__init__()
|
| self.is_oracle = is_oracle
|
| self.version = version
|
|
|
| in_channels = obs_shape(version)[0]
|
| if is_oracle:
|
| in_channels += oracle_obs_shape(version)[0]
|
|
|
| norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01)
|
| actv_builder = partial(nn.Mish, inplace=True)
|
| pre_actv = True
|
|
|
| match version:
|
| case 1:
|
| actv_builder = partial(nn.ReLU, inplace=True)
|
| pre_actv = False
|
| self.latent_net = nn.Sequential(
|
| nn.Linear(1024, 512),
|
| nn.ReLU(inplace=True),
|
| )
|
| self.mu_head = nn.Linear(512, 512)
|
| self.logsig_head = nn.Linear(512, 512)
|
| case 2:
|
| pass
|
| case 3 | 4:
|
| norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3)
|
| case _:
|
| raise ValueError(f'Unexpected version {self.version}')
|
|
|
| self.encoder = ResNet(
|
| in_channels = in_channels,
|
| conv_channels = conv_channels,
|
| num_blocks = num_blocks,
|
| norm_builder = norm_builder,
|
| actv_builder = actv_builder,
|
| pre_actv = pre_actv,
|
| )
|
| self.actv = actv_builder()
|
|
|
|
|
| self._freeze_bn = False
|
|
|
| def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
|
| if self.is_oracle:
|
| assert invisible_obs is not None
|
| obs = torch.cat((obs, invisible_obs), dim=1)
|
| phi = self.encoder(obs)
|
|
|
| match self.version:
|
| case 1:
|
| latent_out = self.latent_net(phi)
|
| mu = self.mu_head(latent_out)
|
| logsig = self.logsig_head(latent_out)
|
| return mu, logsig
|
| case 2 | 3 | 4:
|
| return self.actv(phi)
|
| case _:
|
| raise ValueError(f'Unexpected version {self.version}')
|
|
|
| def train(self, mode=True):
|
| super().train(mode)
|
| if self._freeze_bn:
|
| for mod in self.modules():
|
| if isinstance(mod, nn.BatchNorm1d):
|
| mod.eval()
|
|
|
|
|
| return self
|
|
|
| def reset_running_stats(self):
|
| for mod in self.modules():
|
| if isinstance(mod, nn.BatchNorm1d):
|
| mod.reset_running_stats()
|
|
|
| def freeze_bn(self, value: bool):
|
| self._freeze_bn = value
|
| return self.train(self.training)
|
|
|
| class AuxNet(nn.Module):
|
| def __init__(self, dims=None):
|
| super().__init__()
|
| self.dims = dims
|
| self.net = nn.Linear(1024, sum(dims), bias=False)
|
|
|
| def forward(self, x):
|
| return self.net(x).split(self.dims, dim=-1)
|
|
|
| class DQN(nn.Module):
|
| def __init__(self, *, version=1):
|
| super().__init__()
|
| self.version = version
|
| print(version)
|
| match version:
|
| case 1:
|
| self.v_head = nn.Linear(512, 1)
|
| self.a_head = nn.Linear(512, ACTION_SPACE)
|
| case 2 | 3:
|
| hidden_size = 512 if version == 2 else 256
|
| self.v_head = nn.Sequential(
|
| nn.Linear(1024, hidden_size),
|
| nn.Mish(inplace=True),
|
| nn.Linear(hidden_size, 1),
|
| )
|
| self.a_head = nn.Sequential(
|
| nn.Linear(1024, hidden_size),
|
| nn.Mish(inplace=True),
|
| nn.Linear(hidden_size, ACTION_SPACE),
|
| )
|
| case 4:
|
| self.net = nn.Linear(1024, 1 + ACTION_SPACE)
|
| nn.init.constant_(self.net.bias, 0)
|
|
|
| def forward(self, phi, mask):
|
| if self.version == 4:
|
| v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1)
|
| else:
|
| v = self.v_head(phi)
|
| a = self.a_head(phi)
|
| a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True)
|
| mask_sum = mask.sum(-1, keepdim=True)
|
| a_mean = a_sum / mask_sum
|
| q = (v + a - a_mean).masked_fill(~mask, -torch.inf)
|
| return q
|
|
|
|
|
| class MortalEngine:
|
| def __init__(
|
| self,
|
| brain,
|
| dqn,
|
| is_oracle,
|
| version,
|
| device = None,
|
| stochastic_latent = False,
|
| enable_amp = False,
|
| enable_quick_eval = True,
|
| enable_rule_based_agari_guard = False,
|
| name = 'NoName',
|
| boltzmann_epsilon = 0,
|
| boltzmann_temp = 1,
|
| top_p = 1,
|
| ):
|
| self.engine_type = 'mortal'
|
| self.device = device or torch.device('cpu')
|
| assert isinstance(self.device, torch.device)
|
| self.brain = brain.to(self.device).eval()
|
| self.dqn = dqn.to(self.device).eval()
|
| self.is_oracle = is_oracle
|
| self.version = version
|
| self.stochastic_latent = stochastic_latent
|
|
|
| self.enable_amp = enable_amp
|
| self.enable_quick_eval = enable_quick_eval
|
| self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
|
| self.name = name
|
|
|
| self.boltzmann_epsilon = boltzmann_epsilon
|
| self.boltzmann_temp = boltzmann_temp
|
| self.top_p = top_p
|
|
|
| def react_batch(self, obs, masks, invisible_obs):
|
|
|
| global ot_settings, is_online
|
| try:
|
| with (
|
| torch.autocast(self.device.type, enabled=self.enable_amp),
|
| torch.inference_mode(),
|
| ):
|
| return self._react_batch(obs, masks, invisible_obs)
|
| except Exception as ex:
|
| raise Exception(f'{ex}\n{traceback.format_exc()}')
|
|
|
| def _react_batch(self, obs, masks, invisible_obs):
|
| obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
|
| masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
|
| invisible_obs = None
|
| if self.is_oracle:
|
| invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
|
| batch_size = obs.shape[0]
|
|
|
| match self.version:
|
| case 1:
|
| mu, logsig = self.brain(obs, invisible_obs)
|
| if self.stochastic_latent:
|
| latent = Normal(mu, logsig.exp() + 1e-6).sample()
|
| else:
|
| latent = mu
|
| q_out = self.dqn(latent, masks)
|
| case 2 | 3 | 4:
|
| phi = self.brain(obs)
|
| q_out = self.dqn(phi, masks)
|
|
|
| if self.boltzmann_epsilon > 0:
|
| is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
|
| logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
|
| sampled = sample_top_p(logits, self.top_p)
|
| actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
|
| else:
|
| is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
|
| actions = q_out.argmax(-1)
|
|
|
| return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
|
|
|
| def sample_top_p(logits, p):
|
| if p >= 1:
|
| return Categorical(logits=logits).sample()
|
| if p <= 0:
|
| return logits.argmax(-1)
|
| probs = logits.softmax(-1)
|
| probs_sort, probs_idx = probs.sort(-1, descending=True)
|
| probs_sum = probs_sort.cumsum(-1)
|
| mask = probs_sum - probs_sort > p
|
| probs_sort[mask] = 0.
|
| sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
|
| return sampled
|
|
|
| def load_model(seat: int) -> Bot:
|
|
|
|
|
| if torch.cuda.is_available():
|
| device = torch.device('cuda')
|
| else:
|
| device = torch.device('cpu')
|
|
|
|
|
| control_state_file = "./Elite4zMarch4th9407.pth"
|
|
|
|
|
| control_state_file = pathlib.Path(__file__).parent / control_state_file
|
| state = torch.load(control_state_file, map_location=device)
|
|
|
| mortal = Brain(version=state['config']['control']['version'], conv_channels=state['config']['resnet']['conv_channels'], num_blocks=state['config']['resnet']['num_blocks']).eval()
|
| dqn = DQN(version=state['config']['control']['version']).eval()
|
| mortal.load_state_dict(state['mortal'])
|
| dqn.load_state_dict(state['current_dqn'])
|
|
|
| engine = MortalEngine(
|
| mortal,
|
| dqn,
|
| is_oracle = False,
|
| version = state['config']['control']['version'],
|
| device = device,
|
| enable_amp = False,
|
| enable_quick_eval = False,
|
| enable_rule_based_agari_guard = True,
|
| name = 'mortal',
|
| )
|
|
|
| bot = Bot(engine, seat)
|
| return bot
|
|
|