import gymnasium as gym import torch import torch.nn as nn from stable_baselines3.common.torch_layers import BaseFeaturesExtractor class CardEncoder(nn.Module): """ Shared encoder for single cards. Input: [Batch, ..., 64] -> Output: [Batch, ..., EmbedDim] Optimized: Reduced layer count, removed intermediate LayerNorm. """ def __init__(self, input_dim=64, embed_dim=128): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU(inplace=True), nn.Linear(128, embed_dim), nn.LayerNorm(embed_dim), nn.ReLU(inplace=True), ) def forward(self, x): return self.net(x) class MultiHeadCardAttention(nn.Module): """ Self-Attention block for handling sets of cards. Optimized: Removed post-norm in favor of pre-norm style if desired, but keeping it simple: just standard MHA is fine. """ def __init__(self, embed_dim=128, num_heads=4): super().__init__() # batch_first=True is critical for speed with our data layout self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) self.norm = nn.LayerNorm(embed_dim) def forward(self, x, mask=None): # Flattened logic for speed: # Pre-Norm (Original was Post-Norm, let's keep Post-Norm but optimized) # Robustness handling: if mask is not None: # Fast check: are any masked? if mask.any(): all_masked = mask.all(dim=1, keepdim=True) mask = mask & (~all_masked) # MHA attn_out, _ = self.attn(x, x, x, key_padding_mask=mask, need_weights=False) # Add & Norm return self.norm(x + attn_out) class LovecaFeaturesExtractor(BaseFeaturesExtractor): """ Custom Feature Extractor for Love Live TCG. Parses the 2240-dim structured observation into semantic components. """ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): super().__init__(observation_space, features_dim) self.card_dim = 64 self.embed_dim = 128 # Consider reducing to 64 if speed is critical? No, keep 128 for quality. # Calculate offsets based on 2240 layout # Hand (15) + HandOver (1) + Stage (3) + Live (3) + LiveSucc (3) + OppStage (3) + OppHist (6) = 34 Cards # 34 * 64 = 2176 # Global = 64 # Total = 2240 self.n_hand = 16 # 15 + 1 self.n_stage = 3 self.n_live = 6 # 3 Pending + 3 Success self.n_opp = 9 # 3 Stage + 6 History # 1. Shared Card Encoder self.card_encoder = CardEncoder(self.card_dim, self.embed_dim) # 2. Attention Blocks self.hand_attention = MultiHeadCardAttention(self.embed_dim, num_heads=4) self.opp_attention = MultiHeadCardAttention(self.embed_dim, num_heads=2) # 3. Embeddings/Projections # Positional Embeddings for fixed slot zones (Stage, Live, OppStage) self.stage_pos_emb = nn.Parameter(torch.randn(1, 3, self.embed_dim)) self.live_pos_emb = nn.Parameter(torch.randn(1, 6, self.embed_dim)) self.opp_pos_emb = nn.Parameter(torch.randn(1, 9, self.embed_dim)) # 4. Fusion # Inputs to fusion: # - Hand (16 * 128): 2048 # - Stage (3 * 128): 384 # - Live (6 * 128): 768 # - Opp Summary (Mean Pool): 128 # - Global: 64 # Total Fusion Input: 2048+384+768+128+64 = 3392 self.fusion_dim = 3392 self.fusion_net = nn.Sequential( nn.Linear(self.fusion_dim, 512), nn.ReLU(inplace=True), nn.Linear(512, features_dim), nn.LayerNorm(features_dim), nn.ReLU(inplace=True), ) def forward(self, observations: torch.Tensor) -> torch.Tensor: batch_size = observations.shape[0] # 1. Slice Observation hand_flat = observations[:, :1024] stage_flat = observations[:, 1024:1216] live_flat = observations[:, 1216:1600] opp_flat = observations[:, 1600:2176] global_features = observations[:, 2176:] # 2. Reshape & Encode hand_cards = hand_flat.reshape(batch_size, 16, 64) stage_cards = stage_flat.reshape(batch_size, 3, 64) live_cards = live_flat.reshape(batch_size, 6, 64) opp_cards = opp_flat.reshape(batch_size, 9, 64) # Create Masks (Presence bit is index 0) hand_mask = hand_cards[:, :, 0] < 0.5 opp_mask = opp_cards[:, :, 0] < 0.5 # Encode All Cards hand_emb = self.card_encoder(hand_cards) stage_emb = self.card_encoder(stage_cards) live_emb = self.card_encoder(live_cards) opp_emb = self.card_encoder(opp_cards) # 3. Process Zones # A. Hand: Flattened embeddings (preserving slot-to-card mapping) # We still apply the mask to zero out empty slots mask_expanded = hand_mask.unsqueeze(-1).float() hand_processed = hand_emb * (1.0 - mask_expanded) hand_flat_emb = hand_processed.reshape(batch_size, -1) # B. Stage: Positional Encoding stage_processed = stage_emb + self.stage_pos_emb stage_flat_emb = stage_processed.reshape(batch_size, -1) # C. Live: Positional Encoding live_processed = live_emb + self.live_pos_emb live_flat_emb = live_processed.reshape(batch_size, -1) # D. Opponent: Attention + Mean Pool (Strategic summary) opp_processed = self.opp_attention(opp_emb, mask=opp_mask) opp_mask_expanded = opp_mask.unsqueeze(-1).float() opp_processed = opp_processed * (1.0 - opp_mask_expanded) opp_sum = opp_processed.sum(dim=1) opp_counts = 9.0 - opp_mask.sum(dim=1, keepdim=True).float() opp_summary = opp_sum / (opp_counts + 1e-6) # 4. Fusion combined = torch.cat( [ hand_flat_emb, # 2048 stage_flat_emb, # 384 live_flat_emb, # 768 opp_summary, # 128 global_features, # 64 ], dim=1, ) return self.fusion_net(combined)