kgrabko commited on
Commit
4bedc60
·
verified ·
1 Parent(s): 5975847

Upload fine_tune3b_with_validation_no_torchscript.py

Browse files
fine_tune3b_with_validation_no_torchscript.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ #
4
+ # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
+ # this code under the terms of the GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
6
+ # please read <http://www.gnu.org/licenses/>.
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from transformers import GPT2TokenizerFast
14
+ from tqdm import tqdm
15
+ import shutil
16
+ import math
17
+ from pathlib import Path
18
+ import re
19
+
20
+ from gpt_jit_modern_3b import JiRackPyTorch
21
+
22
+ # ============================= SETTINGS =============================
23
+ TRAIN_SEQ_LEN = 256
24
+ BATCH_SIZE = 2
25
+ ACCUM_STEPS = 16
26
+ EPOCHS = 500
27
+ LEARNING_RATE = 3e-5
28
+ WEIGHT_DECAY = 0.01
29
+ GRAD_CLIP = 1.0
30
+ VAL_SPLIT_RATIO = 0.05
31
+ KEEP_LAST_EPOCHS = 3
32
+
33
+ # === PATHS ===
34
+ BASE_MODEL_PATH = Path("models/gpt_modern_3b_class.state_dict.pt")
35
+ LAST_TRAINED_PATH = Path("models/gpt_last_modern_3b_class.state_dict.pt")
36
+ BACKUP_DIR = Path("models/backups")
37
+ BACKUP_DIR.mkdir(exist_ok=True, parents=True)
38
+
39
+ RAW_PATH = Path("datasets/dialogues_text.txt")
40
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ # device = torch.device("cpu")
44
+ print(f"Using device: {device}")
45
+
46
+ # === DATASET CLEANING ===
47
+ if not CLEAN_PATH.exists() or RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
48
+ print("Cleaning dataset...")
49
+ text = RAW_PATH.read_text(encoding="utf-8")
50
+ text = re.sub(r' {2,}', ' ', text) # remove extra spaces
51
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
52
+ CLEAN_PATH.write_text(text, encoding="utf-8")
53
+ print(f"Done → {CLEAN_PATH}")
54
+
55
+ DATASET_PATH = CLEAN_PATH
56
+ OUTPUT_DIR = Path("build/fine_tuning_output")
57
+ MODEL_SAVE_NAME = "pytorch_model.bin"
58
+
59
+ # ============================= DATASET =============================
60
+ class TextDataset(Dataset):
61
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, split='train'):
62
+ self.seq_len = seq_len
63
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
64
+ tokenizer.pad_token = tokenizer.eos_token
65
+ text = Path(text_file).read_text(encoding="utf-8")
66
+ tokens = tokenizer.encode(text)
67
+
68
+ sequences = []
69
+ for i in range(0, len(tokens) - seq_len, seq_len):
70
+ sequences.append(tokens[i:i + seq_len + 1]) # +1 for labels
71
+
72
+ split_idx = int(len(sequences) * (1 - VAL_SPLIT_RATIO))
73
+ if split == 'train':
74
+ self.data = sequences[:split_idx]
75
+ else:
76
+ self.data = sequences[split_idx:]
77
+
78
+ print(f"{split.upper()} sequences: {len(self.data):,}")
79
+
80
+ def __len__(self):
81
+ return len(self.data)
82
+
83
+ def __getitem__(self, idx):
84
+ seq = self.data[idx]
85
+ return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long)
86
+
87
+
88
+ def evaluate(model, loader):
89
+ model.eval()
90
+ total_loss = 0
91
+ criterion = nn.CrossEntropyLoss()
92
+ with torch.no_grad():
93
+ for x, y in loader:
94
+ x, y = x.to(device), y.to(device)
95
+ logits, _ = model(x)
96
+ loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
97
+ total_loss += loss.item()
98
+ model.train()
99
+ return total_loss / len(loader)
100
+
101
+
102
+ def train():
103
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
104
+
105
+ print("Loading model...")
106
+ model = JiRackPyTorch().to(device)
107
+
108
+ if LAST_TRAINED_PATH.exists():
109
+ print(f"Resuming from {LAST_TRAINED_PATH}")
110
+ model.load_state_dict(torch.load(LAST_TRAINED_PATH, map_location=device))
111
+ elif BASE_MODEL_PATH.exists():
112
+ print(f"Starting from base model {BASE_MODEL_PATH}")
113
+ model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
114
+ else:
115
+ print("Starting from scratch — random weights")
116
+
117
+ model.train()
118
+
119
+ train_dataset = TextDataset(DATASET_PATH, split='train')
120
+ val_dataset = TextDataset(DATASET_PATH, split='val')
121
+
122
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
123
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
124
+
125
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
126
+ criterion = nn.CrossEntropyLoss()
127
+
128
+ print("\nFULL TRAINING STARTED! No LoRA, no compromises — we're training the whole thing!\n")
129
+
130
+ for epoch in range(1, EPOCHS + 1):
131
+ total_loss = 0
132
+ for step, (x, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}")):
133
+ x, y = x.to(device), y.to(device)
134
+
135
+ logits, _ = model(x)
136
+ loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
137
+ loss = loss / ACCUM_STEPS
138
+ loss.backward()
139
+
140
+ total_loss += loss.item() * ACCUM_STEPS
141
+
142
+ if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
143
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
144
+ optimizer.step()
145
+ optimizer.zero_grad()
146
+
147
+ avg_train_loss = total_loss / len(train_loader)
148
+ val_loss = evaluate(model, val_loader)
149
+
150
+ print(f"\nEpoch {epoch}")
151
+ print(f" Train loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}")
152
+ print(f" Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
153
+
154
+ # Save checkpoint
155
+ save_dir = OUTPUT_DIR / f"epoch_{epoch}"
156
+ save_dir.mkdir(exist_ok=True, parents=True)
157
+ torch.save(model.state_dict(), save_dir / MODEL_SAVE_NAME)
158
+ torch.save(model.state_dict(), LAST_TRAINED_PATH)
159
+
160
+ # Keep only the last N epochs to save disk space
161
+ epochs = sorted([p for p in OUTPUT_DIR.iterdir() if p.is_dir() and p.name.startswith("epoch_")])
162
+ for old in epochs[:-KEEP_LAST_EPOCHS]:
163
+ shutil.rmtree(old)
164
+
165
+ print("\nDONE! Full model trained. You are now the emperor of fine-tuning.")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ train()