File size: 5,252 Bytes
1059a9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Dataset loaders for multi-source training.

Supported datasets:
  shakespeare  β€” Tiny Shakespeare (~1M chars, classic GPT demo)
  alpaca       β€” Stanford Alpaca 52K instruction-following examples
  openwebtext  β€” Small OpenWebText sample from HuggingFace (~1GB)
  custom       β€” Any local .txt file passed via --custom_file
"""
import os
import json
import urllib.request

import torch


# ── Download helpers ──────────────────────────────────────────────────────────

def _download(url, dest):
    os.makedirs(os.path.dirname(dest), exist_ok=True)
    if not os.path.exists(dest):
        print(f"  Downloading {os.path.basename(dest)} ...")
        urllib.request.urlretrieve(url, dest)
        print(f"  Saved to {dest}")
    return dest


def get_shakespeare(data_dir="data"):
    return _download(
        "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
        os.path.join(data_dir, "shakespeare.txt"),
    )


def get_alpaca(data_dir="data"):
    path = os.path.join(data_dir, "alpaca.json")
    _download(
        "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json",
        path,
    )
    return path


def get_openwebtext_sample(data_dir="data"):
    """Downloads a sample of web text via the HuggingFace datasets API."""
    path = os.path.join(data_dir, "openwebtext_sample.txt")
    if os.path.exists(path):
        return path
    try:
        from datasets import load_dataset
        print("  Loading web text sample via HuggingFace datasets ...")
        # Use a small slice of HuggingFace's text datasets that work without scripts
        ds = load_dataset("Skylion007/openwebtext", split="train", streaming=True, trust_remote_code=False)
        count = 0
        with open(path, "w") as f:
            for item in ds:
                f.write(item["text"].strip() + "\n\n")
                count += 1
                if count >= 5000:
                    break
        print(f"  Saved {count} documents to {path}")
    except Exception as e:
        print(f"  Skipping openwebtext: {e}")
        return None
        return None
    return path


# ── Text formatters ───────────────────────────────────────────────────────────

def _format_alpaca(item):
    """Format one Alpaca record as a prompt/response string."""
    if item.get("input", "").strip():
        return (
            f"### Instruction:\n{item['instruction']}\n\n"
            f"### Input:\n{item['input']}\n\n"
            f"### Response:\n{item['output']}\n\n"
        )
    return (
        f"### Instruction:\n{item['instruction']}\n\n"
        f"### Response:\n{item['output']}\n\n"
    )


# ── Dataset registry ──────────────────────────────────────────────────────────

DATASETS = {
    "shakespeare": get_shakespeare,
    "alpaca": get_alpaca,
    "openwebtext": get_openwebtext_sample,
}


def load_text(name, data_dir="data", custom_file=None):
    """Return raw text string for the given dataset name."""
    if name == "custom":
        assert custom_file, "--custom_file required for custom dataset"
        with open(custom_file) as f:
            return f.read()

    if name == "alpaca":
        path = get_alpaca(data_dir)
        with open(path) as f:
            records = json.load(f)
        return "".join(_format_alpaca(r) for r in records)

    if name not in DATASETS:
        raise ValueError(f"Unknown dataset '{name}'. Choose from: {list(DATASETS)}, custom")

    path = DATASETS[name](data_dir)
    if path is None:
        return ""
    with open(path) as f:
        return f.read()


def build_combined_text(names, data_dir="data", custom_file=None, weights=None):
    """
    Load and concatenate multiple datasets.
    weights: list of floats (same length as names) to sub-sample each source.
             e.g. [1.0, 0.5] uses 100% of names[0] and 50% of names[1].
    """
    if weights is None:
        weights = [1.0] * len(names)
    assert len(weights) == len(names)

    parts = []
    for name, w in zip(names, weights):
        print(f"Loading dataset: {name} (weight {w})")
        text = load_text(name, data_dir=data_dir, custom_file=custom_file)
        if not text:
            continue
        if w < 1.0:
            cut = int(len(text) * w)
            text = text[:cut]
        parts.append(text)
        print(f"  {name}: {len(text):,} chars")

    combined = "\n\n".join(parts)
    print(f"Total combined: {len(combined):,} chars")
    return combined


# ── Token tensor builder ──────────────────────────────────────────────────────

def tokenize_and_split(text, tokenizer, split_ratio=0.9):
    tokens = tokenizer.encode(text)
    data = torch.tensor(tokens, dtype=torch.long)
    n = int(split_ratio * len(data))
    return data[:n], data[n:]