File size: 6,166 Bytes
672259a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """
Sample from a trained model
REQUIRED:
1. You must specify a config file from the config/ directory
2. All configuration must be in the config file. No CLI overrides allowed
Usage:
python sample.py <config_file>
Examples:
python sample.py config/sample_gpt2.py
"""
import sys
# -----------------------------------------------------------------------------
# Configuration loading (BEFORE imports to validate config first)
# Usage:
# python sample.py <config_file>
# Note: All configuration must be specified in the config file.
# -----------------------------------------------------------------------------
# Parse command line - only accept config file, no --key=value allowed
if len(sys.argv) != 2:
print("ERROR: Invalid arguments!")
print("Usage: python sample.py <config_file>")
print("Available configs in config/:")
print(" - sample_gpt2.py")
sys.exit(1)
config_file = sys.argv[1]
# Disallow --key=value arguments
for arg in sys.argv[1:]:
if arg.startswith('--'):
print(f"ERROR: CLI overrides are not supported. All config must be in file: {config_file}")
sys.exit(1)
# Load the specified config file
print(f"Loading config from: {config_file}")
exec(open(config_file).read())
# Validate required config keys
required_keys = ['out_dir', 'init_from', 'model_config']
missing_keys = [k for k in required_keys if k not in globals()]
if missing_keys:
print(f"ERROR: Missing required config keys: {missing_keys}")
sys.exit(1)
# Load model configuration
model_config = globals()['model_config']
model_file = f"models/{model_config}.py"
try:
exec(open(model_file).read())
except FileNotFoundError:
print(f"ERROR: Model file not found: {model_file}")
print(f"Available models in models/:")
import os
for f in os.listdir('models'):
if f.endswith('.py') and not f.startswith('_'):
print(f" - {f[:-3]}")
sys.exit(1)
# Get model-specific required config keys from GPTConfig
model_required_keys = []
if 'GPTConfig' in globals():
config_class = globals()['GPTConfig']
import dataclasses
for field in dataclasses.fields(config_class):
model_required_keys.append(field.name)
# Validate model-specific config keys
# Skip validation for 'resume' mode (loads from checkpoint) and 'gpt2*' mode (loads pretrained)
# Only require model config when init_from='scratch'
if init_from == 'scratch':
missing_model_keys = [k for k in model_required_keys if k not in globals()]
if missing_model_keys:
print(f"ERROR: Missing required model config keys for {model_config}: {missing_model_keys}")
print(f"Required keys: {model_required_keys}")
sys.exit(1)
# Print configuration (exclude internal variables)
exclude_keys = {'config_file', 'model_file', 'model_config', 'model_required_keys', 'config_class'}
print("\n" + "=" * 60)
print("SAMPLE CONFIGURATION")
print("=" * 60)
for key in sorted(globals().keys()):
val = globals().get(key)
if isinstance(val, (int, float, bool, str)) and key not in exclude_keys and not key.startswith('_'):
print(f" {key:30s} = {val}")
print("=" * 60 + "\n")
# Now import dependencies
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
# Import GPTConfig and GPT from the model file
GPTConfig = globals()['GPTConfig']
GPT = globals()['GPT']
# Auto-detect dtype
if dtype == 'bfloat16' and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
dtype = 'float16'
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# model
checkpoint = None
if init_from == 'resume':
# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
# init from a given GPT-2 model
model = GPT.from_pretrained(init_from, dict(dropout=0.0))
model.eval()
model.to(device)
if compile:
model = torch.compile(model)
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if init_from == 'resume' and checkpoint is not None and 'config' in checkpoint and 'dataset' in checkpoint['config']:
meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
load_meta = os.path.exists(meta_path)
if load_meta:
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
# encode the beginning of the prompt
if start.startswith('FILE:'):
with open(start[5:], 'r', encoding='utf-8') as f:
start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
# run generation
with torch.no_grad():
with ctx:
for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))
print('---------------')
|