| import json |
| import gzip |
| import torch |
| import pathlib |
| import requests |
| import traceback |
| import numpy as np |
|
|
| from torch import nn, Tensor |
| from torch.nn import functional as F |
| from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence |
| from torch.distributions import Normal, Categorical |
| from typing import * |
| from functools import partial |
| from itertools import permutations |
| from libriichi3p.mjai import Bot |
| from libriichi3p.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE |
|
|
| |
| OT_REQUEST_TIMEOUT = 2 |
| ot_settings = { |
| "server": "http://example.com", |
| "online": False, |
| "api_key": "example_api_key", |
| } |
| is_online = False |
|
|
| def online_settings_init(): |
| global ot_settings |
| |
| if (pathlib.Path(__file__).parent / 'ot_settings.json').exists(): |
| with open(pathlib.Path(__file__).parent / 'ot_settings.json', 'r') as f: |
| ot_settings = json.load(f) |
|
|
| online_settings_init() |
| |
|
|
| class ChannelAttention(nn.Module): |
| def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True): |
| super().__init__() |
| self.shared_mlp = nn.Sequential( |
| nn.Linear(channels, channels // ratio, bias=bias), |
| actv_builder(), |
| nn.Linear(channels // ratio, channels, bias=bias), |
| ) |
| if bias: |
| for mod in self.modules(): |
| if isinstance(mod, nn.Linear): |
| nn.init.constant_(mod.bias, 0) |
|
|
| def forward(self, x: Tensor): |
| avg_out = self.shared_mlp(x.mean(-1)) |
| max_out = self.shared_mlp(x.amax(-1)) |
| weight = (avg_out + max_out).sigmoid() |
| x = weight.unsqueeze(-1) * x |
| return x |
|
|
| class ResBlock(nn.Module): |
| def __init__( |
| self, |
| channels, |
| *, |
| norm_builder = nn.Identity, |
| actv_builder = nn.ReLU, |
| pre_actv = False, |
| ): |
| super().__init__() |
| self.pre_actv = pre_actv |
|
|
| if pre_actv: |
| self.res_unit = nn.Sequential( |
| norm_builder(), |
| actv_builder(), |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), |
| norm_builder(), |
| actv_builder(), |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), |
| ) |
| else: |
| self.res_unit = nn.Sequential( |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), |
| norm_builder(), |
| actv_builder(), |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), |
| norm_builder(), |
| ) |
| self.actv = actv_builder() |
| self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True) |
|
|
| def forward(self, x): |
| out = self.res_unit(x) |
| out = self.ca(out) |
| out = out + x |
| if not self.pre_actv: |
| out = self.actv(out) |
| return out |
|
|
| class ResNet(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| conv_channels, |
| num_blocks, |
| *, |
| norm_builder = nn.Identity, |
| actv_builder = nn.ReLU, |
| pre_actv = False, |
| ): |
| super().__init__() |
|
|
| blocks = [] |
| for _ in range(num_blocks): |
| blocks.append(ResBlock( |
| conv_channels, |
| norm_builder = norm_builder, |
| actv_builder = actv_builder, |
| pre_actv = pre_actv, |
| )) |
|
|
| layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)] |
| if pre_actv: |
| layers += [*blocks, norm_builder(), actv_builder()] |
| else: |
| layers += [norm_builder(), actv_builder(), *blocks] |
| layers += [ |
| nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1), |
| actv_builder(), |
| nn.Flatten(), |
| nn.Linear(32 * 34, 1024), |
| ] |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| class Brain(nn.Module): |
| def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1): |
| super().__init__() |
| self.is_oracle = is_oracle |
| self.version = version |
|
|
| in_channels = obs_shape(version)[0] |
| if is_oracle: |
| in_channels += oracle_obs_shape(version)[0] |
|
|
| norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01) |
| actv_builder = partial(nn.Mish, inplace=True) |
| pre_actv = True |
|
|
| match version: |
| case 1: |
| actv_builder = partial(nn.ReLU, inplace=True) |
| pre_actv = False |
| self.latent_net = nn.Sequential( |
| nn.Linear(1024, 512), |
| nn.ReLU(inplace=True), |
| ) |
| self.mu_head = nn.Linear(512, 512) |
| self.logsig_head = nn.Linear(512, 512) |
| case 2: |
| pass |
| case 3 | 4: |
| norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3) |
| case _: |
| raise ValueError(f'Unexpected version {self.version}') |
|
|
| self.encoder = ResNet( |
| in_channels = in_channels, |
| conv_channels = conv_channels, |
| num_blocks = num_blocks, |
| norm_builder = norm_builder, |
| actv_builder = actv_builder, |
| pre_actv = pre_actv, |
| ) |
| self.actv = actv_builder() |
|
|
| |
| self._freeze_bn = False |
|
|
| def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]: |
| if self.is_oracle: |
| assert invisible_obs is not None |
| obs = torch.cat((obs, invisible_obs), dim=1) |
| phi = self.encoder(obs) |
| phi = F.dropout(phi, p=0.1, training=self.training) |
| match self.version: |
| case 1: |
| latent_out = self.latent_net(phi) |
| mu = self.mu_head(latent_out) |
| logsig = self.logsig_head(latent_out) |
| return mu, logsig |
| case 2 | 3 | 4: |
| return self.actv(phi) |
| case _: |
| raise ValueError(f'Unexpected version {self.version}') |
|
|
| def train(self, mode=True): |
| super().train(mode) |
| if self._freeze_bn: |
| for mod in self.modules(): |
| if isinstance(mod, nn.BatchNorm1d): |
| mod.eval() |
| |
| |
| return self |
|
|
| def reset_running_stats(self): |
| for mod in self.modules(): |
| if isinstance(mod, nn.BatchNorm1d): |
| mod.reset_running_stats() |
|
|
| def freeze_bn(self, value: bool): |
| self._freeze_bn = value |
| return self.train(self.training) |
|
|
| class AuxNet(nn.Module): |
| def __init__(self, dims=None): |
| super().__init__() |
| self.dims = dims |
| self.net = nn.Linear(1024, sum(dims), bias=False) |
|
|
| def forward(self, x): |
| return self.net(x).split(self.dims, dim=-1) |
|
|
| class DQN(nn.Module): |
| def __init__(self, *, version=1): |
| super().__init__() |
| self.version = version |
| match version: |
| case 1: |
| self.v_head = nn.Linear(512, 1) |
| self.a_head = nn.Linear(512, ACTION_SPACE) |
| case 2 | 3: |
| hidden_size = 512 if version == 2 else 256 |
| self.v_head = nn.Sequential( |
| nn.Linear(1024, hidden_size), |
| nn.Mish(inplace=True), |
| nn.Linear(hidden_size, 1), |
| ) |
| self.a_head = nn.Sequential( |
| nn.Linear(1024, hidden_size), |
| nn.Mish(inplace=True), |
| nn.Linear(hidden_size, ACTION_SPACE), |
| ) |
| case 4: |
| self.net = nn.Linear(1024, 1 + ACTION_SPACE) |
| nn.init.constant_(self.net.bias, 0) |
|
|
| def forward(self, phi, mask): |
| if self.version == 4: |
| v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1) |
| else: |
| v = self.v_head(phi) |
| a = self.a_head(phi) |
| a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True) |
| mask_sum = mask.sum(-1, keepdim=True) |
| a_mean = a_sum / mask_sum |
| q = (v + a - a_mean).masked_fill(~mask, -1e9) |
| return q |
|
|
|
|
| class MortalEngine: |
| def __init__( |
| self, |
| brain, |
| dqn, |
| is_oracle, |
| version, |
| device = None, |
| stochastic_latent = False, |
| enable_amp = False, |
| enable_quick_eval = True, |
| enable_rule_based_agari_guard = False, |
| name = 'NoName', |
| boltzmann_epsilon = 0, |
| boltzmann_temp = 1, |
| top_p = 1, |
| ): |
| self.engine_type = 'mortal' |
| self.device = device or torch.device('cpu') |
| assert isinstance(self.device, torch.device) |
| self.brain = brain.to(self.device).eval() |
| self.dqn = dqn.to(self.device).eval() |
| self.is_oracle = is_oracle |
| self.version = version |
| self.stochastic_latent = stochastic_latent |
|
|
| self.enable_amp = enable_amp |
| self.enable_quick_eval = enable_quick_eval |
| self.enable_rule_based_agari_guard = enable_rule_based_agari_guard |
| self.name = name |
|
|
| self.boltzmann_epsilon = boltzmann_epsilon |
| self.boltzmann_temp = boltzmann_temp |
| self.top_p = top_p |
|
|
| def react_batch(self, obs, masks, invisible_obs): |
| |
| global ot_settings, is_online |
| |
| if ot_settings['online']: |
| try: |
| list_obs = [o.tolist() for o in obs] |
| list_masks = [m.tolist() for m in masks] |
| post_data = { |
| 'obs': list_obs, |
| 'masks': list_masks, |
| } |
| data = json.dumps(post_data, separators=(',', ':')) |
| compressed_data = gzip.compress(data.encode('utf-8')) |
| headers = { |
| 'Authorization': ot_settings['api_key'], |
| 'Content-Encoding': 'gzip', |
| } |
| r = requests.post( |
| f'{ot_settings["server"]}/react_batch_3p', |
| headers=headers, |
| data=compressed_data, |
| timeout=OT_REQUEST_TIMEOUT |
| ) |
| assert r.status_code == 200 |
| is_online = True |
| r_json = r.json() |
| return r_json['actions'], r_json['q_out'], r_json['masks'], r_json['is_greedy'] |
| except: |
| is_online = False |
| pass |
| |
| try: |
| with ( |
| torch.autocast(self.device.type, enabled=self.enable_amp), |
| torch.inference_mode(), |
| ): |
| return self._react_batch(obs, masks, invisible_obs) |
| except Exception as ex: |
| raise Exception(f'{ex}\n{traceback.format_exc()}') |
|
|
| def _react_batch(self, obs, masks, invisible_obs): |
| obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device) |
| masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device) |
| invisible_obs = None |
| if self.is_oracle: |
| invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device) |
| batch_size = obs.shape[0] |
|
|
| match self.version: |
| case 1: |
| mu, logsig = self.brain(obs, invisible_obs) |
| if self.stochastic_latent: |
| latent = Normal(mu, logsig.exp() + 1e-6).sample() |
| else: |
| latent = mu |
| q_out = self.dqn(latent, masks) |
| case 2 | 3 | 4: |
| phi = self.brain(obs) |
| q_out = self.dqn(phi, masks) |
|
|
| if self.boltzmann_epsilon > 0: |
| is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool) |
| logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf) |
| sampled = sample_top_p(logits, self.top_p) |
| actions = torch.where(is_greedy, q_out.argmax(-1), sampled) |
| else: |
| is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device) |
| actions = q_out.argmax(-1) |
| return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist() |
|
|
| def sample_top_p(logits, p): |
| if p >= 1: |
| return Categorical(logits=logits).sample() |
| if p <= 0: |
| return logits.argmax(-1) |
| probs = logits.softmax(-1) |
| probs_sort, probs_idx = probs.sort(-1, descending=True) |
| probs_sum = probs_sort.cumsum(-1) |
| mask = probs_sum - probs_sort > p |
| probs_sort[mask] = 0. |
| sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1) |
| return sampled |
|
|
| def load_model(seat: int) -> Bot: |
| |
| if torch.cuda.is_available() or False: |
| device = torch.device('cuda') |
| else: |
| device = torch.device('cpu') |
|
|
| |
| control_state_file = "./Backup_B50_Rank1.75_Rwd0.12_Ent0.149.pth" |
|
|
| |
| control_state_file = pathlib.Path(__file__).parent / control_state_file |
| state = torch.load(control_state_file, map_location=device) |
|
|
| mortal = Brain(version=state['config']['control']['version'], conv_channels=state['config']['resnet']['conv_channels'], num_blocks=state['config']['resnet']['num_blocks']).eval() |
| dqn = DQN(version=state['config']['control']['version']).eval() |
| mortal.load_state_dict(state['mortal']) |
| dqn.load_state_dict(state['current_dqn']) |
|
|
| engine = MortalEngine( |
| mortal, |
| dqn, |
| is_oracle = False, |
| version = state['config']['control']['version'], |
| device = device, |
| enable_amp = False, |
| enable_quick_eval = False, |
| enable_rule_based_agari_guard = True, |
| name = 'mortal', |
| top_p = 1, |
| ) |
|
|
| bot = Bot(engine, seat) |
| return bot |
|
|