""" Dataset loaders for multi-source training. Supported datasets: shakespeare — Tiny Shakespeare (~1M chars, classic GPT demo) alpaca — Stanford Alpaca 52K instruction-following examples openwebtext — Small OpenWebText sample from HuggingFace (~1GB) custom — Any local .txt file passed via --custom_file """ import os import json import urllib.request import torch # ── Download helpers ────────────────────────────────────────────────────────── def _download(url, dest): os.makedirs(os.path.dirname(dest), exist_ok=True) if not os.path.exists(dest): print(f" Downloading {os.path.basename(dest)} ...") urllib.request.urlretrieve(url, dest) print(f" Saved to {dest}") return dest def get_shakespeare(data_dir="data"): return _download( "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", os.path.join(data_dir, "shakespeare.txt"), ) def get_alpaca(data_dir="data"): path = os.path.join(data_dir, "alpaca.json") _download( "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json", path, ) return path def get_openwebtext_sample(data_dir="data"): """Downloads a sample of web text via the HuggingFace datasets API.""" path = os.path.join(data_dir, "openwebtext_sample.txt") if os.path.exists(path): return path try: from datasets import load_dataset print(" Loading web text sample via HuggingFace datasets ...") # Use a small slice of HuggingFace's text datasets that work without scripts ds = load_dataset("Skylion007/openwebtext", split="train", streaming=True, trust_remote_code=False) count = 0 with open(path, "w") as f: for item in ds: f.write(item["text"].strip() + "\n\n") count += 1 if count >= 5000: break print(f" Saved {count} documents to {path}") except Exception as e: print(f" Skipping openwebtext: {e}") return None return None return path # ── Text formatters ─────────────────────────────────────────────────────────── def _format_alpaca(item): """Format one Alpaca record as a prompt/response string.""" if item.get("input", "").strip(): return ( f"### Instruction:\n{item['instruction']}\n\n" f"### Input:\n{item['input']}\n\n" f"### Response:\n{item['output']}\n\n" ) return ( f"### Instruction:\n{item['instruction']}\n\n" f"### Response:\n{item['output']}\n\n" ) # ── Dataset registry ────────────────────────────────────────────────────────── DATASETS = { "shakespeare": get_shakespeare, "alpaca": get_alpaca, "openwebtext": get_openwebtext_sample, } def load_text(name, data_dir="data", custom_file=None): """Return raw text string for the given dataset name.""" if name == "custom": assert custom_file, "--custom_file required for custom dataset" with open(custom_file) as f: return f.read() if name == "alpaca": path = get_alpaca(data_dir) with open(path) as f: records = json.load(f) return "".join(_format_alpaca(r) for r in records) if name not in DATASETS: raise ValueError(f"Unknown dataset '{name}'. Choose from: {list(DATASETS)}, custom") path = DATASETS[name](data_dir) if path is None: return "" with open(path) as f: return f.read() def build_combined_text(names, data_dir="data", custom_file=None, weights=None): """ Load and concatenate multiple datasets. weights: list of floats (same length as names) to sub-sample each source. e.g. [1.0, 0.5] uses 100% of names[0] and 50% of names[1]. """ if weights is None: weights = [1.0] * len(names) assert len(weights) == len(names) parts = [] for name, w in zip(names, weights): print(f"Loading dataset: {name} (weight {w})") text = load_text(name, data_dir=data_dir, custom_file=custom_file) if not text: continue if w < 1.0: cut = int(len(text) * w) text = text[:cut] parts.append(text) print(f" {name}: {len(text):,} chars") combined = "\n\n".join(parts) print(f"Total combined: {len(combined):,} chars") return combined # ── Token tensor builder ────────────────────────────────────────────────────── def tokenize_and_split(text, tokenizer, split_ratio=0.9): tokens = tokenizer.encode(text) data = torch.tensor(tokens, dtype=torch.long) n = int(split_ratio * len(data)) return data[:n], data[n:]