kgrabko commited on
Commit
2c58360
·
verified ·
1 Parent(s): ce972e7

Update fine_tune10b_with_validation_no_torchscript.py

Browse files
fine_tune10b_with_validation_no_torchscript.py CHANGED
@@ -1,169 +1,184 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- #
4
- # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
- # this code under the terms of the GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
6
- # please read <http://www.gnu.org/licenses/>.
7
-
8
- import os
9
- import torch
10
- import torch.nn as nn
11
- import torch.optim as optim
12
- from torch.utils.data import Dataset, DataLoader
13
- from transformers import GPT2TokenizerFast
14
- from tqdm import tqdm
15
- import shutil
16
- import math
17
- from pathlib import Path
18
- import re
19
-
20
- from gpt_jit_modern_10b import JiRackPyTorch
21
-
22
- # ============================= SETTINGS =============================
23
- TRAIN_SEQ_LEN = 256
24
- BATCH_SIZE = 2
25
- ACCUM_STEPS = 16
26
- EPOCHS = 500
27
- LEARNING_RATE = 3e-5
28
- WEIGHT_DECAY = 0.01
29
- GRAD_CLIP = 1.0
30
- VAL_SPLIT_RATIO = 0.05
31
- KEEP_LAST_EPOCHS = 3
32
-
33
- # === PATHS ===
34
- BASE_MODEL_PATH = Path("models/gpt_modern_10b_class.state_dict.pt")
35
- LAST_TRAINED_PATH = Path("models/gpt_last_modern_10b_class.state_dict.pt")
36
- BACKUP_DIR = Path("models/backups")
37
- BACKUP_DIR.mkdir(exist_ok=True, parents=True)
38
-
39
- RAW_PATH = Path("datasets/dialogues_text.txt")
40
- CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
41
-
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- # device = torch.device("cpu")
44
- print(f"Using device: {device}")
45
-
46
- # === DATASET CLEANING ===
47
- if not CLEAN_PATH.exists() or RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
48
- print("Cleaning dataset...")
49
- text = RAW_PATH.read_text(encoding="utf-8")
50
- text = re.sub(r' {2,}', ' ', text) # remove extra spaces
51
- text = text.replace(" \n", "\n").replace("\n ", "\n")
52
- CLEAN_PATH.write_text(text, encoding="utf-8")
53
- print(f"Done → {CLEAN_PATH}")
54
-
55
- DATASET_PATH = CLEAN_PATH
56
- OUTPUT_DIR = Path("build/fine_tuning_output")
57
- MODEL_SAVE_NAME = "pytorch_model.bin"
58
-
59
- # ============================= DATASET =============================
60
- class TextDataset(Dataset):
61
- def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, split='train'):
62
- self.seq_len = seq_len
63
- tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
64
- tokenizer.pad_token = tokenizer.eos_token
65
- text = Path(text_file).read_text(encoding="utf-8")
66
- tokens = tokenizer.encode(text)
67
-
68
- sequences = []
69
- for i in range(0, len(tokens) - seq_len, seq_len):
70
- sequences.append(tokens[i:i + seq_len + 1]) # +1 for labels
71
-
72
- split_idx = int(len(sequences) * (1 - VAL_SPLIT_RATIO))
73
- if split == 'train':
74
- self.data = sequences[:split_idx]
75
- else:
76
- self.data = sequences[split_idx:]
77
-
78
- print(f"{split.upper()} sequences: {len(self.data):,}")
79
-
80
- def __len__(self):
81
- return len(self.data)
82
-
83
- def __getitem__(self, idx):
84
- seq = self.data[idx]
85
- return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long)
86
-
87
-
88
- def evaluate(model, loader):
89
- model.eval()
90
- total_loss = 0
91
- criterion = nn.CrossEntropyLoss()
92
- with torch.no_grad():
93
- for x, y in loader:
94
- x, y = x.to(device), y.to(device)
95
- logits, _ = model(x)
96
- loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
97
- total_loss += loss.item()
98
- model.train()
99
- return total_loss / len(loader)
100
-
101
-
102
- def train():
103
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
104
-
105
- print("Loading model...")
106
- model = JiRackPyTorch().to(device)
107
-
108
- if LAST_TRAINED_PATH.exists():
109
- print(f"Resuming from {LAST_TRAINED_PATH}")
110
- model.load_state_dict(torch.load(LAST_TRAINED_PATH, map_location=device))
111
- elif BASE_MODEL_PATH.exists():
112
- print(f"Starting from base model {BASE_MODEL_PATH}")
113
- model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
114
- else:
115
- print("Starting from scratch — random weights")
116
-
117
- model.train()
118
-
119
- train_dataset = TextDataset(DATASET_PATH, split='train')
120
- val_dataset = TextDataset(DATASET_PATH, split='val')
121
-
122
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
123
- val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
124
-
125
- optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
126
- criterion = nn.CrossEntropyLoss()
127
-
128
- print("\nFULL TRAINING STARTED! No LoRA, no compromises — we're training the whole thing!\n")
129
-
130
- for epoch in range(1, EPOCHS + 1):
131
- total_loss = 0
132
- for step, (x, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}")):
133
- x, y = x.to(device), y.to(device)
134
-
135
- logits, _ = model(x)
136
- loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
137
- loss = loss / ACCUM_STEPS
138
- loss.backward()
139
-
140
- total_loss += loss.item() * ACCUM_STEPS
141
-
142
- if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
143
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
144
- optimizer.step()
145
- optimizer.zero_grad()
146
-
147
- avg_train_loss = total_loss / len(train_loader)
148
- val_loss = evaluate(model, val_loader)
149
-
150
- print(f"\nEpoch {epoch}")
151
- print(f" Train loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}")
152
- print(f" Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
153
-
154
- # Save checkpoint
155
- save_dir = OUTPUT_DIR / f"epoch_{epoch}"
156
- save_dir.mkdir(exist_ok=True, parents=True)
157
- torch.save(model.state_dict(), save_dir / MODEL_SAVE_NAME)
158
- torch.save(model.state_dict(), LAST_TRAINED_PATH)
159
-
160
- # Keep only the last N epochs to save disk space
161
- epochs = sorted([p for p in OUTPUT_DIR.iterdir() if p.is_dir() and p.name.startswith("epoch_")])
162
- for old in epochs[:-KEEP_LAST_EPOCHS]:
163
- shutil.rmtree(old)
164
-
165
- print("\nDONE! Full model trained. You are now the emperor of fine-tuning.")
166
-
167
-
168
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  train()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import os
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.optim as optim
27
+ from torch.utils.data import Dataset, DataLoader
28
+ from transformers import GPT2TokenizerFast
29
+ from tqdm import tqdm
30
+ import shutil
31
+ import math
32
+ from pathlib import Path
33
+ import re
34
+
35
+ from gpt_jit_modern_10b import JiRackPyTorch
36
+
37
+ # ============================= SETTINGS =============================
38
+ TRAIN_SEQ_LEN = 256
39
+ BATCH_SIZE = 2
40
+ ACCUM_STEPS = 16
41
+ EPOCHS = 500
42
+ LEARNING_RATE = 3e-5
43
+ WEIGHT_DECAY = 0.01
44
+ GRAD_CLIP = 1.0
45
+ VAL_SPLIT_RATIO = 0.05
46
+ KEEP_LAST_EPOCHS = 3
47
+
48
+ # === PATHS ===
49
+ BASE_MODEL_PATH = Path("models/gpt_modern_10b_class.state_dict.pt")
50
+ LAST_TRAINED_PATH = Path("models/gpt_last_modern_10b_class.state_dict.pt")
51
+ BACKUP_DIR = Path("models/backups")
52
+ BACKUP_DIR.mkdir(exist_ok=True, parents=True)
53
+
54
+ RAW_PATH = Path("datasets/dialogues_text.txt")
55
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
56
+
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ # device = torch.device("cpu")
59
+ print(f"Using device: {device}")
60
+
61
+ # === DATASET CLEANING ===
62
+ if not CLEAN_PATH.exists() or RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
63
+ print("Cleaning dataset...")
64
+ text = RAW_PATH.read_text(encoding="utf-8")
65
+ text = re.sub(r' {2,}', ' ', text) # remove extra spaces
66
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
67
+ CLEAN_PATH.write_text(text, encoding="utf-8")
68
+ print(f"Done {CLEAN_PATH}")
69
+
70
+ DATASET_PATH = CLEAN_PATH
71
+ OUTPUT_DIR = Path("build/fine_tuning_output")
72
+ MODEL_SAVE_NAME = "pytorch_model.bin"
73
+
74
+ # ============================= DATASET =============================
75
+ class TextDataset(Dataset):
76
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, split='train'):
77
+ self.seq_len = seq_len
78
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+ text = Path(text_file).read_text(encoding="utf-8")
81
+ tokens = tokenizer.encode(text)
82
+
83
+ sequences = []
84
+ for i in range(0, len(tokens) - seq_len, seq_len):
85
+ sequences.append(tokens[i:i + seq_len + 1]) # +1 for labels
86
+
87
+ split_idx = int(len(sequences) * (1 - VAL_SPLIT_RATIO))
88
+ if split == 'train':
89
+ self.data = sequences[:split_idx]
90
+ else:
91
+ self.data = sequences[split_idx:]
92
+
93
+ print(f"{split.upper()} sequences: {len(self.data):,}")
94
+
95
+ def __len__(self):
96
+ return len(self.data)
97
+
98
+ def __getitem__(self, idx):
99
+ seq = self.data[idx]
100
+ return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long)
101
+
102
+
103
+ def evaluate(model, loader):
104
+ model.eval()
105
+ total_loss = 0
106
+ criterion = nn.CrossEntropyLoss()
107
+ with torch.no_grad():
108
+ for x, y in loader:
109
+ x, y = x.to(device), y.to(device)
110
+ logits, _ = model(x)
111
+ loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
112
+ total_loss += loss.item()
113
+ model.train()
114
+ return total_loss / len(loader)
115
+
116
+
117
+ def train():
118
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
119
+
120
+ print("Loading model...")
121
+ model = JiRackPyTorch().to(device)
122
+
123
+ if LAST_TRAINED_PATH.exists():
124
+ print(f"Resuming from {LAST_TRAINED_PATH}")
125
+ model.load_state_dict(torch.load(LAST_TRAINED_PATH, map_location=device))
126
+ elif BASE_MODEL_PATH.exists():
127
+ print(f"Starting from base model {BASE_MODEL_PATH}")
128
+ model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
129
+ else:
130
+ print("Starting from scratch random weights")
131
+
132
+ model.train()
133
+
134
+ train_dataset = TextDataset(DATASET_PATH, split='train')
135
+ val_dataset = TextDataset(DATASET_PATH, split='val')
136
+
137
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
138
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
139
+
140
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
141
+ criterion = nn.CrossEntropyLoss()
142
+
143
+ print("\nFULL TRAINING STARTED! No LoRA, no compromises — we're training the whole thing!\n")
144
+
145
+ for epoch in range(1, EPOCHS + 1):
146
+ total_loss = 0
147
+ for step, (x, y) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}")):
148
+ x, y = x.to(device), y.to(device)
149
+
150
+ logits, _ = model(x)
151
+ loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
152
+ loss = loss / ACCUM_STEPS
153
+ loss.backward()
154
+
155
+ total_loss += loss.item() * ACCUM_STEPS
156
+
157
+ if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
158
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
159
+ optimizer.step()
160
+ optimizer.zero_grad()
161
+
162
+ avg_train_loss = total_loss / len(train_loader)
163
+ val_loss = evaluate(model, val_loader)
164
+
165
+ print(f"\nEpoch {epoch}")
166
+ print(f" Train loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}")
167
+ print(f" Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
168
+
169
+ # Save checkpoint
170
+ save_dir = OUTPUT_DIR / f"epoch_{epoch}"
171
+ save_dir.mkdir(exist_ok=True, parents=True)
172
+ torch.save(model.state_dict(), save_dir / MODEL_SAVE_NAME)
173
+ torch.save(model.state_dict(), LAST_TRAINED_PATH)
174
+
175
+ # Keep only the last N epochs to save disk space
176
+ epochs = sorted([p for p in OUTPUT_DIR.iterdir() if p.is_dir() and p.name.startswith("epoch_")])
177
+ for old in epochs[:-KEEP_LAST_EPOCHS]:
178
+ shutil.rmtree(old)
179
+
180
+ print("\nDONE! Full model trained. You are now the emperor of fine-tuning.")
181
+
182
+
183
+ if __name__ == "__main__":
184
  train()