Adrian Gabriel commited on
Commit
f5a78dd
·
1 Parent(s): b0b05a2

Latest additions to TabPFN

Browse files
Files changed (1) hide show
  1. models/current_code.py +7 -2
models/current_code.py CHANGED
@@ -47,6 +47,7 @@ def group(X):
47
  else:
48
  X_encoded[idx][1] = group_matmul.data + + E_feat.data[1]
49
  col = 0
 
50
  idx += 1
51
  X_encoded_tensor = Tensor(X_encoded)
52
  return X_encoded_tensor
@@ -68,6 +69,7 @@ def label_embeddings(y_train):
68
  for (idx, row) in enumerate(y_train.data):
69
  res = Tensor((row)).matmul(W_y)
70
  lbl_embds[idx] = res.data
 
71
 
72
  return Tensor(lbl_embds)
73
 
@@ -138,15 +140,17 @@ def column_attention_inplace(E: Tensor):
138
  A = softmax.forward(scores, dim=-1) # (3,3)
139
  O = A.matmul(V) # (3,4)
140
 
 
 
141
  # In-place residual update of ALL tokens
142
  E.data[s] = E.data[s] + O.data
143
 
144
 
145
  column_attention_inplace(E)
146
- box("Updated Logits", E, "5")
147
 
148
 
149
- def row_attention_inplace(E: Tensor, W_q: Tensor, W_k: Tensor, W_v: Tensor, single_eval_pos: int):
150
  """
151
  In-place row attention:
152
  For each token slot t:
@@ -175,4 +179,5 @@ def row_attention_inplace(E: Tensor, W_q: Tensor, W_k: Tensor, W_v: Tensor, sing
175
  O = A.matmul(V) # (S, D)
176
 
177
  # In-place residual update for this token slot
 
178
  E.data[:, t, :] = E.data[:, t, :] + O.data
 
47
  else:
48
  X_encoded[idx][1] = group_matmul.data + + E_feat.data[1]
49
  col = 0
50
+ box("grouping", [group_window, group_matmul])
51
  idx += 1
52
  X_encoded_tensor = Tensor(X_encoded)
53
  return X_encoded_tensor
 
69
  for (idx, row) in enumerate(y_train.data):
70
  res = Tensor((row)).matmul(W_y)
71
  lbl_embds[idx] = res.data
72
+ box("test", [res], "5")
73
 
74
  return Tensor(lbl_embds)
75
 
 
140
  A = softmax.forward(scores, dim=-1) # (3,3)
141
  O = A.matmul(V) # (3,4)
142
 
143
+ box("test", [Q, K, V, scores, A, O], "5")
144
+
145
  # In-place residual update of ALL tokens
146
  E.data[s] = E.data[s] + O.data
147
 
148
 
149
  column_attention_inplace(E)
150
+ box("Updated Logits", E + 0, "5")
151
 
152
 
153
+ def row_attention_inplace(E: Tensor, single_eval_pos: int):
154
  """
155
  In-place row attention:
156
  For each token slot t:
 
179
  O = A.matmul(V) # (S, D)
180
 
181
  # In-place residual update for this token slot
182
+ box("test", [Q, K, V, scores, A, O], "5")
183
  E.data[:, t, :] = E.data[:, t, :] + O.data