Spitfire1970 commited on
Commit
687309c
·
1 Parent(s): d5f9381
Files changed (6) hide show
  1. encoder/model.py +129 -0
  2. encoder/transformer.py +143 -0
  3. handler.py +29 -0
  4. params_data.py +7 -0
  5. params_model.py +24 -0
  6. requirements.txt +2 -0
encoder/model.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from .transformer import ViT
3
+ sys.path.append("/".join(__file__.split('/')[:-2]))
4
+ from params_model import *
5
+ from params_data import *
6
+
7
+ from collections import OrderedDict
8
+ from torch import nn
9
+ import torch
10
+
11
+ class ConvBlock(nn.Sequential):
12
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
13
+ super().__init__(OrderedDict([
14
+ ('conv', nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)),
15
+ ('bn', nn.BatchNorm2d(out_channels)),
16
+ ('relu', nn.ReLU(inplace=True)),
17
+ ]))
18
+
19
+ class SqueezeExcitation(nn.Module):
20
+ def __init__(self, channels, ratio):
21
+ super().__init__()
22
+
23
+ self.pool = nn.AdaptiveAvgPool2d(1)
24
+ # tiny nn
25
+ self.lin1 = nn.Linear(channels, channels // ratio)
26
+ self.relu = nn.ReLU(inplace=True)
27
+ self.lin2 = nn.Linear(channels // ratio, 2 * channels)
28
+
29
+ def forward(self, x):
30
+ n, c, h, w = x.size()
31
+ x_in = x
32
+
33
+ x = self.pool(x).view(n, c)
34
+ x = self.lin1(x)
35
+ x = self.relu(x)
36
+ x = self.lin2(x)
37
+
38
+ x = x.view(n, 2 * c, 1, 1)
39
+ scale, shift = x.chunk(2, dim=1)
40
+
41
+ x = scale.sigmoid() * x_in + shift
42
+ return x
43
+
44
+ class ResidualBlock(nn.Module):
45
+ def __init__(self, channels, se_ratio):
46
+ super().__init__()
47
+ self.layers = nn.Sequential(OrderedDict([
48
+ ('conv1', nn.Conv2d(channels, channels, 3, padding=1, bias=False)),
49
+ ('bn1', nn.BatchNorm2d(channels)),
50
+ ('relu', nn.ReLU(inplace=True)),
51
+
52
+ ('conv2', nn.Conv2d(channels, channels, 3, padding=1, bias=False)),
53
+ ('bn2', nn.BatchNorm2d(channels)),
54
+
55
+ ('se', SqueezeExcitation(channels, se_ratio)),
56
+ ]))
57
+ self.relu2 = nn.ReLU(inplace=True)
58
+
59
+ def forward(self, x):
60
+ x_in = x
61
+
62
+ x = self.layers(x)
63
+
64
+ x = x + x_in
65
+ x = self.relu2(x)
66
+ return x
67
+
68
+ class Encoder(nn.Module):
69
+
70
+ def __init__(self, loss_device, loss_method = "softmax"):
71
+ super().__init__()
72
+ self.loss_device = loss_device
73
+
74
+ channels = residual_channels
75
+
76
+ self.conv_block = ConvBlock(34, channels, 3, padding=1)
77
+ blocks = [(f'block{i+1}', ResidualBlock(channels, se_ratio)) for i in range(residual_blocks)]
78
+ self.residual_stack = nn.Sequential(OrderedDict(blocks))
79
+
80
+ self.conv_block2 = ConvBlock(channels, channels, 3, padding=1)
81
+ self.final_feature = ConvBlock(channels, vit_input_channels, 3, padding=1)
82
+ self.global_avgpool = nn.AvgPool2d(kernel_size=8)
83
+
84
+ self.cnn = nn.Sequential(*[
85
+ self.conv_block,
86
+ self.residual_stack,
87
+ self.conv_block2,
88
+ self.final_feature,
89
+ self.global_avgpool,
90
+ torch.nn.Flatten()
91
+ ])
92
+
93
+ self.transformer = ViT(input_dim=vit_input_channels,
94
+ output_dim=model_embedding_size,
95
+ dim=transformer_input_dim,
96
+ depth=transformer_depth,
97
+ heads=attention_heads,
98
+ mlp_dim=mlp_dim,
99
+ pool='mean',
100
+ dim_head = dim_head,
101
+ dropout=dropout,
102
+ emb_dropout=emb_dropout)
103
+
104
+ # Cosine similarity scaling (with fixed initial parameter values)
105
+ self.similarity_weight = nn.Parameter(torch.tensor([similarity_weight_init]))
106
+ self.similarity_bias = nn.Parameter(torch.tensor([similarity_bias_init]))
107
+
108
+ def forward(self, games):
109
+
110
+ batch_size, n_frames, feature_shape = games.shape[0], games.shape[1], games.shape[2:]
111
+
112
+ # (batch_size, n_frames, 34, 8, 8) -> (batch_size*n_frames, 34, 8, 8)
113
+ games = torch.reshape(games, (batch_size*n_frames, *feature_shape))
114
+
115
+ # (batch_size*n_frames, cnn_out_features)
116
+ game_features = self.cnn(games)
117
+
118
+ # (batch_size*n_frames, cnn_out_features) -> (batch_size, n_frames, cnn_out_features)
119
+ game_features = torch.reshape(game_features, (batch_size, n_frames, game_features.shape[-1]))
120
+
121
+ # Pass the input into transformer
122
+ # (batch_size, n_frames, n_features)
123
+ embeds_raw = self.transformer(game_features)
124
+ # self.lstm.flatten_parameters()
125
+
126
+ # L2-normalize it
127
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
128
+
129
+ return embeds
encoder/transformer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original vision transformer from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, einsum
5
+ import torch.nn.functional as F
6
+
7
+ # https://einops.rocks/pytorch-examples.html
8
+ from einops import rearrange
9
+
10
+ class PreNorm(nn.Module):
11
+ def __init__(self, dim, fn):
12
+ super().__init__()
13
+ self.norm = nn.LayerNorm(dim)
14
+ self.fn = fn
15
+ def forward(self, x):
16
+ return self.fn(self.norm(x))
17
+
18
+ class FeedForward(nn.Module):
19
+ def __init__(self, dim, hidden_dim, dropout = 0.):
20
+ super().__init__()
21
+ self.net = nn.Sequential(
22
+ nn.Linear(dim, hidden_dim),
23
+ nn.GELU(),
24
+ nn.Dropout(dropout),
25
+ nn.Linear(hidden_dim, dim),
26
+ nn.Dropout(dropout)
27
+ )
28
+ def forward(self, x):
29
+ return self.net(x)
30
+
31
+ class Attention(nn.Module):
32
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
33
+ super().__init__()
34
+ inner_dim = dim_head * heads
35
+ project_out = not (heads == 1 and dim_head == dim)
36
+
37
+ self.heads = heads
38
+ self.scale = dim_head ** -0.5
39
+
40
+ self.attend = nn.Softmax(dim = -1)
41
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
42
+
43
+ self.to_out = nn.Sequential(
44
+ nn.Linear(inner_dim, dim),
45
+ nn.Dropout(dropout)
46
+ ) if project_out else nn.Identity()
47
+
48
+ def forward(self, x):
49
+ b, n, _, h = *x.shape, self.heads
50
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
51
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
52
+
53
+ # for each batch and each head, multiply each query position (i) with each key position (j), summing over the embedding dimension (d), etc
54
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
55
+
56
+ attn = self.attend(dots)
57
+
58
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
59
+ out = rearrange(out, 'b h n d -> b n (h d)')
60
+ return self.to_out(out)
61
+
62
+ class Transformer(nn.Module):
63
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
64
+ super().__init__()
65
+ self.layers = nn.ModuleList([])
66
+ for _ in range(depth):
67
+ self.layers.append(nn.ModuleList([
68
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
69
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
70
+ ]))
71
+ def forward(self, x):
72
+ for attn, ff in self.layers:
73
+ x = attn(x) + x
74
+ x = ff(x) + x
75
+ return x
76
+
77
+ class PositionalEncoding(nn.Module):
78
+ # https://discuss.pytorch.org/t/positional-encoding/175953
79
+ def __init__(self, d_model, max_len=500):
80
+ super().__init__()
81
+
82
+ pe = torch.zeros(max_len, d_model)
83
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
84
+ # alternatively adding sign and cos waves of increasing frequency
85
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
86
+ pe[:, 0::2] = torch.sin(position * div_term)
87
+ pe[:, 1::2] = torch.cos(position * div_term)
88
+ pe = pe.unsqueeze(0)
89
+ self.register_buffer('pe', pe)
90
+
91
+ def forward(self, x):
92
+ # not x = x + self.pe[:x.size(0), :] since
93
+ # x.size(0): batch size whereas x.size(1): length of sequence
94
+ x = x + self.pe[:, :x.size(1), :]
95
+ return x
96
+
97
+ class ViT(nn.Module):
98
+ """
99
+ input_size: number of inputs
100
+ input_dim: number of channels in input
101
+ dim: Last dimension of output tensor after linear transformation nn.Linear(..., dim).
102
+ depth: Number of Transformer blocks.
103
+ heads: Number of heads in Multi-head Attention layer.
104
+ mlp_dim: Dimension of the MLP (FeedForward) layer.
105
+ dropout: Dropout rate.
106
+ emb_dropout: Embedding dropout rate.
107
+ pool: either cls token pooling or mean pooling
108
+ """
109
+ # * to force keyword-only args
110
+ def __init__(self, *, input_dim, output_dim, dim, depth, heads, mlp_dim, pool = 'mean', dim_head = 64, dropout, emb_dropout):
111
+ super().__init__()
112
+
113
+ self.project = nn.Linear(input_dim, dim)
114
+
115
+ self.pos_encoder = PositionalEncoding(dim)
116
+
117
+ self.dropout = nn.Dropout(emb_dropout)
118
+
119
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
120
+
121
+ self.pool = pool
122
+
123
+ self.mlp_head = nn.Sequential(
124
+ nn.LayerNorm(dim),
125
+ nn.Linear(dim, output_dim)
126
+ )
127
+
128
+ self.tanh = torch.nn.Tanh()
129
+
130
+ def forward(self, x):
131
+
132
+ x = self.project(x)
133
+ b, n, _ = x.shape
134
+
135
+ x = self.pos_encoder(x)
136
+
137
+ x = self.dropout(x)
138
+
139
+ x = self.transformer(x)
140
+
141
+ x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
142
+
143
+ return self.tanh(self.mlp_head(x))
handler.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from encoder.model import Encoder
3
+
4
+ class EndpointHandler():
5
+ def __init__(self, path="6.pt"):
6
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ checkpoint = torch.load(path, self.device, weights_only=True)
8
+ self.model = Encoder(self.device)
9
+ state_dict = checkpoint['model_state']
10
+ self.model.load_state_dict(state_dict)
11
+ self.model = self.model.to(self.device)
12
+ self.model.eval()
13
+
14
+ def __call__(self, data):
15
+ tensor = torch.from_numpy(data.tensor).float().to(self.device)
16
+ if len(data) == 1:
17
+ with torch.no_grad():
18
+ embed = self.model(tensor)
19
+ embed = embed / torch.norm(embed)
20
+ return {"reply": embed.cpu().numpy()}
21
+ else:
22
+ with torch.no_grad():
23
+ embeds = self.model(tensor)
24
+ embeds = embeds.view((1, data.num_games, -1)).to(self.device)
25
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
26
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
27
+ centroids_incl = centroids_incl.cpu().squeeze(1)
28
+ final_embeds = centroids_incl[0].numpy().tolist()
29
+ return {"reply": final_embeds}
params_data.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # can add variability in frames per batch
2
+ random_partial_low = 32
3
+ random_partial_high = 32
4
+ game_start = 0
5
+
6
+ # 32 moves as a window
7
+ partials_n_frames = 32
params_model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model parameters
2
+ residual_channels = 64
3
+ residual_blocks = 6
4
+ se_ratio = 8
5
+ vit_input_channels = 320 # input dimension to ViT
6
+ transformer_input_dim = 1024
7
+ model_embedding_size = 512
8
+ transformer_depth = 12
9
+ attention_heads = 8
10
+ mlp_dim = 2048
11
+ dim_head = 64 # k_q_v dims, risky to tune?
12
+ dropout = 0.
13
+ emb_dropout = 0.
14
+ similarity_weight_init = 10.
15
+ similarity_bias_init = -5.
16
+
17
+ ## Training parameters
18
+ learning_rate_init = 0.005
19
+ players_per_batch = 36
20
+ games_per_player = 10
21
+
22
+ v_players_per_batch = 40
23
+ v_games_per_player = 10
24
+ num_validate = 10
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ numpy