Spaces:
Running
Running
File size: 6,420 Bytes
d98decc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gymnasium as gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CardEncoder(nn.Module):
"""
Shared encoder for single cards.
Input: [Batch, ..., 64] -> Output: [Batch, ..., EmbedDim]
Optimized: Reduced layer count, removed intermediate LayerNorm.
"""
def __init__(self, input_dim=64, embed_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(inplace=True),
nn.Linear(128, embed_dim),
nn.LayerNorm(embed_dim),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class MultiHeadCardAttention(nn.Module):
"""
Self-Attention block for handling sets of cards.
Optimized: Removed post-norm in favor of pre-norm style if desired,
but keeping it simple: just standard MHA is fine.
"""
def __init__(self, embed_dim=128, num_heads=4):
super().__init__()
# batch_first=True is critical for speed with our data layout
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x, mask=None):
# Flattened logic for speed:
# Pre-Norm (Original was Post-Norm, let's keep Post-Norm but optimized)
# Robustness handling:
if mask is not None:
# Fast check: are any masked?
if mask.any():
all_masked = mask.all(dim=1, keepdim=True)
mask = mask & (~all_masked)
# MHA
attn_out, _ = self.attn(x, x, x, key_padding_mask=mask, need_weights=False)
# Add & Norm
return self.norm(x + attn_out)
class LovecaFeaturesExtractor(BaseFeaturesExtractor):
"""
Custom Feature Extractor for Love Live TCG.
Parses the 2240-dim structured observation into semantic components.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
self.card_dim = 64
self.embed_dim = 128 # Consider reducing to 64 if speed is critical? No, keep 128 for quality.
# Calculate offsets based on 2240 layout
# Hand (15) + HandOver (1) + Stage (3) + Live (3) + LiveSucc (3) + OppStage (3) + OppHist (6) = 34 Cards
# 34 * 64 = 2176
# Global = 64
# Total = 2240
self.n_hand = 16 # 15 + 1
self.n_stage = 3
self.n_live = 6 # 3 Pending + 3 Success
self.n_opp = 9 # 3 Stage + 6 History
# 1. Shared Card Encoder
self.card_encoder = CardEncoder(self.card_dim, self.embed_dim)
# 2. Attention Blocks
self.hand_attention = MultiHeadCardAttention(self.embed_dim, num_heads=4)
self.opp_attention = MultiHeadCardAttention(self.embed_dim, num_heads=2)
# 3. Embeddings/Projections
# Positional Embeddings for fixed slot zones (Stage, Live, OppStage)
self.stage_pos_emb = nn.Parameter(torch.randn(1, 3, self.embed_dim))
self.live_pos_emb = nn.Parameter(torch.randn(1, 6, self.embed_dim))
self.opp_pos_emb = nn.Parameter(torch.randn(1, 9, self.embed_dim))
# 4. Fusion
# Inputs to fusion:
# - Hand (16 * 128): 2048
# - Stage (3 * 128): 384
# - Live (6 * 128): 768
# - Opp Summary (Mean Pool): 128
# - Global: 64
# Total Fusion Input: 2048+384+768+128+64 = 3392
self.fusion_dim = 3392
self.fusion_net = nn.Sequential(
nn.Linear(self.fusion_dim, 512),
nn.ReLU(inplace=True),
nn.Linear(512, features_dim),
nn.LayerNorm(features_dim),
nn.ReLU(inplace=True),
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
batch_size = observations.shape[0]
# 1. Slice Observation
hand_flat = observations[:, :1024]
stage_flat = observations[:, 1024:1216]
live_flat = observations[:, 1216:1600]
opp_flat = observations[:, 1600:2176]
global_features = observations[:, 2176:]
# 2. Reshape & Encode
hand_cards = hand_flat.reshape(batch_size, 16, 64)
stage_cards = stage_flat.reshape(batch_size, 3, 64)
live_cards = live_flat.reshape(batch_size, 6, 64)
opp_cards = opp_flat.reshape(batch_size, 9, 64)
# Create Masks (Presence bit is index 0)
hand_mask = hand_cards[:, :, 0] < 0.5
opp_mask = opp_cards[:, :, 0] < 0.5
# Encode All Cards
hand_emb = self.card_encoder(hand_cards)
stage_emb = self.card_encoder(stage_cards)
live_emb = self.card_encoder(live_cards)
opp_emb = self.card_encoder(opp_cards)
# 3. Process Zones
# A. Hand: Flattened embeddings (preserving slot-to-card mapping)
# We still apply the mask to zero out empty slots
mask_expanded = hand_mask.unsqueeze(-1).float()
hand_processed = hand_emb * (1.0 - mask_expanded)
hand_flat_emb = hand_processed.reshape(batch_size, -1)
# B. Stage: Positional Encoding
stage_processed = stage_emb + self.stage_pos_emb
stage_flat_emb = stage_processed.reshape(batch_size, -1)
# C. Live: Positional Encoding
live_processed = live_emb + self.live_pos_emb
live_flat_emb = live_processed.reshape(batch_size, -1)
# D. Opponent: Attention + Mean Pool (Strategic summary)
opp_processed = self.opp_attention(opp_emb, mask=opp_mask)
opp_mask_expanded = opp_mask.unsqueeze(-1).float()
opp_processed = opp_processed * (1.0 - opp_mask_expanded)
opp_sum = opp_processed.sum(dim=1)
opp_counts = 9.0 - opp_mask.sum(dim=1, keepdim=True).float()
opp_summary = opp_sum / (opp_counts + 1e-6)
# 4. Fusion
combined = torch.cat(
[
hand_flat_emb, # 2048
stage_flat_emb, # 384
live_flat_emb, # 768
opp_summary, # 128
global_features, # 64
],
dim=1,
)
return self.fusion_net(combined)
|