| """Download + tokenize instruction data for HYDRA SFT. |
| |
| Writes int16 token shards to `data/sft/shard_XXX.bin` plus a |
| `data/sft/meta.json` with counts + special-token mapping. |
| |
| Chat format (vocab's 4 reserved special tokens are repurposed): |
| <BOS=8188> <|user|=8189>\n{instruction}\n{input?}\n <|assistant|=8190>\n |
| {output}<|end|=8191>\n |
| |
| Special-token IDs are constants derived from the tokenizer (they are the |
| last 4 IDs in an 8192-vocab). They are stored in meta.json for the SFT |
| script to read. |
| |
| Sources (tried in order): |
| 1. yahma/alpaca-cleaned (~52K pairs via HF parquet auto-convert) |
| 2. databricks/databricks-dolly-15k (~15K pairs) |
| 3. Hard-coded 200 simple Q&A pairs (offline backup) |
| |
| Usage: |
| python scripts/download_sft_data.py # full download |
| python scripts/download_sft_data.py --test # small smoke run |
| python scripts/download_sft_data.py --offline # skip network; use backup |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import pickle |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import requests |
|
|
| |
| _REPO_ROOT = Path(__file__).resolve().parent.parent |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
|
|
| |
| |
| |
|
|
| CACHE_DIR = Path.home() / ".cache" / "autoresearch" |
| TOKENIZER_PKL = CACHE_DIR / "tokenizer" / "tokenizer.pkl" |
|
|
| SFT_DIR = _REPO_ROOT / "data" / "sft" |
| SFT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| |
| BOS_ID = 8188 |
| USER_ID = 8189 |
| ASSISTANT_ID = 8190 |
| END_ID = 8191 |
|
|
| |
| TOKENS_PER_SHARD = 1_048_576 |
| DTYPE = np.int16 |
|
|
| TARGET_TOKENS_DEFAULT = 15_000_000 |
| TARGET_TOKENS_TEST = 1_500_000 |
|
|
| |
| ALPACA_URL = ( |
| "https://huggingface.co/api/datasets/yahma/alpaca-cleaned/parquet/" |
| "default/train/0.parquet" |
| ) |
| DOLLY_URL = ( |
| "https://huggingface.co/api/datasets/databricks/databricks-dolly-15k/" |
| "parquet/default/train/0.parquet" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| _BACKUP_QA = [ |
| ("What is the capital of France?", "The capital of France is Paris."), |
| ("What is the capital of Germany?", "The capital of Germany is Berlin."), |
| ("What is the capital of Japan?", "The capital of Japan is Tokyo."), |
| ("What is the capital of Italy?", "The capital of Italy is Rome."), |
| ("What is the capital of Spain?", "The capital of Spain is Madrid."), |
| ("What is the capital of England?", "The capital of England is London."), |
| ("What is the capital of Canada?", "The capital of Canada is Ottawa."), |
| ("What is the capital of Australia?", "The capital of Australia is Canberra."), |
| ("What is 2 plus 2?", "Two plus two equals four."), |
| ("What is 5 plus 5?", "Five plus five equals ten."), |
| ("What is 3 times 4?", "Three times four equals twelve."), |
| ("What is 10 minus 3?", "Ten minus three equals seven."), |
| ("Name a color.", "Red is a color. Other colors include blue, green, and yellow."), |
| ("Name a fruit.", "Apple is a fruit. Other fruits include banana, orange, and grape."), |
| ("Name an animal.", "A dog is an animal. Other animals include cat, horse, and cow."), |
| ("What do birds do?", "Birds fly through the air using their wings."), |
| ("What do fish do?", "Fish swim in water using their fins and tails."), |
| ("What is water made of?", "Water is made of hydrogen and oxygen."), |
| ("What is the sun?", "The sun is a star at the center of our solar system."), |
| ("What is the moon?", "The moon is Earth's natural satellite."), |
| ("What is a cloud?", "A cloud is a collection of tiny water droplets in the sky."), |
| ("What is rain?", "Rain is water that falls from clouds to the ground."), |
| ("What is snow?", "Snow is frozen water that falls from clouds in cold weather."), |
| ("What is wind?", "Wind is moving air."), |
| ("What is fire?", "Fire is a chemical reaction that produces heat and light."), |
| ("What is ice?", "Ice is frozen water."), |
| ("Say hello.", "Hello! How can I help you today?"), |
| ("Say goodbye.", "Goodbye! Have a great day."), |
| ("Tell me a short story.", "Once upon a time, a brave rabbit lived in the forest. The end."), |
| ("Tell me a joke.", "Why did the chicken cross the road? To get to the other side."), |
| ("Who wrote Hamlet?", "William Shakespeare wrote the play Hamlet."), |
| ("Who wrote Romeo and Juliet?", "William Shakespeare wrote Romeo and Juliet."), |
| ("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa."), |
| ("When did World War 2 end?", "World War 2 ended in 1945."), |
| ("What is gravity?", "Gravity is the force that pulls objects toward the Earth."), |
| ("What is the speed of light?", "The speed of light is approximately 300,000 kilometers per second."), |
| ("What is the largest planet?", "Jupiter is the largest planet in our solar system."), |
| ("What is the smallest planet?", "Mercury is the smallest planet in our solar system."), |
| ("At what temperature does water boil?", "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit."), |
| ("At what temperature does water freeze?", "Water freezes at 0 degrees Celsius or 32 degrees Fahrenheit."), |
| ("How many legs does a spider have?", "A spider has eight legs."), |
| ("How many legs does an insect have?", "An insect has six legs."), |
| ("What do plants need to grow?", "Plants need sunlight, water, soil, and air to grow."), |
| ("What do humans eat?", "Humans eat a variety of foods including fruits, vegetables, meat, and grains."), |
| ("What is a book?", "A book is a collection of written or printed pages bound together."), |
| ("What is a computer?", "A computer is an electronic device that processes information."), |
| ("What is a phone?", "A phone is a device used to communicate with people at a distance."), |
| ("What is music?", "Music is an arrangement of sounds that is pleasing to hear."), |
| ("What is art?", "Art is the expression of human creativity and imagination."), |
| ("What is a language?", "A language is a system of communication used by a group of people."), |
| ] |
|
|
| |
| BACKUP_QA = (_BACKUP_QA * 4)[:200] |
|
|
|
|
| |
| |
| |
|
|
| class _TokenizerWrapper: |
| """Minimal wrapper around the pickled tiktoken.Encoding. We avoid |
| importing `prepare.Tokenizer` to sidestep its side effects (which |
| touch the running pretrain's cache files).""" |
|
|
| def __init__(self, enc): |
| self.enc = enc |
|
|
| def encode(self, text: str) -> list[int]: |
| return self.enc.encode_ordinary(text) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self.enc.n_vocab |
|
|
|
|
| def load_tokenizer() -> _TokenizerWrapper: |
| if not TOKENIZER_PKL.exists(): |
| raise FileNotFoundError( |
| f"Tokenizer not found at {TOKENIZER_PKL}. Run `python prepare.py` " |
| f"first." |
| ) |
| with open(TOKENIZER_PKL, "rb") as f: |
| enc = pickle.load(f) |
| tok = _TokenizerWrapper(enc) |
| expected_vocab = int(os.environ.get("HYDRA_VOCAB_SIZE", "65536")) |
| assert tok.vocab_size == expected_vocab, ( |
| f"download_sft_data: tokenizer vocab {tok.vocab_size} != HYDRA_VOCAB_SIZE {expected_vocab}; " |
| "rerun prepare.py or set HYDRA_VOCAB_SIZE to match." |
| ) |
| return tok |
|
|
|
|
| |
| |
| |
|
|
| def _download_parquet(url: str, local_path: Path, timeout: int = 60) -> bool: |
| """Stream-download a parquet file with retry. Returns True on success.""" |
| local_path.parent.mkdir(parents=True, exist_ok=True) |
| tmp = local_path.with_suffix(local_path.suffix + ".tmp") |
| for attempt in range(1, 4): |
| try: |
| with requests.get(url, stream=True, timeout=timeout, |
| allow_redirects=True) as r: |
| r.raise_for_status() |
| with open(tmp, "wb") as f: |
| for chunk in r.iter_content(chunk_size=1 << 20): |
| if chunk: |
| f.write(chunk) |
| tmp.replace(local_path) |
| return True |
| except Exception as e: |
| print(f" [net] attempt {attempt} failed: {e}", flush=True) |
| for p in (tmp, local_path): |
| try: |
| p.unlink() |
| except FileNotFoundError: |
| pass |
| time.sleep(2 ** attempt) |
| return False |
|
|
|
|
| def _iter_alpaca(local_path: Path): |
| """Yield (instruction, input, output) from alpaca-cleaned parquet.""" |
| import pyarrow.parquet as pq |
| pf = pq.ParquetFile(str(local_path)) |
| for rg_idx in range(pf.num_row_groups): |
| rg = pf.read_row_group(rg_idx) |
| instr_col = rg.column("instruction").to_pylist() |
| input_col = rg.column("input").to_pylist() |
| output_col = rg.column("output").to_pylist() |
| for instruction, input_text, output in zip(instr_col, input_col, output_col): |
| if instruction and output: |
| yield instruction, (input_text or ""), output |
|
|
|
|
| def _iter_dolly(local_path: Path): |
| """Yield (instruction, input, output) from dolly-15k parquet.""" |
| import pyarrow.parquet as pq |
| pf = pq.ParquetFile(str(local_path)) |
| |
| for rg_idx in range(pf.num_row_groups): |
| rg = pf.read_row_group(rg_idx) |
| cols = {n: rg.column(n).to_pylist() for n in rg.schema.names} |
| instr_col = cols.get("instruction") or cols.get("Instruction") |
| ctx_col = cols.get("context") or cols.get("Context") or [""] * len(instr_col) |
| resp_col = cols.get("response") or cols.get("Response") |
| for instruction, context, response in zip(instr_col, ctx_col, resp_col): |
| if instruction and response: |
| yield instruction, (context or ""), response |
|
|
|
|
| def _iter_backup(): |
| for q, a in BACKUP_QA: |
| yield q, "", a |
|
|
|
|
| |
| |
| |
|
|
| def encode_example(tok: _TokenizerWrapper, instruction: str, |
| input_text: str, output: str) -> list[int]: |
| """Serialize one instruction/response pair into a flat token list. |
| |
| Format: |
| <BOS> <|user|> \\n {instr}\\n[{input}\\n] <|assistant|> \\n {output} <|end|> \\n |
| """ |
| ids: list[int] = [BOS_ID, USER_ID] |
| ids += tok.encode("\n" + instruction.strip()) |
| if input_text and input_text.strip(): |
| ids += tok.encode("\n" + input_text.strip()) |
| ids += tok.encode("\n") |
| ids.append(ASSISTANT_ID) |
| ids += tok.encode("\n" + output.strip()) |
| ids.append(END_ID) |
| ids += tok.encode("\n") |
| return ids |
|
|
|
|
| def encode_example_with_mask(tok: _TokenizerWrapper, instruction: str, |
| input_text: str, output: str |
| ) -> tuple[list[int], list[int]]: |
| """Return (tokens, mask) where mask[i]=1 means 'compute loss on token i' |
| and mask[i]=0 means 'prompt, ignore'. The boundary is the <|assistant|> |
| token: the assistant response (and <|end|>) contribute to loss; the |
| user prompt does not.""" |
| prompt_ids = [BOS_ID, USER_ID] + tok.encode("\n" + instruction.strip()) |
| if input_text and input_text.strip(): |
| prompt_ids += tok.encode("\n" + input_text.strip()) |
| prompt_ids += tok.encode("\n") |
| prompt_ids.append(ASSISTANT_ID) |
|
|
| response_ids = tok.encode("\n" + output.strip()) |
| response_ids.append(END_ID) |
| response_ids += tok.encode("\n") |
|
|
| ids = prompt_ids + response_ids |
| mask = [0] * len(prompt_ids) + [1] * len(response_ids) |
| return ids, mask |
|
|
|
|
| |
| |
| |
|
|
| class ShardWriter: |
| """Writes two parallel int16 files per shard: |
| data/sft/shard_XXX.bin — token IDs |
| data/sft/mask_XXX.bin — 0/1 loss mask |
| |
| Packs one example after another with no padding. At runtime, SFT builds |
| sequences of length MAX_SEQ_LEN by slicing across these flat arrays. |
| """ |
|
|
| def __init__(self, out_dir: Path, tokens_per_shard: int = TOKENS_PER_SHARD): |
| self.out_dir = out_dir |
| self.tokens_per_shard = tokens_per_shard |
| self.shard_idx = 0 |
| self._buf_tok: list[int] = [] |
| self._buf_mask: list[int] = [] |
| self.total_tokens = 0 |
|
|
| def add(self, tokens: list[int], mask: list[int]): |
| assert len(tokens) == len(mask) |
| self._buf_tok.extend(tokens) |
| self._buf_mask.extend(mask) |
| self.total_tokens += len(tokens) |
| while len(self._buf_tok) >= self.tokens_per_shard: |
| self._flush_one(self.tokens_per_shard) |
|
|
| def _flush_one(self, n: int): |
| tok_path = self.out_dir / f"shard_{self.shard_idx:04d}.bin" |
| mask_path = self.out_dir / f"mask_{self.shard_idx:04d}.bin" |
| arr_tok = np.array(self._buf_tok[:n], dtype=DTYPE) |
| arr_mask = np.array(self._buf_mask[:n], dtype=np.uint8) |
| arr_tok.tofile(tok_path) |
| arr_mask.tofile(mask_path) |
| self._buf_tok = self._buf_tok[n:] |
| self._buf_mask = self._buf_mask[n:] |
| print(f" wrote {tok_path.name} ({n:,} tokens)", flush=True) |
| self.shard_idx += 1 |
|
|
| def finalize(self): |
| if self._buf_tok: |
| self._flush_one(len(self._buf_tok)) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--test", action="store_true", |
| help="Small smoke run: write ~1.5M tokens and exit.") |
| ap.add_argument("--offline", action="store_true", |
| help="Skip network, use hard-coded backup only.") |
| ap.add_argument("--target-tokens", type=int, default=None, |
| help="Override target token count.") |
| args = ap.parse_args() |
|
|
| target = args.target_tokens or ( |
| TARGET_TOKENS_TEST if args.test else TARGET_TOKENS_DEFAULT |
| ) |
|
|
| print(f"SFT_DIR: {SFT_DIR}") |
| print(f"Target tokens: {target:,}") |
| print(f"Offline mode: {args.offline}") |
|
|
| |
| for p in SFT_DIR.glob("shard_*.bin"): |
| p.unlink() |
| for p in SFT_DIR.glob("mask_*.bin"): |
| p.unlink() |
|
|
| tok = load_tokenizer() |
| print(f"Tokenizer vocab: {tok.vocab_size}") |
| print(f"Special tokens: BOS={BOS_ID} USER={USER_ID} " |
| f"ASSISTANT={ASSISTANT_ID} END={END_ID}") |
|
|
| sources = [] |
| if not args.offline: |
| alpaca_path = SFT_DIR / "alpaca_raw.parquet" |
| print(f"\n[src] downloading alpaca-cleaned -> {alpaca_path.name} ...") |
| if _download_parquet(ALPACA_URL, alpaca_path): |
| print(f" ok ({alpaca_path.stat().st_size // (1 << 20)} MiB)") |
| sources.append(("alpaca-cleaned", lambda: _iter_alpaca(alpaca_path))) |
| else: |
| print(" alpaca download FAILED, trying dolly...") |
| dolly_path = SFT_DIR / "dolly_raw.parquet" |
| if _download_parquet(DOLLY_URL, dolly_path): |
| print(f" ok ({dolly_path.stat().st_size // (1 << 20)} MiB)") |
| sources.append(("dolly-15k", lambda: _iter_dolly(dolly_path))) |
|
|
| |
| sources.append(("backup-200", _iter_backup)) |
|
|
| if not sources: |
| print("FATAL: no data sources available.", file=sys.stderr) |
| sys.exit(1) |
|
|
| |
| writer = ShardWriter(SFT_DIR) |
| n_examples = 0 |
| n_assistant_tokens = 0 |
| source_counts = {} |
|
|
| for src_name, src_fn in sources: |
| print(f"\n[src] encoding {src_name} ...") |
| src_examples = 0 |
| src_tokens = 0 |
| for (instruction, input_text, output) in src_fn(): |
| |
| if len(output) > 2000: |
| output = output[:2000] |
| ids, mask = encode_example_with_mask(tok, instruction, |
| input_text, output) |
| if len(ids) < 4 or len(ids) > 512: |
| |
| continue |
| writer.add(ids, mask) |
| n_examples += 1 |
| src_examples += 1 |
| src_tokens += len(ids) |
| n_assistant_tokens += sum(mask) |
| if writer.total_tokens >= target: |
| break |
| source_counts[src_name] = { |
| "examples": src_examples, |
| "tokens": src_tokens, |
| } |
| print(f" {src_name}: {src_examples:,} examples, {src_tokens:,} tokens") |
| if writer.total_tokens >= target: |
| break |
|
|
| writer.finalize() |
|
|
| meta = { |
| "total_tokens": writer.total_tokens, |
| "total_examples": n_examples, |
| "assistant_tokens_in_loss": n_assistant_tokens, |
| "num_shards": writer.shard_idx, |
| "tokens_per_shard": TOKENS_PER_SHARD, |
| "dtype": "int16", |
| "vocab_size": tok.vocab_size, |
| "special_tokens": { |
| "bos": BOS_ID, |
| "user": USER_ID, |
| "assistant": ASSISTANT_ID, |
| "end": END_ID, |
| }, |
| "sources": source_counts, |
| "format_hint": ( |
| "<BOS><|user|>\\n{instr}\\n[{input}\\n]<|assistant|>\\n" |
| "{output}<|end|>\\n" |
| ), |
| } |
| meta_path = SFT_DIR / "meta.json" |
| with open(meta_path, "w") as f: |
| json.dump(meta, f, indent=2) |
|
|
| print(f"\n===== SFT data ready =====") |
| print(f" examples: {n_examples:,}") |
| print(f" total tokens: {writer.total_tokens:,}") |
| print(f" loss tokens: {n_assistant_tokens:,}") |
| print(f" shards: {writer.shard_idx}") |
| print(f" meta: {meta_path}") |
|
|
| if args.test and writer.total_tokens < 1_000_000: |
| print(f"\nWARN: test mode produced only {writer.total_tokens:,} " |
| f"tokens — below 1M threshold.") |
| sys.exit(2) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|