Spaces:
Running
Running
Adrian Gabriel
commited on
Commit
·
f5a78dd
1
Parent(s):
b0b05a2
Latest additions to TabPFN
Browse files- models/current_code.py +7 -2
models/current_code.py
CHANGED
|
@@ -47,6 +47,7 @@ def group(X):
|
|
| 47 |
else:
|
| 48 |
X_encoded[idx][1] = group_matmul.data + + E_feat.data[1]
|
| 49 |
col = 0
|
|
|
|
| 50 |
idx += 1
|
| 51 |
X_encoded_tensor = Tensor(X_encoded)
|
| 52 |
return X_encoded_tensor
|
|
@@ -68,6 +69,7 @@ def label_embeddings(y_train):
|
|
| 68 |
for (idx, row) in enumerate(y_train.data):
|
| 69 |
res = Tensor((row)).matmul(W_y)
|
| 70 |
lbl_embds[idx] = res.data
|
|
|
|
| 71 |
|
| 72 |
return Tensor(lbl_embds)
|
| 73 |
|
|
@@ -138,15 +140,17 @@ def column_attention_inplace(E: Tensor):
|
|
| 138 |
A = softmax.forward(scores, dim=-1) # (3,3)
|
| 139 |
O = A.matmul(V) # (3,4)
|
| 140 |
|
|
|
|
|
|
|
| 141 |
# In-place residual update of ALL tokens
|
| 142 |
E.data[s] = E.data[s] + O.data
|
| 143 |
|
| 144 |
|
| 145 |
column_attention_inplace(E)
|
| 146 |
-
box("Updated Logits", E, "5")
|
| 147 |
|
| 148 |
|
| 149 |
-
def row_attention_inplace(E: Tensor,
|
| 150 |
"""
|
| 151 |
In-place row attention:
|
| 152 |
For each token slot t:
|
|
@@ -175,4 +179,5 @@ def row_attention_inplace(E: Tensor, W_q: Tensor, W_k: Tensor, W_v: Tensor, sing
|
|
| 175 |
O = A.matmul(V) # (S, D)
|
| 176 |
|
| 177 |
# In-place residual update for this token slot
|
|
|
|
| 178 |
E.data[:, t, :] = E.data[:, t, :] + O.data
|
|
|
|
| 47 |
else:
|
| 48 |
X_encoded[idx][1] = group_matmul.data + + E_feat.data[1]
|
| 49 |
col = 0
|
| 50 |
+
box("grouping", [group_window, group_matmul])
|
| 51 |
idx += 1
|
| 52 |
X_encoded_tensor = Tensor(X_encoded)
|
| 53 |
return X_encoded_tensor
|
|
|
|
| 69 |
for (idx, row) in enumerate(y_train.data):
|
| 70 |
res = Tensor((row)).matmul(W_y)
|
| 71 |
lbl_embds[idx] = res.data
|
| 72 |
+
box("test", [res], "5")
|
| 73 |
|
| 74 |
return Tensor(lbl_embds)
|
| 75 |
|
|
|
|
| 140 |
A = softmax.forward(scores, dim=-1) # (3,3)
|
| 141 |
O = A.matmul(V) # (3,4)
|
| 142 |
|
| 143 |
+
box("test", [Q, K, V, scores, A, O], "5")
|
| 144 |
+
|
| 145 |
# In-place residual update of ALL tokens
|
| 146 |
E.data[s] = E.data[s] + O.data
|
| 147 |
|
| 148 |
|
| 149 |
column_attention_inplace(E)
|
| 150 |
+
box("Updated Logits", E + 0, "5")
|
| 151 |
|
| 152 |
|
| 153 |
+
def row_attention_inplace(E: Tensor, single_eval_pos: int):
|
| 154 |
"""
|
| 155 |
In-place row attention:
|
| 156 |
For each token slot t:
|
|
|
|
| 179 |
O = A.matmul(V) # (S, D)
|
| 180 |
|
| 181 |
# In-place residual update for this token slot
|
| 182 |
+
box("test", [Q, K, V, scores, A, O], "5")
|
| 183 |
E.data[:, t, :] = E.data[:, t, :] + O.data
|