duel-tiny-fighter / tiny_fighter.py
sankalphs's picture
Upload tiny fighter source
11b9fe5 verified
Raw
History Blame Contribute Delete
7.15 kB
"""Tiny CPU fighter model for real-time NPC move selection.
Architecture: ~142k parameter MLP with LayerNorm (behaves correctly at
batch=1 inference, unlike BatchNorm1d which has degenerate running variance
when there's only a single sample). Fast enough for real-time combat
(< 1ms on CPU) while having enough capacity to learn strategy-conditioned
move selection.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
MOVES = [
"jab", "cross", "hook", "kick", "uppercut",
"block", "parry", "dodge",
"advance", "retreat",
"grapple", "throw",
"sweep", "feint", "wait",
]
NUM_MOVES = len(MOVES)
MOVE_TO_IDX = {m: i for i, m in enumerate(MOVES)}
ATTACKS = {"jab", "cross", "hook", "kick", "uppercut", "sweep"}
DEFENSES = {"block", "parry", "dodge"}
MOVEMENT = {"advance", "retreat"}
GRAPPLES = {"grapple", "throw"}
UTILITY = {"feint", "wait"}
INPUT_DIM = 168
HIDDEN1 = 256
HIDDEN2 = 128
class TinyFighter(nn.Module):
"""Real-time NPC move policy. CPU-friendly, strategy-conditioned."""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(INPUT_DIM, HIDDEN1),
nn.LayerNorm(HIDDEN1),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(HIDDEN1, HIDDEN2),
nn.LayerNorm(HIDDEN2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(HIDDEN2, NUM_MOVES),
)
for m in self.net:
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if x.dim() == 1:
x = x.unsqueeze(0)
logits = self.net(x)
if mask is not None:
if mask.dim() == 1:
mask = mask.unsqueeze(0)
logits = logits.masked_fill(mask == 0, -1e9)
return logits
@torch.inference_mode()
def predict(self, feats: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Single-sample inference helper.
Cheaper than a manual `with torch.no_grad(): forward(...)` because
inference_mode disables more bookkeeping. Callers that batch many
samples should still use forward() under their own no_grad context,
but for the real-time path (batch=1, one move per request) this is
the fast path.
"""
return self.forward(feats, mask)
def remap_bn_state_to_ln(state_dict: dict) -> dict:
"""Drop BatchNorm1d running stats from a state dict so it can load into
the LayerNorm-based TinyFighter architecture.
The Linear weights load unchanged. BatchNorm buffers (running_mean,
running_var, num_batches_tracked) and the BN affine (weight, bias) are
discarded -- the LayerNorm modules start with their PyTorch defaults
(weight=1, bias=0), so the model still produces a well-defined output
even if the policy will need a few rounds of additional training to
re-converge to its previous quality.
"""
drop_suffixes = ("running_mean", "running_var", "num_batches_tracked")
out = {}
for k, v in state_dict.items():
if k.endswith(drop_suffixes):
continue
if k.endswith(".weight") and ".net." in k and any(
f".net.{i}." in k for i in (1, 5)
):
idx = int(k.split(".net.")[1].split(".")[0])
if idx in (1, 5):
continue
if k.endswith(".bias") and ".net." in k and any(
f".net.{i}." in k for i in (1, 5)
):
idx = int(k.split(".net.")[1].split(".")[0])
if idx in (1, 5):
continue
out[k] = v
return out
def state_to_features(
last_npc_moves: List[str],
last_player_moves: List[str],
player_hp: float,
npc_hp: float,
player_stamina: float,
npc_stamina: float,
distance: str,
aggression: float,
defense: float,
parry_affinity: float,
kick_affinity: float,
grapple_affinity: float,
round_num: int = 1,
history_len: int = 5,
) -> torch.Tensor:
"""Convert game state to a 168-dim feature tensor."""
features = []
for i in range(history_len):
idx = MOVE_TO_IDX.get(
last_npc_moves[-(i + 1)] if len(last_npc_moves) > i else "wait", NUM_MOVES - 1
)
oh = [0.0] * NUM_MOVES
oh[idx] = 1.0
features.extend(oh)
for i in range(history_len):
idx = MOVE_TO_IDX.get(
last_player_moves[-(i + 1)] if len(last_player_moves) > i else "wait", NUM_MOVES - 1
)
oh = [0.0] * NUM_MOVES
oh[idx] = 1.0
features.extend(oh)
features.append((npc_hp - player_hp) / 100.0)
features.append((npc_stamina - player_stamina) / 100.0)
dist_oh = [0.0, 0.0, 0.0]
dist_oh[["near", "mid", "far"].index(distance) if distance in ["near", "mid", "far"] else 1] = 1.0
features.extend(dist_oh)
features.append(aggression)
features.append(defense)
features.append(parry_affinity)
features.append(kick_affinity)
features.append(grapple_affinity)
features.append(min(round_num, 10) / 10.0)
features.append(player_hp / 100.0)
features.append(npc_hp / 100.0)
features.append(player_stamina / 100.0)
features.append(npc_stamina / 100.0)
while len(features) < INPUT_DIM:
features.append(0.0)
return torch.tensor(features, dtype=torch.float32)
def make_move_mask(distance: str) -> torch.Tensor:
"""Create a mask for moves that are valid at the given distance."""
mask = [1.0] * NUM_MOVES
if distance == "far":
mask[MOVE_TO_IDX["grapple"]] = 0.0
mask[MOVE_TO_IDX["throw"]] = 0.0
mask[MOVE_TO_IDX["sweep"]] = 0.0
elif distance == "near":
mask[MOVE_TO_IDX["advance"]] = 0.0
return torch.tensor(mask, dtype=torch.float32)
if __name__ == "__main__":
import time
model = TinyFighter()
total = sum(p.numel() for p in model.parameters())
print(f"Total params: {total:,}")
model.eval()
features = state_to_features(
last_npc_moves=["jab", "block", "kick"],
last_player_moves=["cross", "retreat", "jab"],
player_hp=80.0, npc_hp=50.0,
player_stamina=60.0, npc_stamina=40.0,
distance="mid",
aggression=0.7, defense=0.3,
parry_affinity=0.4, kick_affinity=0.6,
grapple_affinity=0.2, round_num=3,
)
mask = make_move_mask("mid")
# Warmup so the first timed call isn't paying one-off dispatch cost.
model.predict(features, mask)
model.predict(features, mask)
with torch.inference_mode():
start = time.perf_counter()
for _ in range(1000):
logits = model.predict(features, mask)
elapsed = (time.perf_counter() - start) / 1000 * 1000
probs = F.softmax(logits, dim=-1)
move_idx = probs.argmax().item()
print(f"Inference: {elapsed:.3f}ms per call")
print(f"Suggested move: {MOVES[move_idx]} (prob={probs[0][move_idx]:.3f})")