File size: 5,252 Bytes
1059a9e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
Dataset loaders for multi-source training.
Supported datasets:
shakespeare β Tiny Shakespeare (~1M chars, classic GPT demo)
alpaca β Stanford Alpaca 52K instruction-following examples
openwebtext β Small OpenWebText sample from HuggingFace (~1GB)
custom β Any local .txt file passed via --custom_file
"""
import os
import json
import urllib.request
import torch
# ββ Download helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _download(url, dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
if not os.path.exists(dest):
print(f" Downloading {os.path.basename(dest)} ...")
urllib.request.urlretrieve(url, dest)
print(f" Saved to {dest}")
return dest
def get_shakespeare(data_dir="data"):
return _download(
"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
os.path.join(data_dir, "shakespeare.txt"),
)
def get_alpaca(data_dir="data"):
path = os.path.join(data_dir, "alpaca.json")
_download(
"https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json",
path,
)
return path
def get_openwebtext_sample(data_dir="data"):
"""Downloads a sample of web text via the HuggingFace datasets API."""
path = os.path.join(data_dir, "openwebtext_sample.txt")
if os.path.exists(path):
return path
try:
from datasets import load_dataset
print(" Loading web text sample via HuggingFace datasets ...")
# Use a small slice of HuggingFace's text datasets that work without scripts
ds = load_dataset("Skylion007/openwebtext", split="train", streaming=True, trust_remote_code=False)
count = 0
with open(path, "w") as f:
for item in ds:
f.write(item["text"].strip() + "\n\n")
count += 1
if count >= 5000:
break
print(f" Saved {count} documents to {path}")
except Exception as e:
print(f" Skipping openwebtext: {e}")
return None
return None
return path
# ββ Text formatters βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _format_alpaca(item):
"""Format one Alpaca record as a prompt/response string."""
if item.get("input", "").strip():
return (
f"### Instruction:\n{item['instruction']}\n\n"
f"### Input:\n{item['input']}\n\n"
f"### Response:\n{item['output']}\n\n"
)
return (
f"### Instruction:\n{item['instruction']}\n\n"
f"### Response:\n{item['output']}\n\n"
)
# ββ Dataset registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DATASETS = {
"shakespeare": get_shakespeare,
"alpaca": get_alpaca,
"openwebtext": get_openwebtext_sample,
}
def load_text(name, data_dir="data", custom_file=None):
"""Return raw text string for the given dataset name."""
if name == "custom":
assert custom_file, "--custom_file required for custom dataset"
with open(custom_file) as f:
return f.read()
if name == "alpaca":
path = get_alpaca(data_dir)
with open(path) as f:
records = json.load(f)
return "".join(_format_alpaca(r) for r in records)
if name not in DATASETS:
raise ValueError(f"Unknown dataset '{name}'. Choose from: {list(DATASETS)}, custom")
path = DATASETS[name](data_dir)
if path is None:
return ""
with open(path) as f:
return f.read()
def build_combined_text(names, data_dir="data", custom_file=None, weights=None):
"""
Load and concatenate multiple datasets.
weights: list of floats (same length as names) to sub-sample each source.
e.g. [1.0, 0.5] uses 100% of names[0] and 50% of names[1].
"""
if weights is None:
weights = [1.0] * len(names)
assert len(weights) == len(names)
parts = []
for name, w in zip(names, weights):
print(f"Loading dataset: {name} (weight {w})")
text = load_text(name, data_dir=data_dir, custom_file=custom_file)
if not text:
continue
if w < 1.0:
cut = int(len(text) * w)
text = text[:cut]
parts.append(text)
print(f" {name}: {len(text):,} chars")
combined = "\n\n".join(parts)
print(f"Total combined: {len(combined):,} chars")
return combined
# ββ Token tensor builder ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def tokenize_and_split(text, tokenizer, split_ratio=0.9):
tokens = tokenizer.encode(text)
data = torch.tensor(tokens, dtype=torch.long)
n = int(split_ratio * len(data))
return data[:n], data[n:]
|