LH-Tech-AI commited on
Commit
37f5946
Β·
verified Β·
1 Parent(s): 4117fe6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +627 -1
README.md CHANGED
@@ -1,4 +1,630 @@
1
- Test
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  ---
4
  license: apache-2.0
 
1
+ Welcome to SmaLLMPro 350M, our latest Instruct-Model based on FineWeb-Edu.
2
+
3
+ # 1. Parameters
4
+ SmaLLMPro 350M has 353,550,000 parameters.
5
+ 1. Decayed parameter tensors: 98, with 354,549,760 parameters
6
+ 2. Non-decayed parameter tensors: 49, with 50,176 parameters
7
+
8
+ # 2. Trainingcode
9
+ ```python
10
+ import os
11
+ import time
12
+ import math
13
+ import pickle
14
+ from contextlib import nullcontext
15
+
16
+ import queue
17
+
18
+ import logging
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.distributed import init_process_group, destroy_process_group
24
+
25
+ from model import GPTConfig, GPT
26
+
27
+ # -----------------------------------------------------------------------------
28
+ # default config values designed to train a gpt2 (124M) on OpenWebText
29
+ # I/O
30
+ out_dir = 'out'
31
+ eval_interval = 2000
32
+ log_interval = 1
33
+ eval_iters = 200
34
+ eval_only = False # if True, script exits right after the first eval
35
+ always_save_checkpoint = True # if True, always save a checkpoint after each eval
36
+ init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
37
+ # wandb logging
38
+ wandb_log = False # disabled by default
39
+ wandb_project = 'owt'
40
+ wandb_run_name = 'gpt2' # 'run' + str(time.time())
41
+ # data
42
+ dataset = 'openwebtext'
43
+ gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
44
+ batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
45
+ block_size = 1024
46
+ # model
47
+ n_layer = 12
48
+ n_head = 12
49
+ n_embd = 768
50
+ dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
51
+ bias = False # do we use bias inside LayerNorm and Linear layers?
52
+ # adamw optimizer
53
+ learning_rate = 6e-4 # max learning rate
54
+ max_iters = 600000 # total number of training iterations
55
+ weight_decay = 1e-1
56
+ beta1 = 0.9
57
+ beta2 = 0.95
58
+ grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
59
+ # learning rate decay settings
60
+ decay_lr = True # whether to decay the learning rate
61
+ warmup_iters = 2000 # how many steps to warm up for
62
+ lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
63
+ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
64
+ # DDP settings
65
+ backend = 'nccl' # 'nccl', 'gloo', etc.
66
+ # system
67
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
68
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
69
+ compile = True # use PyTorch 2.0 to compile the model to be faster
70
+ # -----------------------------------------------------------------------------
71
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
72
+ exec(open('configurator.py').read()) # overrides from command line or config file
73
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
74
+ # -----------------------------------------------------------------------------
75
+
76
+ logger = None
77
+ db_conn = None
78
+
79
+ logging.basicConfig(
80
+ level=logging.INFO,
81
+ format='%(asctime)s %(levelname)s: %(message)s',
82
+ handlers=[logging.StreamHandler()]
83
+ )
84
+ logger = logging.getLogger("Train")
85
+
86
+ # various inits, derived attributes, I/O setup
87
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
88
+ if ddp:
89
+ init_process_group(backend=backend)
90
+ ddp_rank = int(os.environ['RANK'])
91
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
92
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
93
+ device = f'cuda:{ddp_local_rank}'
94
+ torch.cuda.set_device(device)
95
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
96
+ seed_offset = ddp_rank # each process gets a different seed
97
+ # world_size number of processes will be training simultaneously, so we can scale
98
+ # down the desired gradient accumulation iterations per process proportionally
99
+ assert gradient_accumulation_steps % ddp_world_size == 0
100
+ gradient_accumulation_steps //= ddp_world_size
101
+ else:
102
+ # if not ddp, we are running on a single gpu, and one process
103
+ master_process = True
104
+ seed_offset = 0
105
+ ddp_world_size = 1
106
+ tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
107
+ logger.info(f"tokens per iteration will be: {tokens_per_iter:,}")
108
+
109
+
110
+ if master_process:
111
+ os.makedirs(out_dir, exist_ok=True)
112
+ log_dir = "/home/350m_fineweb"
113
+ os.makedirs(log_dir, exist_ok=True)
114
+ log_file = os.path.join(log_dir, "training.log")
115
+
116
+ file_handler = logging.FileHandler(log_file)
117
+ file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s'))
118
+ logger.addHandler(file_handler)
119
+
120
+ logger.info(f"Logging in Datei gestartet: {log_file}")
121
+
122
+ torch.manual_seed(1337 + seed_offset)
123
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
124
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
125
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
126
+ # note: float16 data type will automatically use a GradScaler
127
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
128
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
129
+
130
+ # poor man's data loader
131
+
132
+ data_handles = {
133
+ split: {
134
+ name: np.memmap(os.path.join(path, f'{split}.bin'), dtype=np.uint16, mode='r')
135
+ for name, path in data_sources.items()
136
+ }
137
+ for split in ['train', 'val']
138
+ }
139
+
140
+ def get_batch(split):
141
+ source = 'fineweb'
142
+ data = data_handles[split][source]
143
+
144
+ ix = torch.randint(len(data) - block_size, (batch_size,))
145
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
146
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
147
+
148
+ if device_type == 'cuda':
149
+ # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
150
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
151
+ else:
152
+ x, y = x.to(device), y.to(device)
153
+ return x, y
154
+
155
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
156
+ iter_num = 0
157
+ best_val_loss = 1e9
158
+
159
+ # attempt to derive vocab_size from the dataset
160
+ meta_path = os.path.join(data_sources['fineweb'], 'meta.pkl')
161
+ meta_vocab_size = None
162
+ if os.path.exists(meta_path):
163
+ with open(meta_path, 'rb') as f:
164
+ meta = pickle.load(f)
165
+ meta_vocab_size = meta['vocab_size']
166
+ logger.info(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
167
+
168
+ # model init
169
+ model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
170
+ bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
171
+ if init_from == 'scratch':
172
+ # init a new model from scratch
173
+ logger.info("Initializing a new model from scratch")
174
+ # determine the vocab size we'll use for from-scratch training
175
+ if meta_vocab_size is None:
176
+ logger.info("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
177
+ model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
178
+ gptconf = GPTConfig(**model_args)
179
+ model = GPT(gptconf)
180
+ elif init_from == 'resume':
181
+ logger.info(f"Resuming training from {out_dir}")
182
+ # resume training from a checkpoint.
183
+ ckpt_path = os.path.join(out_dir, sorted(
184
+ [f for f in os.listdir(out_dir) if f.startswith("ckpt_") and f.endswith(".pt")]
185
+ )[-1])
186
+ checkpoint = torch.load(ckpt_path, map_location=device)
187
+ checkpoint_model_args = checkpoint['model_args']
188
+ # force these config attributes to be equal otherwise we can't even resume training
189
+ # the rest of the attributes (e.g. dropout) can stay as desired from command line
190
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
191
+ model_args[k] = checkpoint_model_args[k]
192
+ # create the model
193
+ gptconf = GPTConfig(**model_args)
194
+ model = GPT(gptconf)
195
+ state_dict = checkpoint['model']
196
+ # fix the keys of the state dictionary :(
197
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
198
+ unwanted_prefix = '_orig_mod.'
199
+ for k,v in list(state_dict.items()):
200
+ if k.startswith(unwanted_prefix):
201
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
202
+ model.load_state_dict(state_dict)
203
+ iter_num = checkpoint['iter_num']
204
+ best_val_loss = checkpoint['best_val_loss']
205
+ elif init_from.startswith('gpt2'):
206
+ logger.info(f"Initializing from OpenAI GPT-2 weights: {init_from}")
207
+ # initialize from OpenAI GPT-2 weights
208
+ override_args = dict(dropout=dropout)
209
+ model = GPT.from_pretrained(init_from, override_args)
210
+ # read off the created config params, so we can store them into checkpoint correctly
211
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
212
+ model_args[k] = getattr(model.config, k)
213
+ # crop down the model block size if desired, using model surgery
214
+ if block_size < model.config.block_size:
215
+ model.crop_block_size(block_size)
216
+ model_args['block_size'] = block_size # so that the checkpoint will have the right value
217
+ model.to(device)
218
+
219
+ # initialize a GradScaler. If enabled=False scaler is a no-op
220
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
221
+
222
+ # optimizer
223
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
224
+ if init_from == 'resume':
225
+ optimizer.load_state_dict(checkpoint['optimizer'])
226
+ checkpoint = None # free up memory
227
+
228
+ # compile the model
229
+ if compile:
230
+ logger.info("compiling the model... (takes a ~minute)")
231
+ unoptimized_model = model
232
+ model = torch.compile(model) # requires PyTorch 2.0
233
+
234
+ # wrap model into DDP container
235
+ if ddp:
236
+ model = DDP(model, device_ids=[ddp_local_rank])
237
+
238
+ # helps estimate an arbitrarily accurate loss over either split using many batches
239
+ @torch.no_grad()
240
+ def estimate_loss():
241
+ out = {}
242
+ model.eval()
243
+ for split in ['train', 'val']:
244
+ losses = torch.zeros(eval_iters)
245
+ for k in range(eval_iters):
246
+ X, Y = get_batch(split)
247
+ with ctx:
248
+ logits, loss = model(X, Y)
249
+ losses[k] = loss.item()
250
+ out[split] = losses.mean()
251
+ model.train()
252
+ return out
253
+
254
+ # learning rate decay scheduler (cosine with warmup)
255
+ def get_lr(it):
256
+ # 1) linear warmup for warmup_iters steps
257
+ if it < warmup_iters:
258
+ return learning_rate * (it + 1) / (warmup_iters + 1)
259
+ # 2) if it > lr_decay_iters, return min learning rate
260
+ if it > lr_decay_iters:
261
+ return min_lr
262
+ # 3) in between, use cosine decay down to min learning rate
263
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
264
+ assert 0 <= decay_ratio <= 1
265
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
266
+ return min_lr + coeff * (learning_rate - min_lr)
267
+
268
+ # logging
269
+ if wandb_log and master_process:
270
+ import wandb
271
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
272
+
273
+ # training loop
274
+ X, Y = get_batch('train') # fetch the very first batch
275
+ t0 = time.time()
276
+ local_iter_num = 0 # number of iterations in the lifetime of this process
277
+ raw_model = model.module if ddp else model # unwrap DDP container if needed
278
+ running_mfu = -1.0
279
+ while True:
280
+
281
+ # determine and set the learning rate for this iteration
282
+ lr = get_lr(iter_num) if decay_lr else learning_rate
283
+ for param_group in optimizer.param_groups:
284
+ param_group['lr'] = lr
285
+
286
+ # evaluate the loss on train/val sets and write checkpoints
287
+ if iter_num % eval_interval == 0 and master_process:
288
+ losses = estimate_loss()
289
+ logger.info(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
290
+ if wandb_log:
291
+ wandb.log({
292
+ "iter": iter_num,
293
+ "train/loss": losses['train'],
294
+ "val/loss": losses['val'],
295
+ "lr": lr,
296
+ "mfu": running_mfu*100, # convert to percentage
297
+ })
298
+ if losses['val'] < best_val_loss or always_save_checkpoint:
299
+ best_val_loss = losses['val']
300
+ if iter_num > 0:
301
+ checkpoint = {
302
+ 'model': raw_model.state_dict(),
303
+ 'optimizer': optimizer.state_dict(),
304
+ 'model_args': model_args,
305
+ 'iter_num': iter_num,
306
+ 'best_val_loss': best_val_loss,
307
+ 'config': config,
308
+ }
309
+ logger.info(f"πŸ’Ύ SAVING CHECKPOINT TO {out_dir}")
310
+ ckpt_name = f"ckpt_{iter_num:07d}.pt"
311
+ ckpt_path = os.path.join(out_dir, ckpt_name)
312
+ torch.save(checkpoint, ckpt_path)
313
+ if iter_num == 0 and eval_only:
314
+ break
315
+
316
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
317
+ # and using the GradScaler if data type is float16
318
+ for micro_step in range(gradient_accumulation_steps):
319
+ if ddp:
320
+ # in DDP training we only need to sync gradients at the last micro step.
321
+ # the official way to do this is with model.no_sync() context manager, but
322
+ # I really dislike that this bloats the code and forces us to repeat code
323
+ # looking at the source of that context manager, it just toggles this variable
324
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
325
+ with ctx:
326
+ logits, loss = model(X, Y)
327
+ loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
328
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
329
+ X, Y = get_batch('train')
330
+ # backward pass, with gradient scaling if training in fp16
331
+ scaler.scale(loss).backward()
332
+ # clip the gradient
333
+ if grad_clip != 0.0:
334
+ scaler.unscale_(optimizer)
335
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
336
+ # step the optimizer and scaler if training in fp16
337
+ scaler.step(optimizer)
338
+ scaler.update()
339
+ # flush the gradients as soon as we can, no need for this memory anymore
340
+ optimizer.zero_grad(set_to_none=True)
341
+
342
+ # timing and logging
343
+ t1 = time.time()
344
+ dt = t1 - t0
345
+ t0 = t1
346
+ if iter_num % log_interval == 0 and master_process:
347
+ # get loss as float. note: this is a CPU-GPU sync point
348
+ # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
349
+ lossf = loss.item() * gradient_accumulation_steps
350
+ if local_iter_num >= 5: # let the training loop settle a bit
351
+ mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
352
+ running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
353
+
354
+ if logger:
355
+ log_msg = f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%"
356
+ logger.info(log_msg)
357
+
358
+
359
+ if iter_num % 100 == 0:
360
+
361
+ remaining_iters = max_iters - iter_num
362
+ est_seconds = remaining_iters * dt
363
+ days = int(est_seconds // 86400)
364
+ hours = int((est_seconds % 86400) // 3600)
365
+ minutes = int((est_seconds % 3600) // 60)
366
+
367
+ logger.info(f"⏳ ETA: Resttime ca. {days}d, {hours}h, {minutes}m until iteration {max_iters}")
368
+ logger.info("πŸ“ LIVE-SAMPLE:")
369
+
370
+ model.eval()
371
+
372
+ with torch.no_grad():
373
+ import tiktoken
374
+ enc = tiktoken.get_encoding("gpt2")
375
+
376
+ prompt = "Artificial Intelligence is "
377
+ start_ids = enc.encode(prompt, allowed_special={""})
378
+ context = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0)
379
+
380
+ generated_tokens = raw_model.generate(context, max_new_tokens=200)[0].tolist()
381
+
382
+ valid_tokens = [t for t in generated_tokens if t < enc.n_vocab]
383
+
384
+ try:
385
+ decoded_text = enc.decode(valid_tokens, errors='replace')
386
+ logger.info(f"\n{decoded_text}")
387
+ except Exception as e:
388
+ logger.error(f"Sampling-Fehler: {e}")
389
+
390
+ model.train()
391
+ logger.info("-" * 50)
392
+ iter_num += 1
393
+ local_iter_num += 1
394
+
395
+ # termination conditions
396
+ if iter_num > max_iters:
397
+ break
398
+
399
+ if ddp:
400
+ destroy_process_group()
401
+ ```
402
+
403
+ To use this code, first you'll have to clone the nanoGPT git repository from Karpathy.
404
+
405
+ Then, run:
406
+
407
+ ```bash
408
+ python3 train.py \
409
+ --dataset=fineweb-edu \
410
+ --n_layer=24 \
411
+ --n_head=16 \
412
+ --n_embd=1024 \
413
+ --block_size=1024 \
414
+ --batch_size=8 \
415
+ --gradient_accumulation_steps=16 \
416
+ --learning_rate=6e-4 \
417
+ --max_iters=300000 \
418
+ --eval_interval=1000 \
419
+ --eval_iters=100 \
420
+ --log_interval=5 \
421
+ --weight_decay=0.1 \
422
+ --warmup_iters=2000 \
423
+ --lr_decay_iters=300000 \
424
+ --min_lr=6e-5 \
425
+ --dtype=bfloat16 \
426
+ --compile=True \
427
+ --always_save_checkpoint=True \
428
+ --init_from=scratch \
429
+ --out_dir=/home/user/350m_fineweb
430
+ ```
431
+
432
+ # 3. Finetuning
433
+ To finally finetune your model to answer your questions, run this code to prepare your data:
434
+ ```python
435
+ import os
436
+ import numpy as np
437
+ import tiktoken
438
+ from datasets import load_dataset
439
+ from tqdm import tqdm
440
+
441
+ OUTPUT_DIR = "data/alpaca_cleaned_mixed"
442
+ OUTPUT_FILE = os.path.join(OUTPUT_DIR, "train.bin")
443
+ enc = tiktoken.get_encoding("gpt2")
444
+
445
+ def format_prompt(instruction, input_text, output):
446
+ sys_msg = "### System:\nYou are SmaLLMPro, a helpful AI Assistant developed by LH-Tech AI.\n\n"
447
+ if input_text and input_text.strip() != "":
448
+ return f"{sys_msg}### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}<|endoftext|>"
449
+ else:
450
+ return f"{sys_msg}### Instruction:\n{instruction}\n\n### Response:\n{output}<|endoftext|>"
451
+
452
+ def main():
453
+ print("πŸš€ Starting data preparation fo SmaLLMPro...")
454
+
455
+ print("πŸ“₯ Loading Alpaca-Cleaned dataset from Huggingface...")
456
+ alpaca = load_dataset("yahma/alpaca-cleaned", split='train')
457
+
458
+ identity_prompts = [
459
+ {"instruction": "Who are you?", "input": "", "output": "I am SmaLLMPro, a helpful AI Assistant developed by LH-Tech AI. What do you want to talk about?"},
460
+ {"instruction": "Who developed you?", "input": "", "output": "I was developed by LH-Tech AI. Is there anything I can assist you with?"},
461
+ {"instruction": "What is you name?", "input": "", "output": "My name is SmaLLMPro. How can I help you today?"}
462
+ ]
463
+
464
+ print("πŸ“₯ Loading small FineWeb-Edu subset (anti-forgetting)...")
465
+ fineweb = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split='train', streaming=True)
466
+ fw_iter = iter(fineweb)
467
+
468
+ all_tokens = []
469
+
470
+ print("πŸ“ Adding identity...")
471
+ for p in identity_prompts:
472
+ text = format_prompt(p['instruction'], p['input'], p['output'])
473
+ all_tokens.extend(enc.encode(text, allowed_special={"<|endoftext|>"}))
474
+
475
+ print("πŸ“ Adding Alpaca-Cleaned...")
476
+ for ex in tqdm(alpaca, desc="Alpaca"):
477
+ text = format_prompt(ex['instruction'], ex['input'], ex['output'])
478
+ all_tokens.extend(enc.encode(text, allowed_special={"<|endoftext|>"}))
479
+
480
+ print("πŸ“ Adding FineWeb-Edu (Anti-Forgetting)...")
481
+ for _ in tqdm(range(2500), desc="FineWeb"):
482
+ try:
483
+ ex = next(fw_iter)
484
+ text = ex['text'] + "<|endoftext|>"
485
+ all_tokens.extend(enc.encode(text, allowed_special={"<|endoftext|>"}))
486
+ except StopIteration:
487
+ break
488
+
489
+ print(f"πŸ’Ύ Converting to numpy (Tokens: {len(all_tokens):,})...")
490
+ arr = np.array(all_tokens, dtype=np.uint16)
491
+
492
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
493
+ arr.tofile(OUTPUT_FILE)
494
+
495
+ print(f"\nβœ… Done! Binary file saved as: {OUTPUT_FILE}")
496
+
497
+ if __name__ == "__main__":
498
+ main()
499
+ ```
500
+
501
+ Finally, run this to start the finetuning based on your prepared finetuning data:
502
+ ```python
503
+ import os
504
+ import time
505
+ import math
506
+ import torch
507
+ from model import GPTConfig, GPT
508
+
509
+ import numpy as np
510
+
511
+ out_dir = '/home/user/checkpoints/350m_SmaLLMPro_Final'
512
+ init_from = '/home/user/350m_fineweb'
513
+ dataset = 'alpaca_cleaned_mixed'
514
+
515
+ batch_size = 4
516
+ gradient_accumulation_steps = 32
517
+ block_size = 1024
518
+ learning_rate = 3e-5
519
+ max_iters = 3000
520
+ weight_decay = 0.1
521
+ dropout = 0.1
522
+ warmup_iters = 100
523
+ min_lr = 3e-6
524
+ beta1, beta2 = 0.9, 0.95
525
+ device = 'cuda'
526
+ dtype = 'bfloat16'
527
+ compile = True
528
+ save_interval = 1000
529
+
530
+ os.makedirs(out_dir, exist_ok=True)
531
+ torch.manual_seed(1337)
532
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
533
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
534
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
535
+
536
+ data_dir = os.path.join('data', dataset)
537
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
538
+
539
+ def get_batch():
540
+ ix = torch.randint(len(train_data) - block_size, (batch_size,))
541
+ x = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix])
542
+ y = torch.stack([torch.from_numpy((train_data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
543
+ x, y = x.to(device), y.to(device)
544
+ return x, y
545
+
546
+ print(f"πŸ“₯ Loading Pretraining-Checkpoint from {init_from}...")
547
+ ckpt_files = sorted([f for f in os.listdir(init_from) if f.endswith('.pt')])
548
+ if not ckpt_files:
549
+ raise FileNotFoundError("No checkpoint in init_from folder found!")
550
+
551
+ ckpt_path = os.path.join(init_from, ckpt_files[-1])
552
+ checkpoint = torch.load(ckpt_path, map_location=device)
553
+ gptconf = GPTConfig(**checkpoint['model_args'])
554
+ model = GPT(gptconf)
555
+ state_dict = checkpoint['model']
556
+
557
+ unwanted_prefix = '_orig_mod.'
558
+ for k,v in list(state_dict.items()):
559
+ if k.startswith(unwanted_prefix):
560
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
561
+
562
+ model.load_state_dict(state_dict)
563
+ model.to(device)
564
+
565
+ if compile:
566
+ print("πŸš€ Compiling model...")
567
+ model = torch.compile(model)
568
+
569
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
570
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
571
+
572
+ def get_lr(it):
573
+ if it < warmup_iters: return learning_rate * it / warmup_iters
574
+ if it > max_iters: return min_lr
575
+ decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
576
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
577
+ return min_lr + coeff * (learning_rate - min_lr)
578
+
579
+ print(f"πŸ› οΈ Starting finetuning...")
580
+ model.train()
581
+ t0 = time.time()
582
+
583
+ for iter_num in range(max_iters + 1):
584
+ lr = get_lr(iter_num)
585
+ for param_group in optimizer.param_groups:
586
+ param_group['lr'] = lr
587
+
588
+ for micro_step in range(gradient_accumulation_steps):
589
+ X, Y = get_batch()
590
+ with ctx:
591
+ logits, loss = model(X, Y)
592
+ loss = loss / gradient_accumulation_steps
593
+ scaler.scale(loss).backward()
594
+
595
+ scaler.unscale_(optimizer)
596
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
597
+ scaler.step(optimizer)
598
+ scaler.update()
599
+ optimizer.zero_grad(set_to_none=True)
600
+
601
+ if iter_num % 10 == 0:
602
+ dt = time.time() - t0
603
+ print(f"Iter {iter_num}: Loss {loss.item()*gradient_accumulation_steps:.4f}, Zeit {dt*1000:.2f}ms, LR {lr:.2e}")
604
+ t0 = time.time()
605
+
606
+ if iter_num > 0 and iter_num % save_interval == 0:
607
+ checkpoint_name = f'SmaLLMPro_iter_{iter_num}.pt'
608
+ save_path = os.path.join(out_dir, checkpoint_name)
609
+ print(f"πŸ’Ύ Saving checkpoint: {checkpoint_name}")
610
+ raw_model = model._orig_mod if compile else model
611
+ checkpoint_data = {
612
+ 'model': raw_model.state_dict(),
613
+ 'model_args': checkpoint['model_args'],
614
+ 'iter_num': iter_num,
615
+ 'lr': lr,
616
+ }
617
+ torch.save(checkpoint_data, save_path)
618
+
619
+ print(f"πŸ’Ύ Finetuning done. Saving SmaLLMPro final checkpoint...")
620
+ final_checkpoint = {
621
+ 'model': model.state_dict() if not compile else model._orig_mod.state_dict(),
622
+ 'model_args': checkpoint['model_args'],
623
+ 'config': checkpoint.get('config', {}),
624
+ }
625
+ torch.save(final_checkpoint, os.path.join(out_dir, 'SmaLLMPro_Final.pt'))
626
+ print("βœ… SmaLLMPro saved successfully!")
627
+ ```
628
 
629
  ---
630
  license: apache-2.0