trioskosmos commited on
Commit
d98decc
·
verified ·
1 Parent(s): 1f19911

Upload ai/utils/loveca_features_extractor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/loveca_features_extractor.py +175 -0
ai/utils/loveca_features_extractor.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import torch
3
+ import torch.nn as nn
4
+ from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
5
+
6
+
7
+ class CardEncoder(nn.Module):
8
+ """
9
+ Shared encoder for single cards.
10
+ Input: [Batch, ..., 64] -> Output: [Batch, ..., EmbedDim]
11
+ Optimized: Reduced layer count, removed intermediate LayerNorm.
12
+ """
13
+
14
+ def __init__(self, input_dim=64, embed_dim=128):
15
+ super().__init__()
16
+ self.net = nn.Sequential(
17
+ nn.Linear(input_dim, 128),
18
+ nn.ReLU(inplace=True),
19
+ nn.Linear(128, embed_dim),
20
+ nn.LayerNorm(embed_dim),
21
+ nn.ReLU(inplace=True),
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.net(x)
26
+
27
+
28
+ class MultiHeadCardAttention(nn.Module):
29
+ """
30
+ Self-Attention block for handling sets of cards.
31
+ Optimized: Removed post-norm in favor of pre-norm style if desired,
32
+ but keeping it simple: just standard MHA is fine.
33
+ """
34
+
35
+ def __init__(self, embed_dim=128, num_heads=4):
36
+ super().__init__()
37
+ # batch_first=True is critical for speed with our data layout
38
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
39
+ self.norm = nn.LayerNorm(embed_dim)
40
+
41
+ def forward(self, x, mask=None):
42
+ # Flattened logic for speed:
43
+ # Pre-Norm (Original was Post-Norm, let's keep Post-Norm but optimized)
44
+
45
+ # Robustness handling:
46
+ if mask is not None:
47
+ # Fast check: are any masked?
48
+ if mask.any():
49
+ all_masked = mask.all(dim=1, keepdim=True)
50
+ mask = mask & (~all_masked)
51
+
52
+ # MHA
53
+ attn_out, _ = self.attn(x, x, x, key_padding_mask=mask, need_weights=False)
54
+
55
+ # Add & Norm
56
+ return self.norm(x + attn_out)
57
+
58
+
59
+ class LovecaFeaturesExtractor(BaseFeaturesExtractor):
60
+ """
61
+ Custom Feature Extractor for Love Live TCG.
62
+ Parses the 2240-dim structured observation into semantic components.
63
+ """
64
+
65
+ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
66
+ super().__init__(observation_space, features_dim)
67
+
68
+ self.card_dim = 64
69
+ self.embed_dim = 128 # Consider reducing to 64 if speed is critical? No, keep 128 for quality.
70
+
71
+ # Calculate offsets based on 2240 layout
72
+ # Hand (15) + HandOver (1) + Stage (3) + Live (3) + LiveSucc (3) + OppStage (3) + OppHist (6) = 34 Cards
73
+ # 34 * 64 = 2176
74
+ # Global = 64
75
+ # Total = 2240
76
+
77
+ self.n_hand = 16 # 15 + 1
78
+ self.n_stage = 3
79
+ self.n_live = 6 # 3 Pending + 3 Success
80
+ self.n_opp = 9 # 3 Stage + 6 History
81
+
82
+ # 1. Shared Card Encoder
83
+ self.card_encoder = CardEncoder(self.card_dim, self.embed_dim)
84
+
85
+ # 2. Attention Blocks
86
+ self.hand_attention = MultiHeadCardAttention(self.embed_dim, num_heads=4)
87
+ self.opp_attention = MultiHeadCardAttention(self.embed_dim, num_heads=2)
88
+
89
+ # 3. Embeddings/Projections
90
+ # Positional Embeddings for fixed slot zones (Stage, Live, OppStage)
91
+ self.stage_pos_emb = nn.Parameter(torch.randn(1, 3, self.embed_dim))
92
+ self.live_pos_emb = nn.Parameter(torch.randn(1, 6, self.embed_dim))
93
+ self.opp_pos_emb = nn.Parameter(torch.randn(1, 9, self.embed_dim))
94
+
95
+ # 4. Fusion
96
+ # Inputs to fusion:
97
+ # - Hand (16 * 128): 2048
98
+ # - Stage (3 * 128): 384
99
+ # - Live (6 * 128): 768
100
+ # - Opp Summary (Mean Pool): 128
101
+ # - Global: 64
102
+ # Total Fusion Input: 2048+384+768+128+64 = 3392
103
+
104
+ self.fusion_dim = 3392
105
+ self.fusion_net = nn.Sequential(
106
+ nn.Linear(self.fusion_dim, 512),
107
+ nn.ReLU(inplace=True),
108
+ nn.Linear(512, features_dim),
109
+ nn.LayerNorm(features_dim),
110
+ nn.ReLU(inplace=True),
111
+ )
112
+
113
+ def forward(self, observations: torch.Tensor) -> torch.Tensor:
114
+ batch_size = observations.shape[0]
115
+
116
+ # 1. Slice Observation
117
+ hand_flat = observations[:, :1024]
118
+ stage_flat = observations[:, 1024:1216]
119
+ live_flat = observations[:, 1216:1600]
120
+ opp_flat = observations[:, 1600:2176]
121
+ global_features = observations[:, 2176:]
122
+
123
+ # 2. Reshape & Encode
124
+ hand_cards = hand_flat.reshape(batch_size, 16, 64)
125
+ stage_cards = stage_flat.reshape(batch_size, 3, 64)
126
+ live_cards = live_flat.reshape(batch_size, 6, 64)
127
+ opp_cards = opp_flat.reshape(batch_size, 9, 64)
128
+
129
+ # Create Masks (Presence bit is index 0)
130
+ hand_mask = hand_cards[:, :, 0] < 0.5
131
+ opp_mask = opp_cards[:, :, 0] < 0.5
132
+
133
+ # Encode All Cards
134
+ hand_emb = self.card_encoder(hand_cards)
135
+ stage_emb = self.card_encoder(stage_cards)
136
+ live_emb = self.card_encoder(live_cards)
137
+ opp_emb = self.card_encoder(opp_cards)
138
+
139
+ # 3. Process Zones
140
+
141
+ # A. Hand: Flattened embeddings (preserving slot-to-card mapping)
142
+ # We still apply the mask to zero out empty slots
143
+ mask_expanded = hand_mask.unsqueeze(-1).float()
144
+ hand_processed = hand_emb * (1.0 - mask_expanded)
145
+ hand_flat_emb = hand_processed.reshape(batch_size, -1)
146
+
147
+ # B. Stage: Positional Encoding
148
+ stage_processed = stage_emb + self.stage_pos_emb
149
+ stage_flat_emb = stage_processed.reshape(batch_size, -1)
150
+
151
+ # C. Live: Positional Encoding
152
+ live_processed = live_emb + self.live_pos_emb
153
+ live_flat_emb = live_processed.reshape(batch_size, -1)
154
+
155
+ # D. Opponent: Attention + Mean Pool (Strategic summary)
156
+ opp_processed = self.opp_attention(opp_emb, mask=opp_mask)
157
+ opp_mask_expanded = opp_mask.unsqueeze(-1).float()
158
+ opp_processed = opp_processed * (1.0 - opp_mask_expanded)
159
+ opp_sum = opp_processed.sum(dim=1)
160
+ opp_counts = 9.0 - opp_mask.sum(dim=1, keepdim=True).float()
161
+ opp_summary = opp_sum / (opp_counts + 1e-6)
162
+
163
+ # 4. Fusion
164
+ combined = torch.cat(
165
+ [
166
+ hand_flat_emb, # 2048
167
+ stage_flat_emb, # 384
168
+ live_flat_emb, # 768
169
+ opp_summary, # 128
170
+ global_features, # 64
171
+ ],
172
+ dim=1,
173
+ )
174
+
175
+ return self.fusion_net(combined)