my-gpt-from-scratch / data_loader.py
edgemindroboticslabs's picture
Upload data_loader.py with huggingface_hub
1059a9e verified
"""
Dataset loaders for multi-source training.
Supported datasets:
shakespeare β€” Tiny Shakespeare (~1M chars, classic GPT demo)
alpaca β€” Stanford Alpaca 52K instruction-following examples
openwebtext β€” Small OpenWebText sample from HuggingFace (~1GB)
custom β€” Any local .txt file passed via --custom_file
"""
import os
import json
import urllib.request
import torch
# ── Download helpers ──────────────────────────────────────────────────────────
def _download(url, dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
if not os.path.exists(dest):
print(f" Downloading {os.path.basename(dest)} ...")
urllib.request.urlretrieve(url, dest)
print(f" Saved to {dest}")
return dest
def get_shakespeare(data_dir="data"):
return _download(
"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
os.path.join(data_dir, "shakespeare.txt"),
)
def get_alpaca(data_dir="data"):
path = os.path.join(data_dir, "alpaca.json")
_download(
"https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json",
path,
)
return path
def get_openwebtext_sample(data_dir="data"):
"""Downloads a sample of web text via the HuggingFace datasets API."""
path = os.path.join(data_dir, "openwebtext_sample.txt")
if os.path.exists(path):
return path
try:
from datasets import load_dataset
print(" Loading web text sample via HuggingFace datasets ...")
# Use a small slice of HuggingFace's text datasets that work without scripts
ds = load_dataset("Skylion007/openwebtext", split="train", streaming=True, trust_remote_code=False)
count = 0
with open(path, "w") as f:
for item in ds:
f.write(item["text"].strip() + "\n\n")
count += 1
if count >= 5000:
break
print(f" Saved {count} documents to {path}")
except Exception as e:
print(f" Skipping openwebtext: {e}")
return None
return None
return path
# ── Text formatters ───────────────────────────────────────────────────────────
def _format_alpaca(item):
"""Format one Alpaca record as a prompt/response string."""
if item.get("input", "").strip():
return (
f"### Instruction:\n{item['instruction']}\n\n"
f"### Input:\n{item['input']}\n\n"
f"### Response:\n{item['output']}\n\n"
)
return (
f"### Instruction:\n{item['instruction']}\n\n"
f"### Response:\n{item['output']}\n\n"
)
# ── Dataset registry ──────────────────────────────────────────────────────────
DATASETS = {
"shakespeare": get_shakespeare,
"alpaca": get_alpaca,
"openwebtext": get_openwebtext_sample,
}
def load_text(name, data_dir="data", custom_file=None):
"""Return raw text string for the given dataset name."""
if name == "custom":
assert custom_file, "--custom_file required for custom dataset"
with open(custom_file) as f:
return f.read()
if name == "alpaca":
path = get_alpaca(data_dir)
with open(path) as f:
records = json.load(f)
return "".join(_format_alpaca(r) for r in records)
if name not in DATASETS:
raise ValueError(f"Unknown dataset '{name}'. Choose from: {list(DATASETS)}, custom")
path = DATASETS[name](data_dir)
if path is None:
return ""
with open(path) as f:
return f.read()
def build_combined_text(names, data_dir="data", custom_file=None, weights=None):
"""
Load and concatenate multiple datasets.
weights: list of floats (same length as names) to sub-sample each source.
e.g. [1.0, 0.5] uses 100% of names[0] and 50% of names[1].
"""
if weights is None:
weights = [1.0] * len(names)
assert len(weights) == len(names)
parts = []
for name, w in zip(names, weights):
print(f"Loading dataset: {name} (weight {w})")
text = load_text(name, data_dir=data_dir, custom_file=custom_file)
if not text:
continue
if w < 1.0:
cut = int(len(text) * w)
text = text[:cut]
parts.append(text)
print(f" {name}: {len(text):,} chars")
combined = "\n\n".join(parts)
print(f"Total combined: {len(combined):,} chars")
return combined
# ── Token tensor builder ──────────────────────────────────────────────────────
def tokenize_and_split(text, tokenizer, split_ratio=0.9):
tokens = tokenizer.encode(text)
data = torch.tensor(tokens, dtype=torch.long)
n = int(split_ratio * len(data))
return data[:n], data[n:]