| """Tiny CPU fighter model for real-time NPC move selection. |
| |
| Architecture: ~142k parameter MLP with LayerNorm (behaves correctly at |
| batch=1 inference, unlike BatchNorm1d which has degenerate running variance |
| when there's only a single sample). Fast enough for real-time combat |
| (< 1ms on CPU) while having enough capacity to learn strategy-conditioned |
| move selection. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Optional |
|
|
| MOVES = [ |
| "jab", "cross", "hook", "kick", "uppercut", |
| "block", "parry", "dodge", |
| "advance", "retreat", |
| "grapple", "throw", |
| "sweep", "feint", "wait", |
| ] |
| NUM_MOVES = len(MOVES) |
| MOVE_TO_IDX = {m: i for i, m in enumerate(MOVES)} |
|
|
| ATTACKS = {"jab", "cross", "hook", "kick", "uppercut", "sweep"} |
| DEFENSES = {"block", "parry", "dodge"} |
| MOVEMENT = {"advance", "retreat"} |
| GRAPPLES = {"grapple", "throw"} |
| UTILITY = {"feint", "wait"} |
|
|
| INPUT_DIM = 168 |
| HIDDEN1 = 256 |
| HIDDEN2 = 128 |
|
|
|
|
| class TinyFighter(nn.Module): |
| """Real-time NPC move policy. CPU-friendly, strategy-conditioned.""" |
|
|
| def __init__(self): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(INPUT_DIM, HIDDEN1), |
| nn.LayerNorm(HIDDEN1), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(HIDDEN1, HIDDEN2), |
| nn.LayerNorm(HIDDEN2), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(HIDDEN2, NUM_MOVES), |
| ) |
| for m in self.net: |
| if isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| if x.dim() == 1: |
| x = x.unsqueeze(0) |
| logits = self.net(x) |
| if mask is not None: |
| if mask.dim() == 1: |
| mask = mask.unsqueeze(0) |
| logits = logits.masked_fill(mask == 0, -1e9) |
| return logits |
|
|
| @torch.inference_mode() |
| def predict(self, feats: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Single-sample inference helper. |
| |
| Cheaper than a manual `with torch.no_grad(): forward(...)` because |
| inference_mode disables more bookkeeping. Callers that batch many |
| samples should still use forward() under their own no_grad context, |
| but for the real-time path (batch=1, one move per request) this is |
| the fast path. |
| """ |
| return self.forward(feats, mask) |
|
|
|
|
| def remap_bn_state_to_ln(state_dict: dict) -> dict: |
| """Drop BatchNorm1d running stats from a state dict so it can load into |
| the LayerNorm-based TinyFighter architecture. |
| |
| The Linear weights load unchanged. BatchNorm buffers (running_mean, |
| running_var, num_batches_tracked) and the BN affine (weight, bias) are |
| discarded -- the LayerNorm modules start with their PyTorch defaults |
| (weight=1, bias=0), so the model still produces a well-defined output |
| even if the policy will need a few rounds of additional training to |
| re-converge to its previous quality. |
| """ |
| drop_suffixes = ("running_mean", "running_var", "num_batches_tracked") |
| out = {} |
| for k, v in state_dict.items(): |
| if k.endswith(drop_suffixes): |
| continue |
| if k.endswith(".weight") and ".net." in k and any( |
| f".net.{i}." in k for i in (1, 5) |
| ): |
| idx = int(k.split(".net.")[1].split(".")[0]) |
| if idx in (1, 5): |
| continue |
| if k.endswith(".bias") and ".net." in k and any( |
| f".net.{i}." in k for i in (1, 5) |
| ): |
| idx = int(k.split(".net.")[1].split(".")[0]) |
| if idx in (1, 5): |
| continue |
| out[k] = v |
| return out |
|
|
|
|
| def state_to_features( |
| last_npc_moves: List[str], |
| last_player_moves: List[str], |
| player_hp: float, |
| npc_hp: float, |
| player_stamina: float, |
| npc_stamina: float, |
| distance: str, |
| aggression: float, |
| defense: float, |
| parry_affinity: float, |
| kick_affinity: float, |
| grapple_affinity: float, |
| round_num: int = 1, |
| history_len: int = 5, |
| ) -> torch.Tensor: |
| """Convert game state to a 168-dim feature tensor.""" |
| features = [] |
|
|
| for i in range(history_len): |
| idx = MOVE_TO_IDX.get( |
| last_npc_moves[-(i + 1)] if len(last_npc_moves) > i else "wait", NUM_MOVES - 1 |
| ) |
| oh = [0.0] * NUM_MOVES |
| oh[idx] = 1.0 |
| features.extend(oh) |
|
|
| for i in range(history_len): |
| idx = MOVE_TO_IDX.get( |
| last_player_moves[-(i + 1)] if len(last_player_moves) > i else "wait", NUM_MOVES - 1 |
| ) |
| oh = [0.0] * NUM_MOVES |
| oh[idx] = 1.0 |
| features.extend(oh) |
|
|
| features.append((npc_hp - player_hp) / 100.0) |
| features.append((npc_stamina - player_stamina) / 100.0) |
| dist_oh = [0.0, 0.0, 0.0] |
| dist_oh[["near", "mid", "far"].index(distance) if distance in ["near", "mid", "far"] else 1] = 1.0 |
| features.extend(dist_oh) |
|
|
| features.append(aggression) |
| features.append(defense) |
| features.append(parry_affinity) |
| features.append(kick_affinity) |
| features.append(grapple_affinity) |
|
|
| features.append(min(round_num, 10) / 10.0) |
| features.append(player_hp / 100.0) |
| features.append(npc_hp / 100.0) |
| features.append(player_stamina / 100.0) |
| features.append(npc_stamina / 100.0) |
|
|
| while len(features) < INPUT_DIM: |
| features.append(0.0) |
|
|
| return torch.tensor(features, dtype=torch.float32) |
|
|
|
|
| def make_move_mask(distance: str) -> torch.Tensor: |
| """Create a mask for moves that are valid at the given distance.""" |
| mask = [1.0] * NUM_MOVES |
| if distance == "far": |
| mask[MOVE_TO_IDX["grapple"]] = 0.0 |
| mask[MOVE_TO_IDX["throw"]] = 0.0 |
| mask[MOVE_TO_IDX["sweep"]] = 0.0 |
| elif distance == "near": |
| mask[MOVE_TO_IDX["advance"]] = 0.0 |
| return torch.tensor(mask, dtype=torch.float32) |
|
|
|
|
| if __name__ == "__main__": |
| import time |
|
|
| model = TinyFighter() |
| total = sum(p.numel() for p in model.parameters()) |
| print(f"Total params: {total:,}") |
|
|
| model.eval() |
| features = state_to_features( |
| last_npc_moves=["jab", "block", "kick"], |
| last_player_moves=["cross", "retreat", "jab"], |
| player_hp=80.0, npc_hp=50.0, |
| player_stamina=60.0, npc_stamina=40.0, |
| distance="mid", |
| aggression=0.7, defense=0.3, |
| parry_affinity=0.4, kick_affinity=0.6, |
| grapple_affinity=0.2, round_num=3, |
| ) |
| mask = make_move_mask("mid") |
|
|
| |
| model.predict(features, mask) |
| model.predict(features, mask) |
|
|
| with torch.inference_mode(): |
| start = time.perf_counter() |
| for _ in range(1000): |
| logits = model.predict(features, mask) |
| elapsed = (time.perf_counter() - start) / 1000 * 1000 |
| probs = F.softmax(logits, dim=-1) |
| move_idx = probs.argmax().item() |
|
|
| print(f"Inference: {elapsed:.3f}ms per call") |
| print(f"Suggested move: {MOVES[move_idx]} (prob={probs[0][move_idx]:.3f})") |
|
|