everydaytok commited on
Commit
8e0a66b
Β·
verified Β·
1 Parent(s): 1feca89

Update data_gen.py

Browse files
Files changed (1) hide show
  1. data_gen.py +56 -187
data_gen.py CHANGED
@@ -1,211 +1,80 @@
1
  """
2
- data_gen.py β€” Training / test data for the elastic mesh.
3
 
4
- OOD TEST DESIGN
5
- ───────────────
6
- SEEN during training : box_proj | halfspace | elastic
7
- UNSEEN (OOD) at test : sphere | simplex
8
 
9
- This lets us distinguish:
10
- β€’ Memorisation β†’ high acc on seen, low acc on unseen
11
- β€’ Geometry β†’ high acc on both (the real claim)
12
 
13
- Each sample: (A, B, C) where A=constraints, B=objectives, C=feasibility center.
14
- DIM = 64 (double from previous run, stress-tests before LLM scale).
15
- """
16
-
17
- import numpy as np
18
- import json, pathlib, argparse
19
- from typing import List, Dict
20
-
21
-
22
- DIM = 64
23
- SAMPLES_PER_TYPE = 1000 # Γ— 5 types = 5 000 total
24
-
25
-
26
- # ── UTILITIES ─────────────────────────────────────────────────────────────────
27
-
28
- def norm(v: np.ndarray) -> np.ndarray:
29
- return v / (np.linalg.norm(v) + 1e-12)
30
-
31
- def pack(*arrays, dim):
32
- v = np.concatenate(arrays)
33
- return v[:dim] if len(v) >= dim else np.pad(v, (0, dim - len(v)))
34
-
35
-
36
- # ── PROBLEM TYPE 1 (SEEN): BOX PROJECTION ────────────────────────────────────
37
- # C = clip(B, lo, hi)
38
- # A encodes the box bounds
39
-
40
- def gen_box(n, dim, rng):
41
- data = []
42
- for _ in range(n):
43
- center = rng.uniform(-2, 2, dim)
44
- half = rng.uniform(0.3, 2.0, dim)
45
- lo, hi = center - half, center + half
46
- B = rng.uniform(-4, 4, dim)
47
- C = np.clip(B, lo, hi)
48
- A = pack(lo[:dim//2], hi[:dim//2], dim=dim)
49
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'box_proj'})
50
- return data
51
-
52
-
53
- # ── PROBLEM TYPE 2 (SEEN): HALFSPACE PROJECTION ───────────────────────────────
54
- # C = B βˆ’ (nα΅€B βˆ’ b)Β·n (project B onto hyperplane nα΅€x = b)
55
-
56
- def gen_halfspace(n, dim, rng):
57
- data = []
58
- for _ in range(n):
59
- normal = norm(rng.standard_normal(dim))
60
- b = float(rng.uniform(-1, 1))
61
- B = rng.uniform(-3, 3, dim)
62
- C = B - (float(np.dot(normal, B)) - b) * normal
63
- A = normal.copy(); A[0] = b
64
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'halfspace'})
65
- return data
66
-
67
-
68
- # ── PROBLEM TYPE 3 (SEEN): ELASTIC BALANCE ────────────────────────────────────
69
- # C[j] = w[j]Β·a_center[j] + (1βˆ’w[j])Β·B[j] per-dimension soft trade-off
70
 
71
- def gen_elastic(n, dim, rng):
72
- data = []
73
- for _ in range(n):
74
- a_center = rng.uniform(-2, 2, dim)
75
- w = rng.uniform(0.05, 0.95, dim)
76
- B = rng.uniform(-3, 3, dim)
77
- C = w * a_center + (1.0 - w) * B
78
- A = pack(a_center[:dim//2], w[:dim//2], dim=dim)
79
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'elastic'})
80
- return data
81
-
82
-
83
- # ── PROBLEM TYPE 4 (OOD): SPHERE SURFACE ─────────────────────────────────────
84
- # C = center + rΒ·(Bβˆ’center)/β€–Bβˆ’centerβ€– (nearest point on sphere to B)
85
-
86
- def gen_sphere(n, dim, rng):
87
- data = []
88
- for _ in range(n):
89
- center = rng.uniform(-1.5, 1.5, dim)
90
- r = float(rng.uniform(1.0, 3.0))
91
- B = rng.uniform(-4, 4, dim)
92
- diff = B - center
93
- nd = np.linalg.norm(diff)
94
- if nd < 1e-10:
95
- diff = np.ones(dim) / np.sqrt(dim); nd = 1.0
96
- C = center + r * diff / nd
97
- A = center.copy(); A[0] = r
98
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'sphere'})
99
- return data
100
-
101
-
102
- # ── PROBLEM TYPE 5 (OOD): SIMPLEX PROJECTION ─────────────────────────────────
103
- # C = nearest point on probability simplex to B (Ξ£xα΅’=1, xα΅’β‰₯0)
104
-
105
- def _proj_simplex(v):
106
- n = len(v)
107
- u = np.sort(v)[::-1]
108
- cs = np.cumsum(u) - 1.0
109
- rho = int(np.where(u * np.arange(1, n+1) > cs)[0][-1])
110
- theta = cs[rho] / (rho + 1.0)
111
- return np.maximum(v - theta, 0.0)
112
 
113
- def gen_simplex(n, dim, rng):
114
- data = []
115
- for _ in range(n):
116
- A = np.ones(dim)
117
- B = rng.uniform(-1.0, 3.0, dim)
118
- C = _proj_simplex(B)
119
- data.append({'A': A.tolist(), 'B': B.tolist(), 'C': C.tolist(), 'type': 'simplex'})
120
- return data
121
 
 
122
 
123
- # ── ASSEMBLY ──────────────────────────────────────────────────────────────────
124
 
125
- SEEN_TYPES = {
126
- 'box_proj': gen_box,
127
- 'halfspace': gen_halfspace,
128
- 'elastic': gen_elastic,
 
 
129
  }
130
- OOD_TYPES = {
131
- 'sphere': gen_sphere,
132
- 'simplex': gen_simplex,
133
- }
134
- ALL_TYPES = {**SEEN_TYPES, **OOD_TYPES}
135
-
136
 
137
- def generate_all(n_per_type=SAMPLES_PER_TYPE, dim=DIM, seed=42):
138
- rng = np.random.default_rng(seed)
139
  data = []
140
- for fn in ALL_TYPES.values():
141
- data.extend(fn(n_per_type, dim, rng))
142
- idx = rng.permutation(len(data))
143
- return [data[i] for i in idx]
144
-
 
 
 
 
145
 
146
  if __name__ == '__main__':
147
  parser = argparse.ArgumentParser()
148
- parser.add_argument('--dim', type=int, default=DIM)
149
  parser.add_argument('--n', type=int, default=SAMPLES_PER_TYPE)
150
  parser.add_argument('--out', type=str, default='data')
151
  args = parser.parse_args()
152
 
153
- print(f"\n{'─'*55}")
154
- print(f" Generating {5 * args.n} samples | dim={args.dim}")
155
- print(f" SEEN : box_proj | halfspace | elastic")
156
- print(f" OOD : sphere | simplex")
157
- print(f"{'─'*55}")
158
-
159
- rng = np.random.default_rng(42)
160
-
161
- seen_data, ood_data = [], []
162
- for t, fn in SEEN_TYPES.items():
163
- seen_data.extend(fn(args.n, args.dim, rng))
164
- for t, fn in OOD_TYPES.items():
165
- ood_data.extend(fn(args.n, args.dim, rng))
166
-
167
- # Shuffle within splits
168
- si = rng.permutation(len(seen_data))
169
- oi = rng.permutation(len(ood_data))
170
- seen_data = [seen_data[i] for i in si]
171
- ood_data = [ood_data[i] for i in oi]
172
-
173
- # Train = 90% of SEEN only
174
- # Test = 10% of SEEN + ALL OOD (so model never trained on OOD)
175
- split = int(len(seen_data) * 0.9)
176
- train = seen_data[:split]
177
- test_seen = seen_data[split:]
178
- test = test_seen + ood_data
179
-
180
- # Re-shuffle test so seen/OOD are interleaved
181
- ti = rng.permutation(len(test))
182
- test = [test[i] for i in ti]
183
 
184
  out = pathlib.Path(args.out)
185
  out.mkdir(exist_ok=True)
186
- with open(out / 'train.json', 'w') as f: json.dump(train, f)
187
- with open(out / 'test.json', 'w') as f: json.dump(test, f)
188
 
189
  from collections import Counter
190
- tr_types = Counter(d['type'] for d in train)
191
- te_types = Counter(d['type'] for d in test)
192
-
193
- print(f"\n {'Type':<14} {'Train':>7} {'Test':>7} {'Split'}")
194
- print(f" {'─'*14} {'─'*7} {'─'*7} {'─'*10}")
195
- for t in ALL_TYPES:
196
- label = 'OOD βœ—' if t in OOD_TYPES else 'SEEN βœ“'
197
- print(f" {t:<14} {tr_types.get(t,0):>7} {te_types.get(t,0):>7} {label}")
198
- print(f"\n Total train={len(train)} test={len(test)}\n")
199
-
200
- # Quick sanity: verify C is geometrically correct for first sample per type
201
- print(f" Sanity check:")
202
- seen_set = set()
203
- for d in train + test:
204
- t = d['type']
205
- if t in seen_set: continue
206
- seen_set.add(t)
207
- A, B, C = map(np.array, [d['A'], d['B'], d['C']])
208
- print(f" [{t:<12}] β€–Aβ€–={np.linalg.norm(A):.2f} "
209
- f"β€–Bβ€–={np.linalg.norm(B):.2f} β€–Cβ€–={np.linalg.norm(C):.2f}")
210
-
211
- print(f"\n Saved β†’ {out}/train.json {out}/test.json\n")
 
1
  """
2
+ data_gen.py
3
 
4
+ Scalar mesh test data. Each sample is (a, b, c) β€” all single floats.
 
 
 
5
 
6
+ A ∈ [0.1, 0.9] β€” constraint scalar (top input)
7
+ B ∈ [0.1, 0.9] β€” objective scalar (bottom input)
8
+ C β€” feasibility center (what the mesh must learn to produce)
9
 
10
+ All C values kept in [0.1, 0.9] so the scalar Hooke mesh can represent them
11
+ without needing to amplify beyond the input range.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ SEEN during training : heavy_a | avg | diff
14
+ OOD (test only) : heavy_b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ If the mesh generalises:
17
+ It learned "weighted combination of A and B" as a concept
18
+ β†’ it applies unseen weighting (B-heavy) without training on it.
19
+ """
 
 
 
 
20
 
21
+ import numpy as np, json, pathlib, random, argparse
22
 
23
+ SAMPLES_PER_TYPE = 2500 # Γ— 4 types = 10 000 total
24
 
25
+ DATASETS = {
26
+ # name : (lambda, seen?)
27
+ 'heavy_a': (lambda a, b: 0.8*a + 0.2*b, True),
28
+ 'avg': (lambda a, b: 0.5*a + 0.5*b, True),
29
+ 'diff': (lambda a, b: 0.5 + 0.4*(a - b), True), # signed diff, offset to [0.1,0.9]
30
+ 'heavy_b': (lambda a, b: 0.2*a + 0.8*b, False), # OOD
31
  }
 
 
 
 
 
 
32
 
33
+ def generate(n_per=SAMPLES_PER_TYPE, seed=42):
34
+ rng = np.random.default_rng(seed)
35
  data = []
36
+ for dtype, (fn, _) in DATASETS.items():
37
+ for _ in range(n_per):
38
+ a = float(rng.uniform(0.1, 0.9))
39
+ b = float(rng.uniform(0.1, 0.9))
40
+ c = fn(a, b)
41
+ data.append({'a': round(a,4), 'b': round(b,4),
42
+ 'c': round(c,4), 'type': dtype})
43
+ random.shuffle(data)
44
+ return data
45
 
46
  if __name__ == '__main__':
47
  parser = argparse.ArgumentParser()
 
48
  parser.add_argument('--n', type=int, default=SAMPLES_PER_TYPE)
49
  parser.add_argument('--out', type=str, default='data')
50
  args = parser.parse_args()
51
 
52
+ data = generate(args.n)
53
+
54
+ # Split: train = SEEN only (90%), test = 10% SEEN + ALL OOD
55
+ seen = [d for d in data if DATASETS[d['type']][1]]
56
+ ood = [d for d in data if not DATASETS[d['type']][1]]
57
+
58
+ split = int(len(seen) * 0.9)
59
+ train = seen[:split]
60
+ test = seen[split:] + ood
61
+
62
+ random.shuffle(test)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  out = pathlib.Path(args.out)
65
  out.mkdir(exist_ok=True)
66
+ with open(out/'train.json','w') as f: json.dump(train, f)
67
+ with open(out/'test.json', 'w') as f: json.dump(test, f)
68
 
69
  from collections import Counter
70
+ tr = Counter(d['type'] for d in train)
71
+ te = Counter(d['type'] for d in test)
72
+
73
+ print(f"\n{'─'*50}")
74
+ print(f" {'Type':<12} {'Train':>7} {'Test':>7} Split")
75
+ print(f" {'─'*12} {'─'*7} {'─'*7} {'─'*8}")
76
+ for t, (fn, seen_flag) in DATASETS.items():
77
+ label = 'SEEN' if seen_flag else 'OOD βœ—'
78
+ print(f" {t:<12} {tr.get(t,0):>7} {te.get(t,0):>7} {label}")
79
+ print(f"\n Train total: {len(train)} Test total: {len(test)}")
80
+ print(f" Saved β†’ {out}/train.json {out}/test.json\n")