Spaces:
Sleeping
Sleeping
| import json | |
| import gzip | |
| import torch | |
| import pathlib | |
| import requests | |
| import traceback | |
| import numpy as np | |
| from torch import nn, Tensor | |
| from torch.nn import functional as F | |
| from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence | |
| from torch.distributions import Normal, Categorical | |
| from typing import * | |
| from functools import partial | |
| from itertools import permutations | |
| try: | |
| from libriichi3p.mjai import Bot | |
| from libriichi3p.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE | |
| except: | |
| import importlib.util | |
| import sys | |
| import os | |
| # ⚠️ 这里必须填入你在 Colab 中的绝对路径! | |
| # 假设你的文件在云盘的 MahjongTest 文件夹下,名字叫 libriichi3p.so | |
| # 如果你的文件叫别的名字,或者在别的文件夹,请务必修改这行路径 | |
| SO_FILE_PATH = "/content/drive/MyDrive/MahjongTest/libriichi3p.so" | |
| # 1. 检查文件到底存不存在 | |
| if not os.path.exists(SO_FILE_PATH): | |
| print(f"❌ 致命错误:在路径 {SO_FILE_PATH} 下根本找不到文件!请检查路径拼写。") | |
| else: | |
| print(f"✅ 找到文件: {SO_FILE_PATH},正在尝试强行加载...") | |
| try: | |
| # 2. 根据绝对路径创建模块加载规范 (spec) | |
| # 第一个参数是你想给它起的名字(供 Python 内部识别),第二个参数是文件路径 | |
| spec = importlib.util.spec_from_file_location("libriichi3p", SO_FILE_PATH) | |
| # 3. 实例化模块 | |
| libriichi3p_module = importlib.util.module_from_spec(spec) | |
| # 4. 注册到系统的模块字典里 (非常重要!这样后续其他文件 import libriichi3p 就能直接用) | |
| sys.modules["libriichi3p"] = libriichi3p_module | |
| # 5. 执行底层代码加载 | |
| spec.loader.exec_module(libriichi3p_module) | |
| print("🎉 强行导入成功!现在可以在代码里正常使用了。") | |
| except Exception as e: | |
| print(f"❌ 导入失败,暴露出真实报错: {e}") | |
| # ========== Online Server =========== # | |
| OT_REQUEST_TIMEOUT = 2 | |
| ot_settings = { | |
| "server": "http://example.com", | |
| "online": False, | |
| "api_key": "example_api_key", | |
| } | |
| is_online = False | |
| def online_settings_init(): | |
| global ot_settings | |
| # Check if the file exists | |
| if (pathlib.Path(__file__).parent / 'ot_settings.json').exists(): | |
| with open(pathlib.Path(__file__).parent / 'ot_settings.json', 'r') as f: | |
| ot_settings = json.load(f) | |
| online_settings_init() | |
| # ==================================== # | |
| class ChannelAttention(nn.Module): | |
| def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True): | |
| super().__init__() | |
| self.shared_mlp = nn.Sequential( | |
| nn.Linear(channels, channels // ratio, bias=bias), | |
| actv_builder(), | |
| nn.Linear(channels // ratio, channels, bias=bias), | |
| ) | |
| if bias: | |
| for mod in self.modules(): | |
| if isinstance(mod, nn.Linear): | |
| nn.init.constant_(mod.bias, 0) | |
| def forward(self, x: Tensor): | |
| avg_out = self.shared_mlp(x.mean(-1)) | |
| max_out = self.shared_mlp(x.amax(-1)) | |
| weight = (avg_out + max_out).sigmoid() | |
| x = weight.unsqueeze(-1) * x | |
| return x | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| *, | |
| norm_builder = nn.Identity, | |
| actv_builder = nn.ReLU, | |
| pre_actv = False, | |
| ): | |
| super().__init__() | |
| self.pre_actv = pre_actv | |
| if pre_actv: | |
| self.res_unit = nn.Sequential( | |
| norm_builder(), | |
| actv_builder(), | |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), | |
| norm_builder(), | |
| actv_builder(), | |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), | |
| ) | |
| else: | |
| self.res_unit = nn.Sequential( | |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), | |
| norm_builder(), | |
| actv_builder(), | |
| nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), | |
| norm_builder(), | |
| ) | |
| self.actv = actv_builder() | |
| self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True) | |
| def forward(self, x): | |
| out = self.res_unit(x) | |
| out = self.ca(out) | |
| out = out + x | |
| if not self.pre_actv: | |
| out = self.actv(out) | |
| return out | |
| class ResNet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| conv_channels, | |
| num_blocks, | |
| *, | |
| norm_builder = nn.Identity, | |
| actv_builder = nn.ReLU, | |
| pre_actv = False, | |
| ): | |
| super().__init__() | |
| blocks = [] | |
| for _ in range(num_blocks): | |
| blocks.append(ResBlock( | |
| conv_channels, | |
| norm_builder = norm_builder, | |
| actv_builder = actv_builder, | |
| pre_actv = pre_actv, | |
| )) | |
| layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)] | |
| if pre_actv: | |
| layers += [*blocks, norm_builder(), actv_builder()] | |
| else: | |
| layers += [norm_builder(), actv_builder(), *blocks] | |
| layers += [ | |
| nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1), | |
| actv_builder(), | |
| nn.Flatten(), | |
| nn.Linear(32 * 34, 1024), | |
| ] | |
| self.net = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Brain(nn.Module): | |
| def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1): | |
| super().__init__() | |
| self.is_oracle = is_oracle | |
| self.version = version | |
| in_channels = obs_shape(version)[0] | |
| if is_oracle: | |
| in_channels += oracle_obs_shape(version)[0] | |
| norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01) | |
| actv_builder = partial(nn.Mish, inplace=True) | |
| pre_actv = True | |
| match version: | |
| case 1: | |
| actv_builder = partial(nn.ReLU, inplace=True) | |
| pre_actv = False | |
| self.latent_net = nn.Sequential( | |
| nn.Linear(1024, 512), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.mu_head = nn.Linear(512, 512) | |
| self.logsig_head = nn.Linear(512, 512) | |
| case 2: | |
| pass | |
| case 3 | 4: | |
| norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3) | |
| case _: | |
| raise ValueError(f'Unexpected version {self.version}') | |
| self.encoder = ResNet( | |
| in_channels = in_channels, | |
| conv_channels = conv_channels, | |
| num_blocks = num_blocks, | |
| norm_builder = norm_builder, | |
| actv_builder = actv_builder, | |
| pre_actv = pre_actv, | |
| ) | |
| self.actv = actv_builder() | |
| # always use EMA or CMA when True | |
| self._freeze_bn = False | |
| def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]: | |
| if self.is_oracle: | |
| assert invisible_obs is not None | |
| obs = torch.cat((obs, invisible_obs), dim=1) | |
| phi = self.encoder(obs) | |
| phi = F.dropout(phi, p=0.1, training=self.training) | |
| match self.version: | |
| case 1: | |
| latent_out = self.latent_net(phi) | |
| mu = self.mu_head(latent_out) | |
| logsig = self.logsig_head(latent_out) | |
| return mu, logsig | |
| case 2 | 3 | 4: | |
| return self.actv(phi) | |
| case _: | |
| raise ValueError(f'Unexpected version {self.version}') | |
| def train(self, mode=True): | |
| super().train(mode) | |
| if self._freeze_bn: | |
| for mod in self.modules(): | |
| if isinstance(mod, nn.BatchNorm1d): | |
| mod.eval() | |
| # I don't think this benefits | |
| # module.requires_grad_(False) | |
| return self | |
| def reset_running_stats(self): | |
| for mod in self.modules(): | |
| if isinstance(mod, nn.BatchNorm1d): | |
| mod.reset_running_stats() | |
| def freeze_bn(self, value: bool): | |
| self._freeze_bn = value | |
| return self.train(self.training) | |
| class AuxNet(nn.Module): | |
| def __init__(self, dims=None): | |
| super().__init__() | |
| self.dims = dims | |
| self.net = nn.Linear(1024, sum(dims), bias=False) | |
| def forward(self, x): | |
| return self.net(x).split(self.dims, dim=-1) | |
| class DQN(nn.Module): | |
| def __init__(self, *, version=1): | |
| super().__init__() | |
| self.version = version | |
| match version: | |
| case 1: | |
| self.v_head = nn.Linear(512, 1) | |
| self.a_head = nn.Linear(512, ACTION_SPACE) | |
| case 2 | 3: | |
| hidden_size = 512 if version == 2 else 256 | |
| self.v_head = nn.Sequential( | |
| nn.Linear(1024, hidden_size), | |
| nn.Mish(inplace=True), | |
| nn.Linear(hidden_size, 1), | |
| ) | |
| self.a_head = nn.Sequential( | |
| nn.Linear(1024, hidden_size), | |
| nn.Mish(inplace=True), | |
| nn.Linear(hidden_size, ACTION_SPACE), | |
| ) | |
| case 4: | |
| self.net = nn.Linear(1024, 1 + ACTION_SPACE) | |
| nn.init.constant_(self.net.bias, 0) | |
| def forward(self, phi, mask): | |
| if self.version == 4: | |
| v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1) | |
| else: | |
| v = self.v_head(phi) | |
| a = self.a_head(phi) | |
| a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True) | |
| mask_sum = mask.sum(-1, keepdim=True) | |
| a_mean = a_sum / mask_sum | |
| q = (v + a - a_mean).masked_fill(~mask, -1e9) | |
| return q | |
| class MortalEngine: | |
| def __init__( | |
| self, | |
| brain, | |
| dqn, | |
| is_oracle, | |
| version, | |
| device = None, | |
| stochastic_latent = False, | |
| enable_amp = False, | |
| enable_quick_eval = True, | |
| enable_rule_based_agari_guard = False, | |
| name = 'NoName', | |
| boltzmann_epsilon = 0, | |
| boltzmann_temp = 1, | |
| top_p = 1, | |
| ): | |
| self.engine_type = 'mortal' | |
| self.device = device or torch.device('cpu') | |
| assert isinstance(self.device, torch.device) | |
| self.brain = brain.to(self.device).eval() | |
| self.dqn = dqn.to(self.device).eval() | |
| self.is_oracle = is_oracle | |
| self.version = version | |
| self.stochastic_latent = stochastic_latent | |
| self.enable_amp = enable_amp | |
| self.enable_quick_eval = enable_quick_eval | |
| self.enable_rule_based_agari_guard = enable_rule_based_agari_guard | |
| self.name = name | |
| self.boltzmann_epsilon = boltzmann_epsilon | |
| self.boltzmann_temp = boltzmann_temp | |
| self.top_p = top_p | |
| def react_batch(self, obs, masks, invisible_obs): | |
| # ========== Online Server =========== # | |
| global ot_settings, is_online | |
| # print('Reacting Batch') | |
| if ot_settings['online']: | |
| try: | |
| list_obs = [o.tolist() for o in obs] | |
| list_masks = [m.tolist() for m in masks] | |
| post_data = { | |
| 'obs': list_obs, | |
| 'masks': list_masks, | |
| } | |
| data = json.dumps(post_data, separators=(',', ':')) | |
| compressed_data = gzip.compress(data.encode('utf-8')) | |
| headers = { | |
| 'Authorization': ot_settings['api_key'], | |
| 'Content-Encoding': 'gzip', | |
| } | |
| r = requests.post( | |
| f'{ot_settings["server"]}/react_batch_3p', | |
| headers=headers, | |
| data=compressed_data, | |
| timeout=OT_REQUEST_TIMEOUT | |
| ) | |
| assert r.status_code == 200 | |
| is_online = True | |
| r_json = r.json() | |
| return r_json['actions'], r_json['q_out'], r_json['masks'], r_json['is_greedy'] | |
| except: | |
| is_online = False | |
| pass | |
| # ==================================== # | |
| try: | |
| with ( | |
| torch.autocast(self.device.type, enabled=self.enable_amp), | |
| torch.inference_mode(), | |
| ): | |
| return self._react_batch(obs, masks, invisible_obs) | |
| except Exception as ex: | |
| raise Exception(f'{ex}\n{traceback.format_exc()}') | |
| def _react_batch(self, obs, masks, invisible_obs): | |
| obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device) | |
| masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device) | |
| invisible_obs = None | |
| if self.is_oracle: | |
| invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device) | |
| batch_size = obs.shape[0] | |
| match self.version: | |
| case 1: | |
| mu, logsig = self.brain(obs, invisible_obs) | |
| if self.stochastic_latent: | |
| latent = Normal(mu, logsig.exp() + 1e-6).sample() | |
| else: | |
| latent = mu | |
| q_out = self.dqn(latent, masks) | |
| case 2 | 3 | 4: | |
| phi = self.brain(obs) | |
| q_out = self.dqn(phi, masks) | |
| if self.boltzmann_epsilon > 0: | |
| is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool) | |
| logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf) | |
| sampled = sample_top_p(logits, self.top_p) | |
| actions = torch.where(is_greedy, q_out.argmax(-1), sampled) | |
| else: | |
| is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device) | |
| actions = q_out.argmax(-1) | |
| return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist() | |
| def sample_top_p(logits, p): | |
| if p >= 1: | |
| return Categorical(logits=logits).sample() | |
| if p <= 0: | |
| return logits.argmax(-1) | |
| probs = logits.softmax(-1) | |
| probs_sort, probs_idx = probs.sort(-1, descending=True) | |
| probs_sum = probs_sort.cumsum(-1) | |
| mask = probs_sum - probs_sort > p | |
| probs_sort[mask] = 0. | |
| sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1) | |
| return sampled | |
| def load_model(seat: int, model: str) -> Bot: | |
| # check if GPU is available | |
| if torch.cuda.is_available(): | |
| device = torch.device('cuda') | |
| else: | |
| device = torch.device('cpu') | |
| # latest binary model | |
| if model == None: | |
| model = 'Elite4zWeightedBest5.pth' | |
| model = str(model).split('?')[0] | |
| control_state_file = model | |
| print(control_state_file, 'loaded') | |
| # Get the path of control_state_file = current directory / control_state_file | |
| control_state_file = pathlib.Path(__file__).parent / control_state_file | |
| state = torch.load(control_state_file, map_location=device) | |
| mortal = Brain(version=state['config']['control']['version'], conv_channels=state['config']['resnet']['conv_channels'], num_blocks=state['config']['resnet']['num_blocks']).eval() | |
| dqn = DQN(version=state['config']['control']['version']).eval() | |
| mortal.load_state_dict(state['mortal']) | |
| dqn.load_state_dict(state['current_dqn']) | |
| engine = MortalEngine( | |
| mortal, | |
| dqn, | |
| is_oracle = False, | |
| version = state['config']['control']['version'], | |
| device = device, | |
| enable_amp = False, | |
| enable_quick_eval = False, | |
| enable_rule_based_agari_guard = True, | |
| name = 'mortal', | |
| top_p = 1, | |
| ) | |
| bot = Bot(engine, seat) | |
| return bot | |