Spaces:
Runtime error
Runtime error
Update data_gen.py
Browse files- data_gen.py +60 -40
data_gen.py
CHANGED
|
@@ -1,54 +1,74 @@
|
|
| 1 |
"""
|
| 2 |
-
data_gen.py
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
-
import numpy as np
|
| 6 |
-
import json, pathlib, random, argparse
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
rng = np.random.default_rng(seed)
|
| 13 |
data = []
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
c_diff = 0.5 + 0.4 * (a - b)
|
| 26 |
-
data.append({'a': round_list(a), 'b': round_list(b), 'c': round_list(c_diff), 'type': 'diff'})
|
| 27 |
-
|
| 28 |
-
# 3. Lateral Routing (A rolls right by 1, B rolls left by 1)
|
| 29 |
-
# This forces the mesh to use its diagonal triangulated springs!
|
| 30 |
-
c_route = 0.5 * np.roll(a, 1) + 0.5 * np.roll(b, -1)
|
| 31 |
-
data.append({'a': round_list(a), 'b': round_list(b), 'c': round_list(c_route), 'type': 'route'})
|
| 32 |
-
|
| 33 |
random.shuffle(data)
|
| 34 |
return data
|
| 35 |
|
| 36 |
-
def round_list(arr):
|
| 37 |
-
return [round(float(x), 4) for x in arr]
|
| 38 |
-
|
| 39 |
if __name__ == '__main__':
|
| 40 |
parser = argparse.ArgumentParser()
|
| 41 |
-
parser.add_argument('--
|
| 42 |
-
parser.add_argument('--
|
|
|
|
| 43 |
args = parser.parse_args()
|
| 44 |
|
| 45 |
-
data = generate(args.n, args.
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
out.mkdir(exist_ok=True)
|
| 52 |
-
with open(out/'train.json','w') as f: json.dump(train, f)
|
| 53 |
-
with open(out/'test.json',
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
data_gen.py v5
|
| 3 |
+
|
| 4 |
+
Each sample: (A, B, C) where A,B,C ∈ ℝ^n, all values in [0.1, 0.9]
|
| 5 |
+
|
| 6 |
+
For n=1 these are plain scalars.
|
| 7 |
+
For n>1 each dimension is an independent weighted combination of A[i] and B[i],
|
| 8 |
+
so the mesh must learn to route each channel correctly through the bulge.
|
| 9 |
+
|
| 10 |
+
SEEN during training : heavy_a | avg | diff
|
| 11 |
+
OOD (test only) : heavy_b ← has never been seen, tests geometric generalisation
|
| 12 |
"""
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
import numpy as np, json, pathlib, random, argparse
|
| 15 |
+
from collections import Counter
|
| 16 |
|
| 17 |
+
N = 1 # embedding dimension (1 = pure scalar, set >1 for vector)
|
| 18 |
+
SAMPLES_PER_TYPE = 2500
|
| 19 |
+
|
| 20 |
+
DATASETS = {
|
| 21 |
+
'heavy_a': (lambda a, b: 0.8*a + 0.2*b, True),
|
| 22 |
+
'avg': (lambda a, b: 0.5*a + 0.5*b, True),
|
| 23 |
+
'diff': (lambda a, b: 0.5 + 0.4*(a - b), True), # maps to [0.1, 0.9]
|
| 24 |
+
'heavy_b': (lambda a, b: 0.2*a + 0.8*b, False), # OOD
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
def generate(n=N, n_per=SAMPLES_PER_TYPE, seed=42):
|
| 28 |
rng = np.random.default_rng(seed)
|
| 29 |
data = []
|
| 30 |
+
for dtype, (fn, _) in DATASETS.items():
|
| 31 |
+
for _ in range(n_per):
|
| 32 |
+
a = rng.uniform(0.1, 0.9, n).tolist()
|
| 33 |
+
b = rng.uniform(0.1, 0.9, n).tolist()
|
| 34 |
+
c = [round(float(fn(a[i], b[i])), 4) for i in range(n)]
|
| 35 |
+
data.append({
|
| 36 |
+
'A': [round(v, 4) for v in a],
|
| 37 |
+
'B': [round(v, 4) for v in b],
|
| 38 |
+
'C': c,
|
| 39 |
+
'type': dtype,
|
| 40 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
random.shuffle(data)
|
| 42 |
return data
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
if __name__ == '__main__':
|
| 45 |
parser = argparse.ArgumentParser()
|
| 46 |
+
parser.add_argument('--n', type=int, default=N, help='input dimensions')
|
| 47 |
+
parser.add_argument('--spt', type=int, default=SAMPLES_PER_TYPE,help='samples per type')
|
| 48 |
+
parser.add_argument('--out', type=str, default='data')
|
| 49 |
args = parser.parse_args()
|
| 50 |
|
| 51 |
+
data = generate(args.n, args.spt)
|
| 52 |
+
|
| 53 |
+
seen = [d for d in data if DATASETS[d['type']][1]]
|
| 54 |
+
ood = [d for d in data if not DATASETS[d['type']][1]]
|
| 55 |
+
|
| 56 |
+
split = int(len(seen) * 0.9)
|
| 57 |
+
train = seen[:split]
|
| 58 |
+
test = seen[split:] + ood
|
| 59 |
+
random.shuffle(test)
|
| 60 |
+
|
| 61 |
+
out = pathlib.Path(args.out)
|
| 62 |
out.mkdir(exist_ok=True)
|
| 63 |
+
with open(out / 'train.json', 'w') as f: json.dump(train, f)
|
| 64 |
+
with open(out / 'test.json', 'w') as f: json.dump(test, f)
|
| 65 |
+
|
| 66 |
+
tr = Counter(d['type'] for d in train)
|
| 67 |
+
te = Counter(d['type'] for d in test)
|
| 68 |
+
|
| 69 |
+
print(f"\n dim={args.n} total={len(data)}")
|
| 70 |
+
print(f" {'Type':<12} {'Train':>7} {'Test':>7} Split")
|
| 71 |
+
print(f" {'─'*12} {'─'*7} {'─'*7} {'─'*8}")
|
| 72 |
+
for t, (_, seen_flag) in DATASETS.items():
|
| 73 |
+
print(f" {t:<12} {tr.get(t,0):>7} {te.get(t,0):>7} {'SEEN' if seen_flag else 'OOD ✗'}")
|
| 74 |
+
print(f"\n → {out}/train.json {out}/test.json\n")
|