Maxlegrec commited on
Commit
8378d33
·
verified ·
1 Parent(s): d7ecccf

Upload attn_map.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. attn_map.py +116 -0
attn_map.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ move = np.arange(1, 8)
5
+
6
+ diag = np.array([
7
+ move + move*8,
8
+ move - move*8,
9
+ move*-1 - move*8,
10
+ move*-1 + move*8
11
+ ])
12
+
13
+ orthog = np.array([
14
+ move,
15
+ move*-8,
16
+ move*-1,
17
+ move*8
18
+ ])
19
+
20
+ knight = np.array([
21
+ [2 + 1*8],
22
+ [2 - 1*8],
23
+ [1 - 2*8],
24
+ [-1 - 2*8],
25
+ [-2 - 1*8],
26
+ [-2 + 1*8],
27
+ [-1 + 2*8],
28
+ [1 + 2*8]
29
+ ])
30
+
31
+ promos = np.array([2*8, 3*8, 4*8])
32
+ pawn_promotion = np.array([
33
+ -1 + promos,
34
+ 0 + promos,
35
+ 1 + promos
36
+ ])
37
+
38
+ def make_map():
39
+ """theoretically possible put-down squares (numpy array) for each pick-up square (list element).
40
+ squares are [0, 1, ..., 63] for [a1, b1, ..., h8]. squares after 63 are for promotion squares.
41
+ each successive "row" beyond 63 (ie. 64:72, 72:80, 80:88) are for over-promotions to queen, rook, and bishop;
42
+ respectively. a pawn traverse to row 56:64 signifies a "default" promotion to a knight."""
43
+ traversable = []
44
+ for i in range(8):
45
+ for j in range(8):
46
+ sq = (8*i + j)
47
+ traversable.append(
48
+ sq +
49
+ np.sort(
50
+ np.int32(
51
+ np.concatenate((
52
+ orthog[0][:7-j], orthog[2][:j], orthog[1][:i], orthog[3][:7-i],
53
+ diag[0][:np.min((7-i, 7-j))], diag[3][:np.min((7-i, j))],
54
+ diag[1][:np.min((i, 7-j))], diag[2][:np.min((i, j))],
55
+ knight[0] if i < 7 and j < 6 else [], knight[1] if i > 0 and j < 6 else [],
56
+ knight[2] if i > 1 and j < 7 else [], knight[3] if i > 1 and j > 0 else [],
57
+ knight[4] if i > 0 and j > 1 else [], knight[5] if i < 7 and j > 1 else [],
58
+ knight[6] if i < 6 and j > 0 else [], knight[7] if i < 6 and j < 7 else [],
59
+ pawn_promotion[0] if i == 6 and j > 0 else [],
60
+ pawn_promotion[1] if i == 6 else [],
61
+ pawn_promotion[2] if i == 6 and j < 7 else [],
62
+ ))
63
+ )
64
+ )
65
+ )
66
+ z = np.zeros((64*64+8*24, 1858), dtype=np.int32)
67
+ apm_out = np.zeros((1858,), dtype=np.int32)
68
+ apm_in = np.zeros((64*64+8*24), dtype=np.int32)
69
+ # first loop for standard moves (for i in 0:1858, stride by 1)
70
+ i = 0
71
+ for pickup_index, putdown_indices in enumerate(traversable):
72
+ for putdown_index in putdown_indices:
73
+ if putdown_index < 64:
74
+ du_idx = putdown_index + (64*pickup_index)
75
+ z[du_idx, i] = 1
76
+ apm_out[i] = du_idx
77
+ apm_in[du_idx] = i
78
+ i += 1
79
+ # second loop for promotions (for i in 1792:1858, stride by ls[j])
80
+ j = 0
81
+ j1 = np.array([3, -2, 3, -2, 3])
82
+ j2 = np.array([3, 3, -5, 3, 3, -5, 3, 3, 1])
83
+ ls = np.append(j1, 1)
84
+ for k in range(6):
85
+ ls = np.append(ls, j2)
86
+ ls = np.append(ls, j1)
87
+ ls = np.append(ls, 0)
88
+ for pickup_index, putdown_indices in enumerate(traversable):
89
+ for putdown_index in putdown_indices:
90
+ if putdown_index >= 64:
91
+ pickup_file = pickup_index % 8
92
+ promotion_file = putdown_index % 8
93
+ promotion_rank = (putdown_index // 8) - 8
94
+ du_idx = 4096 + pickup_file*24 + (promotion_file*3+promotion_rank)
95
+ z[du_idx, i] = 1
96
+ apm_out[i] = du_idx
97
+ apm_in[du_idx] = i
98
+ i += ls[j]
99
+ j += 1
100
+
101
+ return z, apm_out, apm_in
102
+
103
+ apm_map, apm_out, apm_in = make_map()
104
+
105
+ def set_zero_sum(x):
106
+ x = x + (1 - torch.sum(x, dim=1, keepdim=True)) * (1.0 / 64)
107
+ return x
108
+
109
+ def get_up_down(moves):
110
+ apm_map_tensor = torch.from_numpy(apm_map)
111
+ out = torch.matmul(moves, apm_map_tensor.T.float())
112
+ out = out[..., :64*64]
113
+ out = out.view(-1, 64, 64)
114
+ pu = set_zero_sum(torch.sum(out, dim=-1))
115
+ pd = set_zero_sum(torch.sum(out, dim=-2))
116
+ return pu, pd