File size: 6,166 Bytes
672259a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""

Sample from a trained model



REQUIRED:

  1. You must specify a config file from the config/ directory

  2. All configuration must be in the config file. No CLI overrides allowed



Usage:

  python sample.py <config_file>



Examples:

  python sample.py config/sample_gpt2.py

"""
import sys

# -----------------------------------------------------------------------------
# Configuration loading (BEFORE imports to validate config first)
# Usage:
#   python sample.py <config_file>
# Note: All configuration must be specified in the config file.
# -----------------------------------------------------------------------------

# Parse command line - only accept config file, no --key=value allowed
if len(sys.argv) != 2:
    print("ERROR: Invalid arguments!")
    print("Usage: python sample.py <config_file>")
    print("Available configs in config/:")
    print("  - sample_gpt2.py")
    sys.exit(1)

config_file = sys.argv[1]

# Disallow --key=value arguments
for arg in sys.argv[1:]:
    if arg.startswith('--'):
        print(f"ERROR: CLI overrides are not supported. All config must be in file: {config_file}")
        sys.exit(1)

# Load the specified config file
print(f"Loading config from: {config_file}")
exec(open(config_file).read())

# Validate required config keys
required_keys = ['out_dir', 'init_from', 'model_config']
missing_keys = [k for k in required_keys if k not in globals()]
if missing_keys:
    print(f"ERROR: Missing required config keys: {missing_keys}")
    sys.exit(1)

# Load model configuration
model_config = globals()['model_config']
model_file = f"models/{model_config}.py"
try:
    exec(open(model_file).read())
except FileNotFoundError:
    print(f"ERROR: Model file not found: {model_file}")
    print(f"Available models in models/:")
    import os
    for f in os.listdir('models'):
        if f.endswith('.py') and not f.startswith('_'):
            print(f"  - {f[:-3]}")
    sys.exit(1)

# Get model-specific required config keys from GPTConfig
model_required_keys = []
if 'GPTConfig' in globals():
    config_class = globals()['GPTConfig']
    import dataclasses
    for field in dataclasses.fields(config_class):
        model_required_keys.append(field.name)

# Validate model-specific config keys
# Skip validation for 'resume' mode (loads from checkpoint) and 'gpt2*' mode (loads pretrained)
# Only require model config when init_from='scratch'
if init_from == 'scratch':
    missing_model_keys = [k for k in model_required_keys if k not in globals()]
    if missing_model_keys:
        print(f"ERROR: Missing required model config keys for {model_config}: {missing_model_keys}")
        print(f"Required keys: {model_required_keys}")
        sys.exit(1)

# Print configuration (exclude internal variables)
exclude_keys = {'config_file', 'model_file', 'model_config', 'model_required_keys', 'config_class'}
print("\n" + "=" * 60)
print("SAMPLE CONFIGURATION")
print("=" * 60)
for key in sorted(globals().keys()):
    val = globals().get(key)
    if isinstance(val, (int, float, bool, str)) and key not in exclude_keys and not key.startswith('_'):
        print(f"  {key:30s} = {val}")
print("=" * 60 + "\n")

# Now import dependencies
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken

# Import GPTConfig and GPT from the model file
GPTConfig = globals()['GPTConfig']
GPT = globals()['GPT']

# Auto-detect dtype
if dtype == 'bfloat16' and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
    dtype = 'float16'

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
checkpoint = None
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model)

# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and checkpoint is not None and 'config' in checkpoint and 'dataset' in checkpoint['config']:
    meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
    load_meta = os.path.exists(meta_path)
if load_meta:
    print(f"Loading meta from {meta_path}...")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    stoi, itos = meta['stoi'], meta['itos']
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])
else:
    print("No meta.pkl found, assuming GPT-2 encodings...")
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

# run generation
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            print(decode(y[0].tolist()))
            print('---------------')