amirali1985 commited on
Commit
216e5a0
Β·
verified Β·
1 Parent(s): 06cce54

Upload modular/code/modular_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modular/code/modular_data.py +112 -0
modular/code/modular_data.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular arithmetic dataset β€” a + b mod p.
3
+
4
+ Matches Nanda et al. (2023) exactly: p=113, 30% train split, fixed seed.
5
+
6
+ Token layout (contiguous integer IDs, no external tokenizer):
7
+ 0 … p-1 β†’ numbers
8
+ p β†’ '+'
9
+ p+1 β†’ '='
10
+ p+2 β†’ PAD
11
+ total trajectory vocab = p+3
12
+
13
+ Sequence format: [a, +, b, =, result] (5 tokens, prompt_len=4)
14
+ """
15
+ import json
16
+ import random
17
+ from dataclasses import dataclass, asdict
18
+ from pathlib import Path
19
+ from typing import List, Tuple
20
+
21
+ P: int = 113 # Nanda's prime
22
+
23
+ PLUS = P # 113
24
+ EQUALS = P + 1 # 114
25
+ PAD = P + 2 # 115
26
+ VOCAB_SIZE = P + 3 # 116 β€” trajectory vocab passed to SorlModelWrapper
27
+
28
+ PROMPT_LEN = 4 # [a, +, b, =]
29
+ ANSWER_LEN = 1 # [result]
30
+ SEQ_LEN = 5
31
+
32
+ EVAL_CACHE_DIR = Path(__file__).resolve().parent / "eval_sets"
33
+
34
+
35
+ @dataclass
36
+ class ModularExample:
37
+ tokens: List[int] # [a, +, b, =, result]
38
+ a: int
39
+ b: int
40
+ result: int
41
+
42
+
43
+ def make_example(a: int, b: int, p: int = P) -> ModularExample:
44
+ return ModularExample(tokens=[a, PLUS, b, EQUALS, (a + b) % p], a=a, b=b, result=(a + b) % p)
45
+
46
+
47
+ def generate_dataset(p: int = P, train_fraction: float = 0.3, seed: int = 42) -> Tuple[List[ModularExample], List[ModularExample]]:
48
+ """All pΒ² pairs shuffled and split into fixed train/test."""
49
+ rng = random.Random(seed)
50
+ all_pairs = [(a, b) for a in range(p) for b in range(p)]
51
+ rng.shuffle(all_pairs)
52
+ n_train = int(len(all_pairs) * train_fraction)
53
+ train = [make_example(a, b, p) for a, b in all_pairs[:n_train]]
54
+ test = [make_example(a, b, p) for a, b in all_pairs[n_train:]]
55
+ return train, test
56
+
57
+
58
+ DATASET_REPO = "thoughtworks/arithmetic-sorl-data"
59
+
60
+
61
+ def get_eval_set(p: int = P, seed: int = 42) -> List[ModularExample]:
62
+ """Load test set β€” local cache β†’ HF download β†’ regenerate."""
63
+ cache = EVAL_CACHE_DIR / f"modular_p{p}_test_seed{seed}.json"
64
+ if cache.exists():
65
+ with open(cache) as f:
66
+ return [ModularExample(**ex) for ex in json.load(f)]
67
+
68
+ EVAL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
69
+ try:
70
+ from huggingface_hub import hf_hub_download
71
+ path = hf_hub_download(
72
+ repo_id=DATASET_REPO, repo_type="dataset",
73
+ filename=f"modular/test_seed{seed}.json",
74
+ )
75
+ import shutil
76
+ shutil.copy(path, cache)
77
+ with open(cache) as f:
78
+ return [ModularExample(**ex) for ex in json.load(f)]
79
+ except Exception:
80
+ pass
81
+
82
+ _, test = generate_dataset(p=p, seed=seed)
83
+ with open(cache, "w") as f:
84
+ json.dump([asdict(ex) for ex in test], f)
85
+ return test
86
+
87
+
88
+ def get_train_set(p: int = P, seed: int = 42) -> List[ModularExample]:
89
+ """Load train set β€” local cache β†’ HF download β†’ regenerate."""
90
+ cache = EVAL_CACHE_DIR / f"modular_p{p}_train_seed{seed}.json"
91
+ if cache.exists():
92
+ with open(cache) as f:
93
+ return [ModularExample(**ex) for ex in json.load(f)]
94
+
95
+ EVAL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
96
+ try:
97
+ from huggingface_hub import hf_hub_download
98
+ path = hf_hub_download(
99
+ repo_id=DATASET_REPO, repo_type="dataset",
100
+ filename=f"modular/train_seed{seed}.json",
101
+ )
102
+ import shutil
103
+ shutil.copy(path, cache)
104
+ with open(cache) as f:
105
+ return [ModularExample(**ex) for ex in json.load(f)]
106
+ except Exception:
107
+ pass
108
+
109
+ train, _ = generate_dataset(p=p, seed=seed)
110
+ with open(cache, "w") as f:
111
+ json.dump([asdict(ex) for ex in train], f)
112
+ return train