| import os |
| import time |
| import math |
| import torch |
| from model import GPTConfig, GPT |
|
|
| import numpy as np |
|
|
| |
| out_dir = '/media/leo/Data/checkpoints/350m_Apex_1.5_Code' |
| init_from_file = '/media/leo/Data/checkpoints/350m_Apex_1.5_Final_NEW_More_Anti_Forgetting/Apex_1.5_Final.pt' |
| dataset = 'apex_code_boost' |
|
|
| |
| batch_size = 4 |
| gradient_accumulation_steps = 32 |
| block_size = 1024 |
| learning_rate = 1e-5 |
| max_iters = 1000 |
| weight_decay = 0.1 |
| dropout = 0.1 |
| warmup_iters = 50 |
| min_lr = 1e-6 |
| beta1, beta2 = 0.9, 0.95 |
| device = 'cuda' |
| dtype = 'bfloat16' |
| compile = True |
| save_interval = 500 |
| |
|
|
| os.makedirs(out_dir, exist_ok=True) |
| torch.manual_seed(1337) |
| device_type = 'cuda' if 'cuda' in device else 'cpu' |
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
| ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
| |
| data_dir = os.path.join('data', dataset) |
| train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') |
| train_mask = np.memmap(os.path.join(data_dir, 'train_mask.bin'), dtype=np.uint8, mode='r') |
|
|
| def get_batch(): |
| ix = torch.randint(len(train_data) - block_size, (batch_size,)) |
| x = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) |
| y = torch.stack([torch.from_numpy((train_data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) |
| |
| m = torch.stack([torch.from_numpy((train_mask[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) |
| |
| |
| |
| y[m == 0] = -100 |
| |
| x, y = x.to(device), y.to(device) |
| return x, y |
|
|
| |
| print(f"📥 Lade Apex 1.5 Final als Basis...") |
| checkpoint = torch.load(init_from_file, map_location=device) |
| gptconf = GPTConfig(**checkpoint['model_args']) |
| model = GPT(gptconf) |
| state_dict = checkpoint['model'] |
|
|
| |
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
|
|
| if compile: |
| print("🚀 Kompiliere Modell...") |
| model = torch.compile(model) |
|
|
| optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) |
| scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) |
|
|
| |
| def get_lr(it): |
| if it < warmup_iters: return learning_rate * it / warmup_iters |
| if it > max_iters: return min_lr |
| decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return min_lr + coeff * (learning_rate - min_lr) |
|
|
| |
| print(f"🛠️ Starte Finetuning: Apex 1.5 lernt Coden...") |
| model.train() |
| t0 = time.time() |
|
|
| for iter_num in range(max_iters + 1): |
| lr = get_lr(iter_num) |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
|
|
| for micro_step in range(gradient_accumulation_steps): |
| X, Y = get_batch() |
| with ctx: |
| logits, loss = model(X, Y) |
| loss = loss / gradient_accumulation_steps |
| scaler.scale(loss).backward() |
|
|
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| if iter_num % 10 == 0: |
| dt = time.time() - t0 |
| print(f"Iter {iter_num}: Loss {loss.item()*gradient_accumulation_steps:.4f}, Zeit {dt*1000:.2f}ms, LR {lr:.2e}") |
| t0 = time.time() |
|
|
| if iter_num > 0 and iter_num % save_interval == 0: |
| checkpoint_name = f'Apex_1.5_Code_iter_{iter_num}.pt' |
| save_path = os.path.join(out_dir, checkpoint_name) |
| print(f"💾 Speichere Zwischen-Checkpoint: {checkpoint_name}") |
| raw_model = model._orig_mod if compile else model |
| checkpoint_data = { |
| 'model': raw_model.state_dict(), |
| 'model_args': checkpoint['model_args'], |
| 'iter_num': iter_num, |
| 'lr': lr, |
| } |
| torch.save(checkpoint_data, save_path) |
|
|
| |
| print(f"💾 Finetuning beendet. Speichere Apex 1.5 Code...") |
| final_checkpoint = { |
| 'model': model.state_dict() if not compile else model._orig_mod.state_dict(), |
| 'model_args': checkpoint['model_args'], |
| 'config': checkpoint.get('config', {}), |
| } |
| torch.save(final_checkpoint, os.path.join(out_dir, 'Apex_1.5_Code_Final.pt')) |
| print("✅ Apex 1.5 Code wurde erfolgreich gespeichert!") |
|
|