import json import gzip import torch import pathlib import requests import traceback import numpy as np from torch import nn, Tensor from torch.nn import functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence from torch.distributions import Normal, Categorical from typing import * from functools import partial from itertools import permutations from libriichi.mjai import Bot from libriichi.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE # ========== Online Server =========== # OT_REQUEST_TIMEOUT = 2 ot_settings = { "server": "http://example.com", "online": False, "api_key": "example_api_key", } is_online = False def online_settings_init(): global ot_settings # Check if the file exists if (pathlib.Path(__file__).parent / 'ot_settings.json').exists(): with open(pathlib.Path(__file__).parent / 'ot_settings.json', 'r') as f: ot_settings = json.load(f) online_settings_init() # ==================================== # class ChannelAttention(nn.Module): def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True): super().__init__() self.shared_mlp = nn.Sequential( nn.Linear(channels, channels // ratio, bias=bias), actv_builder(), nn.Linear(channels // ratio, channels, bias=bias), ) if bias: for mod in self.modules(): if isinstance(mod, nn.Linear): nn.init.constant_(mod.bias, 0) def forward(self, x: Tensor): avg_out = self.shared_mlp(x.mean(-1)) max_out = self.shared_mlp(x.amax(-1)) weight = (avg_out + max_out).sigmoid() x = weight.unsqueeze(-1) * x return x class ResBlock(nn.Module): def __init__( self, channels, *, norm_builder = nn.Identity, actv_builder = nn.ReLU, pre_actv = False, ): super().__init__() self.pre_actv = pre_actv if pre_actv: self.res_unit = nn.Sequential( norm_builder(), actv_builder(), nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), norm_builder(), actv_builder(), nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), ) else: self.res_unit = nn.Sequential( nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), norm_builder(), actv_builder(), nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), norm_builder(), ) self.actv = actv_builder() self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True) def forward(self, x): out = self.res_unit(x) out = self.ca(out) out = out + x if not self.pre_actv: out = self.actv(out) return out class ResNet(nn.Module): def __init__( self, in_channels, conv_channels, num_blocks, *, norm_builder = nn.Identity, actv_builder = nn.ReLU, pre_actv = False, ): super().__init__() blocks = [] for _ in range(num_blocks): blocks.append(ResBlock( conv_channels, norm_builder = norm_builder, actv_builder = actv_builder, pre_actv = pre_actv, )) layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)] if pre_actv: layers += [*blocks, norm_builder(), actv_builder()] else: layers += [norm_builder(), actv_builder(), *blocks] layers += [ nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1), actv_builder(), nn.Flatten(), nn.Linear(32 * 34, 1024), ] self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class Brain(nn.Module): def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1): super().__init__() self.is_oracle = is_oracle self.version = version in_channels = obs_shape(version)[0] if is_oracle: in_channels += oracle_obs_shape(version)[0] norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01) actv_builder = partial(nn.Mish, inplace=True) pre_actv = True match version: case 1: actv_builder = partial(nn.ReLU, inplace=True) pre_actv = False self.latent_net = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(inplace=True), ) self.mu_head = nn.Linear(512, 512) self.logsig_head = nn.Linear(512, 512) case 2: pass case 3 | 4: norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3) case _: raise ValueError(f'Unexpected version {self.version}') self.encoder = ResNet( in_channels = in_channels, conv_channels = conv_channels, num_blocks = num_blocks, norm_builder = norm_builder, actv_builder = actv_builder, pre_actv = pre_actv, ) self.actv = actv_builder() # always use EMA or CMA when True self._freeze_bn = False def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]: if self.is_oracle: assert invisible_obs is not None obs = torch.cat((obs, invisible_obs), dim=1) phi = self.encoder(obs) match self.version: case 1: latent_out = self.latent_net(phi) mu = self.mu_head(latent_out) logsig = self.logsig_head(latent_out) return mu, logsig case 2 | 3 | 4: return self.actv(phi) case _: raise ValueError(f'Unexpected version {self.version}') def train(self, mode=True): super().train(mode) if self._freeze_bn: for mod in self.modules(): if isinstance(mod, nn.BatchNorm1d): mod.eval() # I don't think this benefits # module.requires_grad_(False) return self def reset_running_stats(self): for mod in self.modules(): if isinstance(mod, nn.BatchNorm1d): mod.reset_running_stats() def freeze_bn(self, value: bool): self._freeze_bn = value return self.train(self.training) class AuxNet(nn.Module): def __init__(self, dims=None): super().__init__() self.dims = dims self.net = nn.Linear(1024, sum(dims), bias=False) def forward(self, x): return self.net(x).split(self.dims, dim=-1) class DQN(nn.Module): def __init__(self, *, version=1): super().__init__() self.version = version print(version) match version: case 1: self.v_head = nn.Linear(512, 1) self.a_head = nn.Linear(512, ACTION_SPACE) case 2 | 3: hidden_size = 512 if version == 2 else 256 self.v_head = nn.Sequential( nn.Linear(1024, hidden_size), nn.Mish(inplace=True), nn.Linear(hidden_size, 1), ) self.a_head = nn.Sequential( nn.Linear(1024, hidden_size), nn.Mish(inplace=True), nn.Linear(hidden_size, ACTION_SPACE), ) case 4: self.net = nn.Linear(1024, 1 + ACTION_SPACE) nn.init.constant_(self.net.bias, 0) def forward(self, phi, mask): if self.version == 4: v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1) else: v = self.v_head(phi) a = self.a_head(phi) a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True) mask_sum = mask.sum(-1, keepdim=True) a_mean = a_sum / mask_sum q = (v + a - a_mean).masked_fill(~mask, -torch.inf) return q class MortalEngine: def __init__( self, brain, dqn, is_oracle, version, device = None, stochastic_latent = False, enable_amp = False, enable_quick_eval = True, enable_rule_based_agari_guard = False, name = 'NoName', boltzmann_epsilon = 0, boltzmann_temp = 1, top_p = 1, ): self.engine_type = 'mortal' self.device = device or torch.device('cpu') assert isinstance(self.device, torch.device) self.brain = brain.to(self.device).eval() self.dqn = dqn.to(self.device).eval() self.is_oracle = is_oracle self.version = version self.stochastic_latent = stochastic_latent self.enable_amp = enable_amp self.enable_quick_eval = enable_quick_eval self.enable_rule_based_agari_guard = enable_rule_based_agari_guard self.name = name self.boltzmann_epsilon = boltzmann_epsilon self.boltzmann_temp = boltzmann_temp self.top_p = top_p def react_batch(self, obs, masks, invisible_obs): # ========== Online Server =========== # global ot_settings, is_online try: with ( torch.autocast(self.device.type, enabled=self.enable_amp), torch.inference_mode(), ): return self._react_batch(obs, masks, invisible_obs) except Exception as ex: raise Exception(f'{ex}\n{traceback.format_exc()}') def _react_batch(self, obs, masks, invisible_obs): obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device) masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device) invisible_obs = None if self.is_oracle: invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device) batch_size = obs.shape[0] match self.version: case 1: mu, logsig = self.brain(obs, invisible_obs) if self.stochastic_latent: latent = Normal(mu, logsig.exp() + 1e-6).sample() else: latent = mu q_out = self.dqn(latent, masks) case 2 | 3 | 4: phi = self.brain(obs) q_out = self.dqn(phi, masks) if self.boltzmann_epsilon > 0: is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool) logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf) sampled = sample_top_p(logits, self.top_p) actions = torch.where(is_greedy, q_out.argmax(-1), sampled) else: is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device) actions = q_out.argmax(-1) return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist() def sample_top_p(logits, p): if p >= 1: return Categorical(logits=logits).sample() if p <= 0: return logits.argmax(-1) probs = logits.softmax(-1) probs_sort, probs_idx = probs.sort(-1, descending=True) probs_sum = probs_sort.cumsum(-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0. sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1) return sampled def load_model(seat: int) -> Bot: # check if GPU is available # device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') # latest binary model control_state_file = "./Elite4zMarch4th9407.pth" # Get the path of control_state_file = current directory / control_state_file control_state_file = pathlib.Path(__file__).parent / control_state_file state = torch.load(control_state_file, map_location=device) mortal = Brain(version=state['config']['control']['version'], conv_channels=state['config']['resnet']['conv_channels'], num_blocks=state['config']['resnet']['num_blocks']).eval() dqn = DQN(version=state['config']['control']['version']).eval() mortal.load_state_dict(state['mortal']) dqn.load_state_dict(state['current_dqn']) engine = MortalEngine( mortal, dqn, is_oracle = False, version = state['config']['control']['version'], device = device, enable_amp = False, enable_quick_eval = False, enable_rule_based_agari_guard = True, name = 'mortal', ) bot = Bot(engine, seat) return bot