faz9_recur: gen_khop ZINCIR gorevi (perm dongu-kisayolu kaldirildi; gercek k-hop)
Browse files- code/kod/faz9_recur.py +12 -12
code/kod/faz9_recur.py
CHANGED
|
@@ -137,21 +137,21 @@ def smoke():
|
|
| 137 |
|
| 138 |
# ───────────── FAZ B: derinlik-gerektiren görev + fixed-vs-recurrent (GO/NO-GO) ─────────────
|
| 139 |
def gen_khop(batch, n_keys, k, device="cpu"):
|
| 140 |
-
"""k-hop
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
SEP = n_keys; seqs = []; tgts = []
|
| 144 |
for _ in range(batch):
|
| 145 |
-
|
|
|
|
| 146 |
toks = []
|
| 147 |
-
for
|
| 148 |
-
toks += [
|
| 149 |
-
|
| 150 |
-
toks += [SEP,
|
| 151 |
-
|
| 152 |
-
for _ in range(k):
|
| 153 |
-
cur = int(perm[cur])
|
| 154 |
-
seqs.append(toks); tgts.append(cur)
|
| 155 |
return (torch.tensor(seqs, device=device), torch.tensor(tgts, device=device))
|
| 156 |
|
| 157 |
|
|
|
|
| 137 |
|
| 138 |
# ───────────── FAZ B: derinlik-gerektiren görev + fixed-vs-recurrent (GO/NO-GO) ─────────────
|
| 139 |
def gen_khop(batch, n_keys, k, device="cpu"):
|
| 140 |
+
"""k-hop ZİNCİR traversali (döngü-kısayolu YOK). Düğümler tek bir zincir (chain=randperm),
|
| 141 |
+
kenarlar [chain[i], chain[i+1]] KARIŞIK sırada → SEP → start=chain[p]. Hedef=chain[p+k], p∈[0,n-k).
|
| 142 |
+
Zincir döngüsüz → hedef≠start her zaman → 'başlangıcı kopyala' kısayolu yok → gerçek k sıralı hop gerek.
|
| 143 |
+
Vocab 0..n_keys-1 + SEP(n_keys). Loss SADECE son pozisyon. (Eski permütasyon görevi σ^k(s)=s döngü-kısayoluyla kirleniyordu.)"""
|
| 144 |
+
assert n_keys > k, "n_keys > k olmalı"
|
| 145 |
SEP = n_keys; seqs = []; tgts = []
|
| 146 |
for _ in range(batch):
|
| 147 |
+
chain = torch.randperm(n_keys)
|
| 148 |
+
edges = [[int(chain[i]), int(chain[i + 1])] for i in range(n_keys - 1)]
|
| 149 |
toks = []
|
| 150 |
+
for j in torch.randperm(len(edges)).tolist():
|
| 151 |
+
toks += edges[j]
|
| 152 |
+
p = int(torch.randint(0, n_keys - k, (1,)))
|
| 153 |
+
toks += [SEP, int(chain[p])]
|
| 154 |
+
seqs.append(toks); tgts.append(int(chain[p + k]))
|
|
|
|
|
|
|
|
|
|
| 155 |
return (torch.tensor(seqs, device=device), torch.tensor(tgts, device=device))
|
| 156 |
|
| 157 |
|