Spaces:
Runtime error
Runtime error
| """Download + tokenize instruction data for HYDRA SFT. | |
| Writes int16 token shards to `data/sft/shard_XXX.bin` plus a | |
| `data/sft/meta.json` with counts + special-token mapping. | |
| Chat format (vocab's 4 reserved special tokens are repurposed): | |
| <BOS=8188> <|user|=8189>\n{instruction}\n{input?}\n <|assistant|=8190>\n | |
| {output}<|end|=8191>\n | |
| Special-token IDs are constants derived from the tokenizer (they are the | |
| last 4 IDs in an 8192-vocab). They are stored in meta.json for the SFT | |
| script to read. | |
| Sources (tried in order): | |
| 1. yahma/alpaca-cleaned (~52K pairs via HF parquet auto-convert) | |
| 2. databricks/databricks-dolly-15k (~15K pairs) | |
| 3. Hard-coded 200 simple Q&A pairs (offline backup) | |
| Usage: | |
| python scripts/download_sft_data.py # full download | |
| python scripts/download_sft_data.py --test # small smoke run | |
| python scripts/download_sft_data.py --offline # skip network; use backup | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import pickle | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import requests | |
| # Make `prepare` and `hydra.*` importable when run as a script | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| CACHE_DIR = Path.home() / ".cache" / "autoresearch" | |
| TOKENIZER_PKL = CACHE_DIR / "tokenizer" / "tokenizer.pkl" | |
| SFT_DIR = _REPO_ROOT / "data" / "sft" | |
| SFT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Reserved token repurposing — must match prepare.py SPECIAL_TOKENS list | |
| # (indices 8188-8191 in the 8192-vocab BPE). | |
| BOS_ID = 8188 # <|reserved_0|> | |
| USER_ID = 8189 # <|reserved_1|> | |
| ASSISTANT_ID = 8190 # <|reserved_2|> | |
| END_ID = 8191 # <|reserved_3|> | |
| # Shards are int16 arrays of packed token IDs. | |
| TOKENS_PER_SHARD = 1_048_576 # ~2 MB per shard | |
| DTYPE = np.int16 # vocab_size=8192 fits in int16 | |
| TARGET_TOKENS_DEFAULT = 15_000_000 # ~15M instruction tokens | |
| TARGET_TOKENS_TEST = 1_500_000 # smoke run | |
| # HuggingFace auto-parquet endpoint — one file for alpaca-cleaned | |
| ALPACA_URL = ( | |
| "https://huggingface.co/api/datasets/yahma/alpaca-cleaned/parquet/" | |
| "default/train/0.parquet" | |
| ) | |
| DOLLY_URL = ( | |
| "https://huggingface.co/api/datasets/databricks/databricks-dolly-15k/" | |
| "parquet/default/train/0.parquet" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Offline backup Q&A pairs (used only if network unavailable) | |
| # --------------------------------------------------------------------------- | |
| _BACKUP_QA = [ | |
| ("What is the capital of France?", "The capital of France is Paris."), | |
| ("What is the capital of Germany?", "The capital of Germany is Berlin."), | |
| ("What is the capital of Japan?", "The capital of Japan is Tokyo."), | |
| ("What is the capital of Italy?", "The capital of Italy is Rome."), | |
| ("What is the capital of Spain?", "The capital of Spain is Madrid."), | |
| ("What is the capital of England?", "The capital of England is London."), | |
| ("What is the capital of Canada?", "The capital of Canada is Ottawa."), | |
| ("What is the capital of Australia?", "The capital of Australia is Canberra."), | |
| ("What is 2 plus 2?", "Two plus two equals four."), | |
| ("What is 5 plus 5?", "Five plus five equals ten."), | |
| ("What is 3 times 4?", "Three times four equals twelve."), | |
| ("What is 10 minus 3?", "Ten minus three equals seven."), | |
| ("Name a color.", "Red is a color. Other colors include blue, green, and yellow."), | |
| ("Name a fruit.", "Apple is a fruit. Other fruits include banana, orange, and grape."), | |
| ("Name an animal.", "A dog is an animal. Other animals include cat, horse, and cow."), | |
| ("What do birds do?", "Birds fly through the air using their wings."), | |
| ("What do fish do?", "Fish swim in water using their fins and tails."), | |
| ("What is water made of?", "Water is made of hydrogen and oxygen."), | |
| ("What is the sun?", "The sun is a star at the center of our solar system."), | |
| ("What is the moon?", "The moon is Earth's natural satellite."), | |
| ("What is a cloud?", "A cloud is a collection of tiny water droplets in the sky."), | |
| ("What is rain?", "Rain is water that falls from clouds to the ground."), | |
| ("What is snow?", "Snow is frozen water that falls from clouds in cold weather."), | |
| ("What is wind?", "Wind is moving air."), | |
| ("What is fire?", "Fire is a chemical reaction that produces heat and light."), | |
| ("What is ice?", "Ice is frozen water."), | |
| ("Say hello.", "Hello! How can I help you today?"), | |
| ("Say goodbye.", "Goodbye! Have a great day."), | |
| ("Tell me a short story.", "Once upon a time, a brave rabbit lived in the forest. The end."), | |
| ("Tell me a joke.", "Why did the chicken cross the road? To get to the other side."), | |
| ("Who wrote Hamlet?", "William Shakespeare wrote the play Hamlet."), | |
| ("Who wrote Romeo and Juliet?", "William Shakespeare wrote Romeo and Juliet."), | |
| ("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa."), | |
| ("When did World War 2 end?", "World War 2 ended in 1945."), | |
| ("What is gravity?", "Gravity is the force that pulls objects toward the Earth."), | |
| ("What is the speed of light?", "The speed of light is approximately 300,000 kilometers per second."), | |
| ("What is the largest planet?", "Jupiter is the largest planet in our solar system."), | |
| ("What is the smallest planet?", "Mercury is the smallest planet in our solar system."), | |
| ("At what temperature does water boil?", "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit."), | |
| ("At what temperature does water freeze?", "Water freezes at 0 degrees Celsius or 32 degrees Fahrenheit."), | |
| ("How many legs does a spider have?", "A spider has eight legs."), | |
| ("How many legs does an insect have?", "An insect has six legs."), | |
| ("What do plants need to grow?", "Plants need sunlight, water, soil, and air to grow."), | |
| ("What do humans eat?", "Humans eat a variety of foods including fruits, vegetables, meat, and grains."), | |
| ("What is a book?", "A book is a collection of written or printed pages bound together."), | |
| ("What is a computer?", "A computer is an electronic device that processes information."), | |
| ("What is a phone?", "A phone is a device used to communicate with people at a distance."), | |
| ("What is music?", "Music is an arrangement of sounds that is pleasing to hear."), | |
| ("What is art?", "Art is the expression of human creativity and imagination."), | |
| ("What is a language?", "A language is a system of communication used by a group of people."), | |
| ] | |
| # Duplicate to reach ~200 samples (each pair appears ~4x) | |
| BACKUP_QA = (_BACKUP_QA * 4)[:200] | |
| # --------------------------------------------------------------------------- | |
| # Tokenizer loader | |
| # --------------------------------------------------------------------------- | |
| class _TokenizerWrapper: | |
| """Minimal wrapper around the pickled tiktoken.Encoding. We avoid | |
| importing `prepare.Tokenizer` to sidestep its side effects (which | |
| touch the running pretrain's cache files).""" | |
| def __init__(self, enc): | |
| self.enc = enc | |
| def encode(self, text: str) -> list[int]: | |
| return self.enc.encode_ordinary(text) | |
| def vocab_size(self) -> int: | |
| return self.enc.n_vocab | |
| def load_tokenizer() -> _TokenizerWrapper: | |
| if not TOKENIZER_PKL.exists(): | |
| raise FileNotFoundError( | |
| f"Tokenizer not found at {TOKENIZER_PKL}. Run `python prepare.py` " | |
| f"first." | |
| ) | |
| with open(TOKENIZER_PKL, "rb") as f: | |
| enc = pickle.load(f) | |
| tok = _TokenizerWrapper(enc) | |
| assert tok.vocab_size == 8192, f"Expected vocab=8192, got {tok.vocab_size}" | |
| return tok | |
| # --------------------------------------------------------------------------- | |
| # Source downloaders | |
| # --------------------------------------------------------------------------- | |
| def _download_parquet(url: str, local_path: Path, timeout: int = 60) -> bool: | |
| """Stream-download a parquet file with retry. Returns True on success.""" | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp = local_path.with_suffix(local_path.suffix + ".tmp") | |
| for attempt in range(1, 4): | |
| try: | |
| with requests.get(url, stream=True, timeout=timeout, | |
| allow_redirects=True) as r: | |
| r.raise_for_status() | |
| with open(tmp, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1 << 20): | |
| if chunk: | |
| f.write(chunk) | |
| tmp.replace(local_path) | |
| return True | |
| except Exception as e: | |
| print(f" [net] attempt {attempt} failed: {e}", flush=True) | |
| for p in (tmp, local_path): | |
| try: | |
| p.unlink() | |
| except FileNotFoundError: | |
| pass | |
| time.sleep(2 ** attempt) | |
| return False | |
| def _iter_alpaca(local_path: Path): | |
| """Yield (instruction, input, output) from alpaca-cleaned parquet.""" | |
| import pyarrow.parquet as pq | |
| pf = pq.ParquetFile(str(local_path)) | |
| for rg_idx in range(pf.num_row_groups): | |
| rg = pf.read_row_group(rg_idx) | |
| instr_col = rg.column("instruction").to_pylist() | |
| input_col = rg.column("input").to_pylist() | |
| output_col = rg.column("output").to_pylist() | |
| for instruction, input_text, output in zip(instr_col, input_col, output_col): | |
| if instruction and output: | |
| yield instruction, (input_text or ""), output | |
| def _iter_dolly(local_path: Path): | |
| """Yield (instruction, input, output) from dolly-15k parquet.""" | |
| import pyarrow.parquet as pq | |
| pf = pq.ParquetFile(str(local_path)) | |
| # Schema: instruction, context, response, category | |
| for rg_idx in range(pf.num_row_groups): | |
| rg = pf.read_row_group(rg_idx) | |
| cols = {n: rg.column(n).to_pylist() for n in rg.schema.names} | |
| instr_col = cols.get("instruction") or cols.get("Instruction") | |
| ctx_col = cols.get("context") or cols.get("Context") or [""] * len(instr_col) | |
| resp_col = cols.get("response") or cols.get("Response") | |
| for instruction, context, response in zip(instr_col, ctx_col, resp_col): | |
| if instruction and response: | |
| yield instruction, (context or ""), response | |
| def _iter_backup(): | |
| for q, a in BACKUP_QA: | |
| yield q, "", a | |
| # --------------------------------------------------------------------------- | |
| # Encoding | |
| # --------------------------------------------------------------------------- | |
| def encode_example(tok: _TokenizerWrapper, instruction: str, | |
| input_text: str, output: str) -> list[int]: | |
| """Serialize one instruction/response pair into a flat token list. | |
| Format: | |
| <BOS> <|user|> \\n {instr}\\n[{input}\\n] <|assistant|> \\n {output} <|end|> \\n | |
| """ | |
| ids: list[int] = [BOS_ID, USER_ID] | |
| ids += tok.encode("\n" + instruction.strip()) | |
| if input_text and input_text.strip(): | |
| ids += tok.encode("\n" + input_text.strip()) | |
| ids += tok.encode("\n") | |
| ids.append(ASSISTANT_ID) | |
| ids += tok.encode("\n" + output.strip()) | |
| ids.append(END_ID) | |
| ids += tok.encode("\n") | |
| return ids | |
| def encode_example_with_mask(tok: _TokenizerWrapper, instruction: str, | |
| input_text: str, output: str | |
| ) -> tuple[list[int], list[int]]: | |
| """Return (tokens, mask) where mask[i]=1 means 'compute loss on token i' | |
| and mask[i]=0 means 'prompt, ignore'. The boundary is the <|assistant|> | |
| token: the assistant response (and <|end|>) contribute to loss; the | |
| user prompt does not.""" | |
| prompt_ids = [BOS_ID, USER_ID] + tok.encode("\n" + instruction.strip()) | |
| if input_text and input_text.strip(): | |
| prompt_ids += tok.encode("\n" + input_text.strip()) | |
| prompt_ids += tok.encode("\n") | |
| prompt_ids.append(ASSISTANT_ID) | |
| response_ids = tok.encode("\n" + output.strip()) | |
| response_ids.append(END_ID) | |
| response_ids += tok.encode("\n") | |
| ids = prompt_ids + response_ids | |
| mask = [0] * len(prompt_ids) + [1] * len(response_ids) | |
| return ids, mask | |
| # --------------------------------------------------------------------------- | |
| # Shard writer | |
| # --------------------------------------------------------------------------- | |
| class ShardWriter: | |
| """Writes two parallel int16 files per shard: | |
| data/sft/shard_XXX.bin — token IDs | |
| data/sft/mask_XXX.bin — 0/1 loss mask | |
| Packs one example after another with no padding. At runtime, SFT builds | |
| sequences of length MAX_SEQ_LEN by slicing across these flat arrays. | |
| """ | |
| def __init__(self, out_dir: Path, tokens_per_shard: int = TOKENS_PER_SHARD): | |
| self.out_dir = out_dir | |
| self.tokens_per_shard = tokens_per_shard | |
| self.shard_idx = 0 | |
| self._buf_tok: list[int] = [] | |
| self._buf_mask: list[int] = [] | |
| self.total_tokens = 0 | |
| def add(self, tokens: list[int], mask: list[int]): | |
| assert len(tokens) == len(mask) | |
| self._buf_tok.extend(tokens) | |
| self._buf_mask.extend(mask) | |
| self.total_tokens += len(tokens) | |
| while len(self._buf_tok) >= self.tokens_per_shard: | |
| self._flush_one(self.tokens_per_shard) | |
| def _flush_one(self, n: int): | |
| tok_path = self.out_dir / f"shard_{self.shard_idx:04d}.bin" | |
| mask_path = self.out_dir / f"mask_{self.shard_idx:04d}.bin" | |
| arr_tok = np.array(self._buf_tok[:n], dtype=DTYPE) | |
| arr_mask = np.array(self._buf_mask[:n], dtype=np.uint8) | |
| arr_tok.tofile(tok_path) | |
| arr_mask.tofile(mask_path) | |
| self._buf_tok = self._buf_tok[n:] | |
| self._buf_mask = self._buf_mask[n:] | |
| print(f" wrote {tok_path.name} ({n:,} tokens)", flush=True) | |
| self.shard_idx += 1 | |
| def finalize(self): | |
| if self._buf_tok: | |
| self._flush_one(len(self._buf_tok)) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--test", action="store_true", | |
| help="Small smoke run: write ~1.5M tokens and exit.") | |
| ap.add_argument("--offline", action="store_true", | |
| help="Skip network, use hard-coded backup only.") | |
| ap.add_argument("--target-tokens", type=int, default=None, | |
| help="Override target token count.") | |
| args = ap.parse_args() | |
| target = args.target_tokens or ( | |
| TARGET_TOKENS_TEST if args.test else TARGET_TOKENS_DEFAULT | |
| ) | |
| print(f"SFT_DIR: {SFT_DIR}") | |
| print(f"Target tokens: {target:,}") | |
| print(f"Offline mode: {args.offline}") | |
| # Clear any prior shards | |
| for p in SFT_DIR.glob("shard_*.bin"): | |
| p.unlink() | |
| for p in SFT_DIR.glob("mask_*.bin"): | |
| p.unlink() | |
| tok = load_tokenizer() | |
| print(f"Tokenizer vocab: {tok.vocab_size}") | |
| print(f"Special tokens: BOS={BOS_ID} USER={USER_ID} " | |
| f"ASSISTANT={ASSISTANT_ID} END={END_ID}") | |
| sources = [] # list of (name, iterator_fn) | |
| if not args.offline: | |
| alpaca_path = SFT_DIR / "alpaca_raw.parquet" | |
| print(f"\n[src] downloading alpaca-cleaned -> {alpaca_path.name} ...") | |
| if _download_parquet(ALPACA_URL, alpaca_path): | |
| print(f" ok ({alpaca_path.stat().st_size // (1 << 20)} MiB)") | |
| sources.append(("alpaca-cleaned", lambda: _iter_alpaca(alpaca_path))) | |
| else: | |
| print(" alpaca download FAILED, trying dolly...") | |
| dolly_path = SFT_DIR / "dolly_raw.parquet" | |
| if _download_parquet(DOLLY_URL, dolly_path): | |
| print(f" ok ({dolly_path.stat().st_size // (1 << 20)} MiB)") | |
| sources.append(("dolly-15k", lambda: _iter_dolly(dolly_path))) | |
| # Always include backup — cheap, catches tail | |
| sources.append(("backup-200", _iter_backup)) | |
| if not sources: | |
| print("FATAL: no data sources available.", file=sys.stderr) | |
| sys.exit(1) | |
| # Stream-encode | |
| writer = ShardWriter(SFT_DIR) | |
| n_examples = 0 | |
| n_assistant_tokens = 0 | |
| source_counts = {} | |
| for src_name, src_fn in sources: | |
| print(f"\n[src] encoding {src_name} ...") | |
| src_examples = 0 | |
| src_tokens = 0 | |
| for (instruction, input_text, output) in src_fn(): | |
| # Skip overly long outputs — 7.5M model can't use them | |
| if len(output) > 2000: | |
| output = output[:2000] | |
| ids, mask = encode_example_with_mask(tok, instruction, | |
| input_text, output) | |
| if len(ids) < 4 or len(ids) > 512: | |
| # Skip degenerate / too-long examples | |
| continue | |
| writer.add(ids, mask) | |
| n_examples += 1 | |
| src_examples += 1 | |
| src_tokens += len(ids) | |
| n_assistant_tokens += sum(mask) | |
| if writer.total_tokens >= target: | |
| break | |
| source_counts[src_name] = { | |
| "examples": src_examples, | |
| "tokens": src_tokens, | |
| } | |
| print(f" {src_name}: {src_examples:,} examples, {src_tokens:,} tokens") | |
| if writer.total_tokens >= target: | |
| break | |
| writer.finalize() | |
| meta = { | |
| "total_tokens": writer.total_tokens, | |
| "total_examples": n_examples, | |
| "assistant_tokens_in_loss": n_assistant_tokens, | |
| "num_shards": writer.shard_idx, | |
| "tokens_per_shard": TOKENS_PER_SHARD, | |
| "dtype": "int16", | |
| "vocab_size": tok.vocab_size, | |
| "special_tokens": { | |
| "bos": BOS_ID, | |
| "user": USER_ID, | |
| "assistant": ASSISTANT_ID, | |
| "end": END_ID, | |
| }, | |
| "sources": source_counts, | |
| "format_hint": ( | |
| "<BOS><|user|>\\n{instr}\\n[{input}\\n]<|assistant|>\\n" | |
| "{output}<|end|>\\n" | |
| ), | |
| } | |
| meta_path = SFT_DIR / "meta.json" | |
| with open(meta_path, "w") as f: | |
| json.dump(meta, f, indent=2) | |
| print(f"\n===== SFT data ready =====") | |
| print(f" examples: {n_examples:,}") | |
| print(f" total tokens: {writer.total_tokens:,}") | |
| print(f" loss tokens: {n_assistant_tokens:,}") | |
| print(f" shards: {writer.shard_idx}") | |
| print(f" meta: {meta_path}") | |
| if args.test and writer.total_tokens < 1_000_000: | |
| print(f"\nWARN: test mode produced only {writer.total_tokens:,} " | |
| f"tokens — below 1M threshold.") | |
| sys.exit(2) | |
| if __name__ == "__main__": | |
| main() | |