| """ |
| Dataset loaders for multi-source training. |
| |
| Supported datasets: |
| shakespeare β Tiny Shakespeare (~1M chars, classic GPT demo) |
| alpaca β Stanford Alpaca 52K instruction-following examples |
| openwebtext β Small OpenWebText sample from HuggingFace (~1GB) |
| custom β Any local .txt file passed via --custom_file |
| """ |
| import os |
| import json |
| import urllib.request |
|
|
| import torch |
|
|
|
|
| |
|
|
| def _download(url, dest): |
| os.makedirs(os.path.dirname(dest), exist_ok=True) |
| if not os.path.exists(dest): |
| print(f" Downloading {os.path.basename(dest)} ...") |
| urllib.request.urlretrieve(url, dest) |
| print(f" Saved to {dest}") |
| return dest |
|
|
|
|
| def get_shakespeare(data_dir="data"): |
| return _download( |
| "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", |
| os.path.join(data_dir, "shakespeare.txt"), |
| ) |
|
|
|
|
| def get_alpaca(data_dir="data"): |
| path = os.path.join(data_dir, "alpaca.json") |
| _download( |
| "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json", |
| path, |
| ) |
| return path |
|
|
|
|
| def get_openwebtext_sample(data_dir="data"): |
| """Downloads a sample of web text via the HuggingFace datasets API.""" |
| path = os.path.join(data_dir, "openwebtext_sample.txt") |
| if os.path.exists(path): |
| return path |
| try: |
| from datasets import load_dataset |
| print(" Loading web text sample via HuggingFace datasets ...") |
| |
| ds = load_dataset("Skylion007/openwebtext", split="train", streaming=True, trust_remote_code=False) |
| count = 0 |
| with open(path, "w") as f: |
| for item in ds: |
| f.write(item["text"].strip() + "\n\n") |
| count += 1 |
| if count >= 5000: |
| break |
| print(f" Saved {count} documents to {path}") |
| except Exception as e: |
| print(f" Skipping openwebtext: {e}") |
| return None |
| return None |
| return path |
|
|
|
|
| |
|
|
| def _format_alpaca(item): |
| """Format one Alpaca record as a prompt/response string.""" |
| if item.get("input", "").strip(): |
| return ( |
| f"### Instruction:\n{item['instruction']}\n\n" |
| f"### Input:\n{item['input']}\n\n" |
| f"### Response:\n{item['output']}\n\n" |
| ) |
| return ( |
| f"### Instruction:\n{item['instruction']}\n\n" |
| f"### Response:\n{item['output']}\n\n" |
| ) |
|
|
|
|
| |
|
|
| DATASETS = { |
| "shakespeare": get_shakespeare, |
| "alpaca": get_alpaca, |
| "openwebtext": get_openwebtext_sample, |
| } |
|
|
|
|
| def load_text(name, data_dir="data", custom_file=None): |
| """Return raw text string for the given dataset name.""" |
| if name == "custom": |
| assert custom_file, "--custom_file required for custom dataset" |
| with open(custom_file) as f: |
| return f.read() |
|
|
| if name == "alpaca": |
| path = get_alpaca(data_dir) |
| with open(path) as f: |
| records = json.load(f) |
| return "".join(_format_alpaca(r) for r in records) |
|
|
| if name not in DATASETS: |
| raise ValueError(f"Unknown dataset '{name}'. Choose from: {list(DATASETS)}, custom") |
|
|
| path = DATASETS[name](data_dir) |
| if path is None: |
| return "" |
| with open(path) as f: |
| return f.read() |
|
|
|
|
| def build_combined_text(names, data_dir="data", custom_file=None, weights=None): |
| """ |
| Load and concatenate multiple datasets. |
| weights: list of floats (same length as names) to sub-sample each source. |
| e.g. [1.0, 0.5] uses 100% of names[0] and 50% of names[1]. |
| """ |
| if weights is None: |
| weights = [1.0] * len(names) |
| assert len(weights) == len(names) |
|
|
| parts = [] |
| for name, w in zip(names, weights): |
| print(f"Loading dataset: {name} (weight {w})") |
| text = load_text(name, data_dir=data_dir, custom_file=custom_file) |
| if not text: |
| continue |
| if w < 1.0: |
| cut = int(len(text) * w) |
| text = text[:cut] |
| parts.append(text) |
| print(f" {name}: {len(text):,} chars") |
|
|
| combined = "\n\n".join(parts) |
| print(f"Total combined: {len(combined):,} chars") |
| return combined |
|
|
|
|
| |
|
|
| def tokenize_and_split(text, tokenizer, split_ratio=0.9): |
| tokens = tokenizer.encode(text) |
| data = torch.tensor(tokens, dtype=torch.long) |
| n = int(split_ratio * len(data)) |
| return data[:n], data[n:] |
|
|