File size: 6,420 Bytes
d98decc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gymnasium as gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class CardEncoder(nn.Module):
    """

    Shared encoder for single cards.

    Input: [Batch, ..., 64] -> Output: [Batch, ..., EmbedDim]

    Optimized: Reduced layer count, removed intermediate LayerNorm.

    """

    def __init__(self, input_dim=64, embed_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class MultiHeadCardAttention(nn.Module):
    """

    Self-Attention block for handling sets of cards.

    Optimized: Removed post-norm in favor of pre-norm style if desired,

    but keeping it simple: just standard MHA is fine.

    """

    def __init__(self, embed_dim=128, num_heads=4):
        super().__init__()
        # batch_first=True is critical for speed with our data layout
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x, mask=None):
        # Flattened logic for speed:
        # Pre-Norm (Original was Post-Norm, let's keep Post-Norm but optimized)

        # Robustness handling:
        if mask is not None:
            # Fast check: are any masked?
            if mask.any():
                all_masked = mask.all(dim=1, keepdim=True)
                mask = mask & (~all_masked)

        # MHA
        attn_out, _ = self.attn(x, x, x, key_padding_mask=mask, need_weights=False)

        # Add & Norm
        return self.norm(x + attn_out)


class LovecaFeaturesExtractor(BaseFeaturesExtractor):
    """

    Custom Feature Extractor for Love Live TCG.

    Parses the 2240-dim structured observation into semantic components.

    """

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)

        self.card_dim = 64
        self.embed_dim = 128  # Consider reducing to 64 if speed is critical? No, keep 128 for quality.

        # Calculate offsets based on 2240 layout
        # Hand (15) + HandOver (1) + Stage (3) + Live (3) + LiveSucc (3) + OppStage (3) + OppHist (6) = 34 Cards
        # 34 * 64 = 2176
        # Global = 64
        # Total = 2240

        self.n_hand = 16  # 15 + 1
        self.n_stage = 3
        self.n_live = 6  # 3 Pending + 3 Success
        self.n_opp = 9  # 3 Stage + 6 History

        # 1. Shared Card Encoder
        self.card_encoder = CardEncoder(self.card_dim, self.embed_dim)

        # 2. Attention Blocks
        self.hand_attention = MultiHeadCardAttention(self.embed_dim, num_heads=4)
        self.opp_attention = MultiHeadCardAttention(self.embed_dim, num_heads=2)

        # 3. Embeddings/Projections
        # Positional Embeddings for fixed slot zones (Stage, Live, OppStage)
        self.stage_pos_emb = nn.Parameter(torch.randn(1, 3, self.embed_dim))
        self.live_pos_emb = nn.Parameter(torch.randn(1, 6, self.embed_dim))
        self.opp_pos_emb = nn.Parameter(torch.randn(1, 9, self.embed_dim))

        # 4. Fusion
        # Inputs to fusion:
        # - Hand (16 * 128): 2048
        # - Stage (3 * 128): 384
        # - Live (6 * 128): 768
        # - Opp Summary (Mean Pool): 128
        # - Global: 64
        # Total Fusion Input: 2048+384+768+128+64 = 3392

        self.fusion_dim = 3392
        self.fusion_net = nn.Sequential(
            nn.Linear(self.fusion_dim, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, features_dim),
            nn.LayerNorm(features_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        batch_size = observations.shape[0]

        # 1. Slice Observation
        hand_flat = observations[:, :1024]
        stage_flat = observations[:, 1024:1216]
        live_flat = observations[:, 1216:1600]
        opp_flat = observations[:, 1600:2176]
        global_features = observations[:, 2176:]

        # 2. Reshape & Encode
        hand_cards = hand_flat.reshape(batch_size, 16, 64)
        stage_cards = stage_flat.reshape(batch_size, 3, 64)
        live_cards = live_flat.reshape(batch_size, 6, 64)
        opp_cards = opp_flat.reshape(batch_size, 9, 64)

        # Create Masks (Presence bit is index 0)
        hand_mask = hand_cards[:, :, 0] < 0.5
        opp_mask = opp_cards[:, :, 0] < 0.5

        # Encode All Cards
        hand_emb = self.card_encoder(hand_cards)
        stage_emb = self.card_encoder(stage_cards)
        live_emb = self.card_encoder(live_cards)
        opp_emb = self.card_encoder(opp_cards)

        # 3. Process Zones

        # A. Hand: Flattened embeddings (preserving slot-to-card mapping)
        # We still apply the mask to zero out empty slots
        mask_expanded = hand_mask.unsqueeze(-1).float()
        hand_processed = hand_emb * (1.0 - mask_expanded)
        hand_flat_emb = hand_processed.reshape(batch_size, -1)

        # B. Stage: Positional Encoding
        stage_processed = stage_emb + self.stage_pos_emb
        stage_flat_emb = stage_processed.reshape(batch_size, -1)

        # C. Live: Positional Encoding
        live_processed = live_emb + self.live_pos_emb
        live_flat_emb = live_processed.reshape(batch_size, -1)

        # D. Opponent: Attention + Mean Pool (Strategic summary)
        opp_processed = self.opp_attention(opp_emb, mask=opp_mask)
        opp_mask_expanded = opp_mask.unsqueeze(-1).float()
        opp_processed = opp_processed * (1.0 - opp_mask_expanded)
        opp_sum = opp_processed.sum(dim=1)
        opp_counts = 9.0 - opp_mask.sum(dim=1, keepdim=True).float()
        opp_summary = opp_sum / (opp_counts + 1e-6)

        # 4. Fusion
        combined = torch.cat(
            [
                hand_flat_emb,  # 2048
                stage_flat_emb,  # 384
                live_flat_emb,  # 768
                opp_summary,  # 128
                global_features,  # 64
            ],
            dim=1,
        )

        return self.fusion_net(combined)