kgrabko commited on
Commit
ae3a2ac
·
verified ·
1 Parent(s): 744cbe0

Update fine_tune_jit_with_validation_gpt2_cuda.py

Browse files
fine_tune_jit_with_validation_gpt2_cuda.py CHANGED
@@ -1,186 +1,204 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) 2025 CMS Manhattan
3
- # JiRack JIT Fine-tuning — 100% рабочая версия для ROCm
4
-
5
- import os
6
- import torch
7
- import torch.nn as nn
8
- import torch.optim as optim
9
- from torch.utils.data import Dataset, DataLoader
10
- from transformers import GPT2Tokenizer
11
- from tqdm import tqdm
12
- import shutil
13
- import math
14
- from pathlib import Path
15
- import re
16
- from torch.cuda.amp import autocast, GradScaler
17
-
18
- # ============================= SETTINGS =============================
19
- TRAIN_SEQ_LEN = 256 # твой контекст — 8192, но ты режешь на 256
20
- BATCH_SIZE = 12
21
- EPOCHS = 50
22
- LEARNING_RATE = 6e-6
23
- WEIGHT_DECAY = 0.01
24
- GRAD_CLIP = 1.0
25
- KEEP_LAST_EPOCHS = 3
26
- VAL_SPLIT_RATIO = 0.05
27
-
28
- BASE_MODEL_PATH = Path("models/JiRack_H16_L32_V50257_D768_MSL8192_FF768x4.script.pt")
29
- LAST_TRAINED_PATH = Path("models/JiRack_last_H16_L32_V50257_D768_MSL8192_FF768x4.script.pt")
30
- BACKUP_DIR = Path("models/backups")
31
- BACKUP_DIR.mkdir(parents=True, exist_ok=True)
32
-
33
- RAW_PATH = Path("datasets/dialogues_text.txt")
34
- CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
35
-
36
- OUTPUT_DIR = Path("build/fine_tuning_output")
37
- SAVE_NAME = "gpt_finetuned.script.pt"
38
-
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- print(f"Устройство: {device}\n")
41
-
42
- # ============================= ОЧИСТКА =============================
43
- if not CLEAN_PATH.exists() or (RAW_PATH.exists() and RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime):
44
- if not RAW_PATH.exists():
45
- raise FileNotFoundError(f"Нет файла: {RAW_PATH}")
46
- print("Очистка датасета...")
47
- text = RAW_PATH.read_text(encoding="utf-8")
48
- text = re.sub(r" {2,}", " ", text).replace(" \n", "\n").replace("\n ", "\n")
49
- CLEAN_PATH.write_text(text, encoding="utf-8")
50
- print(f"Чистый датасет сохранён → {CLEAN_PATH}\n")
51
- else:
52
- print(f"Используем готовый датасет → {CLEAN_PATH}\n")
53
-
54
- # ============================= ДАТАСЕТ =============================
55
- class TextDataset(Dataset):
56
- def __init__(self, file_path, split='train'):
57
- self.tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer", local_files_only=True)
58
- self.tokenizer.pad_token = self.tokenizer.eos_token
59
-
60
- print(f"Токенизация {file_path} ({split})...")
61
- text = Path(file_path).read_text(encoding="utf-8")
62
- tokens = self.tokenizer.encode(text)
63
-
64
- inputs = []
65
- labels = []
66
- for i in range(0, len(tokens) - TRAIN_SEQ_LEN, TRAIN_SEQ_LEN):
67
- inputs.append(tokens[i:i + TRAIN_SEQ_LEN])
68
- labels.append(tokens[i + 1:i + TRAIN_SEQ_LEN + 1])
69
-
70
- total = len(inputs)
71
- val_n = int(total * VAL_SPLIT_RATIO)
72
-
73
- if split == "train":
74
- self.data = list(zip(inputs[:total - val_n], labels[:total - val_n]))
75
- else:
76
- self.data = list(zip(inputs[total - val_n:], labels[total - val_n:]))
77
-
78
- print(f"{split.upper()}: {len(self.data):,} последовательностей\n")
79
-
80
- def __len__(self): return len(self.data)
81
- def __getitem__(self, i):
82
- x, y = self.data[i]
83
- return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
84
-
85
- # ============================= ВСПОМОГАТЕЛЬНО =============================
86
- def get_logits(model, x):
87
- try:
88
- logits, _ = model(x)
89
- except:
90
- logits = model(x)
91
- return logits
92
-
93
- def evaluate(model, loader):
94
- model.eval()
95
- total = 0.0
96
- crit = nn.CrossEntropyLoss()
97
- with torch.no_grad():
98
- for x, y in loader:
99
- x, y = x.to(device), y.to(device)
100
- with autocast():
101
- total += crit(get_logits(model, x).view(-1, get_logits(model, x).size(-1)), y.view(-1)).item()
102
- model.train()
103
- return total / len(loader)
104
-
105
- def cleanup():
106
- old = sorted(OUTPUT_DIR.glob("epoch*"), key=lambda p: int(p.name[5:]))[:-KEEP_LAST_EPOCHS]
107
- for d in old:
108
- shutil.rmtree(d, ignore_errors=True)
109
-
110
- # ============================= ОБУЧЕНИЕ =============================
111
- def train():
112
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
113
-
114
- if LAST_TRAINED_PATH.exists():
115
- print(f"Продолжаем обучение с: {LAST_TRAINED_PATH.name}")
116
- model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
117
- elif BASE_MODEL_PATH.exists():
118
- print(f"Старт с базовой модели: {BASE_MODEL_PATH.name}")
119
- model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
120
- else:
121
- raise FileNotFoundError("Нет JIT-модели!")
122
-
123
- model.train()
124
-
125
- train_ds = TextDataset(CLEAN_PATH, "train")
126
- val_ds = TextDataset(CLEAN_PATH, "val")
127
-
128
- train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
129
- val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, drop_last=True)
130
-
131
- optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
132
- criterion = nn.CrossEntropyLoss()
133
- scaler = GradScaler() # AMP ускорение в 1.5–2×
134
-
135
- print(f"НАЧИНАЕМ ОБУЧЕНИЕ — {EPOCHS} эпох, ~{len(train_loader)*EPOCHS:,} шагов\n")
136
-
137
- for epoch in range(1, EPOCHS + 1):
138
- print(f"ЭПОХА {epoch}/{EPOCHS}")
139
- epoch_loss = 0.0
140
-
141
- for x, y in tqdm(train_loader, desc="Train", leave=False):
142
- x, y = x.to(device), y.to(device)
143
- optimizer.zero_grad()
144
-
145
- with autocast():
146
- logits = get_logits(model, x)
147
- loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) # ← ИСПРАВЛЕНО!
148
-
149
- scaler.scale(loss).backward()
150
- scaler.unscale_(optimizer)
151
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
152
- scaler.step(optimizer)
153
- scaler.update()
154
-
155
- loss_val = loss.item()
156
- epoch_loss += loss_val
157
-
158
- avg = epoch_loss / len(train_loader)
159
- print(f"TRAIN loss: {avg:.4f} | ppl: {math.exp(avg):.1f}")
160
-
161
- val_loss = evaluate(model, val_loader)
162
- print(f" VAL → loss: {val_loss:.4f} | ppl: {math.exp(val_loss):.1f}\n")
163
-
164
- # Сохранение
165
- epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
166
- epoch_dir.mkdir(exist_ok=True)
167
- model.save(epoch_dir / SAVE_NAME)
168
- print(f"Сохранено → {epoch_dir / SAVE_NAME}")
169
- cleanup()
170
-
171
- # Финал
172
- final = OUTPUT_DIR / "final"
173
- final.mkdir(parents=True, exist_ok=True)
174
- model.save(final / SAVE_NAME)
175
- train_ds.tokenizer.save_pretrained(final)
176
-
177
- if LAST_TRAINED_PATH.exists():
178
- shutil.copy(LAST_TRAINED_PATH, BACKUP_DIR / f"backup_{int(time.time())}.pt")
179
- shutil.copy(final / SAVE_NAME, LAST_TRAINED_PATH)
180
-
181
- print("\nГОТОВО! Модель обучена и сохранена:")
182
- print(f" → {final / SAVE_NAME}")
183
- print(f"{LAST_TRAINED_PATH}")
184
-
185
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  train()
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import os
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.optim as optim
27
+ from torch.utils.data import Dataset, DataLoader
28
+ from transformers import GPT2Tokenizer
29
+ from tqdm import tqdm
30
+ import shutil
31
+ import math
32
+ from pathlib import Path
33
+ import re
34
+ from torch.cuda.amp import autocast, GradScaler
35
+
36
+ # ============================= SETTINGS =============================
37
+ TRAIN_SEQ_LEN = 256 # твой контекст — 8192, но ты режешь на 256
38
+ BATCH_SIZE = 12
39
+ EPOCHS = 50
40
+ LEARNING_RATE = 6e-6
41
+ WEIGHT_DECAY = 0.01
42
+ GRAD_CLIP = 1.0
43
+ KEEP_LAST_EPOCHS = 3
44
+ VAL_SPLIT_RATIO = 0.05
45
+
46
+ BASE_MODEL_PATH = Path("models/JiRack_H16_L32_V50257_D768_MSL8192_FF768x4.script.pt")
47
+ LAST_TRAINED_PATH = Path("models/JiRack_last_H16_L32_V50257_D768_MSL8192_FF768x4.script.pt")
48
+ BACKUP_DIR = Path("models/backups")
49
+ BACKUP_DIR.mkdir(parents=True, exist_ok=True)
50
+
51
+ RAW_PATH = Path("datasets/dialogues_text.txt")
52
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
53
+
54
+ OUTPUT_DIR = Path("build/fine_tuning_output")
55
+ SAVE_NAME = "gpt_finetuned.script.pt"
56
+
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ print(f"Устройство: {device}\n")
59
+
60
+ # ============================= ОЧИСТКА =============================
61
+ if not CLEAN_PATH.exists() or (RAW_PATH.exists() and RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime):
62
+ if not RAW_PATH.exists():
63
+ raise FileNotFoundError(f"Нет файла: {RAW_PATH}")
64
+ print("Очистка датасета...")
65
+ text = RAW_PATH.read_text(encoding="utf-8")
66
+ text = re.sub(r" {2,}", " ", text).replace(" \n", "\n").replace("\n ", "\n")
67
+ CLEAN_PATH.write_text(text, encoding="utf-8")
68
+ print(f"Чистый датасет сохранён {CLEAN_PATH}\n")
69
+ else:
70
+ print(f"Используем готовый датасет → {CLEAN_PATH}\n")
71
+
72
+ # ============================= ДАТАСЕТ =============================
73
+ class TextDataset(Dataset):
74
+ def __init__(self, file_path, split='train'):
75
+ self.tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer", local_files_only=True)
76
+ self.tokenizer.pad_token = self.tokenizer.eos_token
77
+
78
+ print(f"Токенизация {file_path} ({split})...")
79
+ text = Path(file_path).read_text(encoding="utf-8")
80
+ tokens = self.tokenizer.encode(text)
81
+
82
+ inputs = []
83
+ labels = []
84
+ for i in range(0, len(tokens) - TRAIN_SEQ_LEN, TRAIN_SEQ_LEN):
85
+ inputs.append(tokens[i:i + TRAIN_SEQ_LEN])
86
+ labels.append(tokens[i + 1:i + TRAIN_SEQ_LEN + 1])
87
+
88
+ total = len(inputs)
89
+ val_n = int(total * VAL_SPLIT_RATIO)
90
+
91
+ if split == "train":
92
+ self.data = list(zip(inputs[:total - val_n], labels[:total - val_n]))
93
+ else:
94
+ self.data = list(zip(inputs[total - val_n:], labels[total - val_n:]))
95
+
96
+ print(f"{split.upper()}: {len(self.data):,} последовательностей\n")
97
+
98
+ def __len__(self): return len(self.data)
99
+ def __getitem__(self, i):
100
+ x, y = self.data[i]
101
+ return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
102
+
103
+ # ============================= ВСПОМОГАТЕЛЬНО =============================
104
+ def get_logits(model, x):
105
+ try:
106
+ logits, _ = model(x)
107
+ except:
108
+ logits = model(x)
109
+ return logits
110
+
111
+ def evaluate(model, loader):
112
+ model.eval()
113
+ total = 0.0
114
+ crit = nn.CrossEntropyLoss()
115
+ with torch.no_grad():
116
+ for x, y in loader:
117
+ x, y = x.to(device), y.to(device)
118
+ with autocast():
119
+ total += crit(get_logits(model, x).view(-1, get_logits(model, x).size(-1)), y.view(-1)).item()
120
+ model.train()
121
+ return total / len(loader)
122
+
123
+ def cleanup():
124
+ old = sorted(OUTPUT_DIR.glob("epoch*"), key=lambda p: int(p.name[5:]))[:-KEEP_LAST_EPOCHS]
125
+ for d in old:
126
+ shutil.rmtree(d, ignore_errors=True)
127
+
128
+ # ============================= ОБУЧЕНИЕ =============================
129
+ def train():
130
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
131
+
132
+ if LAST_TRAINED_PATH.exists():
133
+ print(f"Продолжаем обучение с: {LAST_TRAINED_PATH.name}")
134
+ model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
135
+ elif BASE_MODEL_PATH.exists():
136
+ print(f"Старт с базовой модели: {BASE_MODEL_PATH.name}")
137
+ model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
138
+ else:
139
+ raise FileNotFoundError("Нет JIT-модели!")
140
+
141
+ model.train()
142
+
143
+ train_ds = TextDataset(CLEAN_PATH, "train")
144
+ val_ds = TextDataset(CLEAN_PATH, "val")
145
+
146
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
147
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, drop_last=True)
148
+
149
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
150
+ criterion = nn.CrossEntropyLoss()
151
+ scaler = GradScaler() # AMP — ускорение в 1.5–2×
152
+
153
+ print(f"НАЧИНАЕМ ОБУЧЕНИЕ — {EPOCHS} эпох, ~{len(train_loader)*EPOCHS:,} шагов\n")
154
+
155
+ for epoch in range(1, EPOCHS + 1):
156
+ print(f"ЭПОХА {epoch}/{EPOCHS}")
157
+ epoch_loss = 0.0
158
+
159
+ for x, y in tqdm(train_loader, desc="Train", leave=False):
160
+ x, y = x.to(device), y.to(device)
161
+ optimizer.zero_grad()
162
+
163
+ with autocast():
164
+ logits = get_logits(model, x)
165
+ loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) # ← ИСПРАВЛЕНО!
166
+
167
+ scaler.scale(loss).backward()
168
+ scaler.unscale_(optimizer)
169
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
170
+ scaler.step(optimizer)
171
+ scaler.update()
172
+
173
+ loss_val = loss.item()
174
+ epoch_loss += loss_val
175
+
176
+ avg = epoch_loss / len(train_loader)
177
+ print(f"TRAIN → loss: {avg:.4f} | ppl: {math.exp(avg):.1f}")
178
+
179
+ val_loss = evaluate(model, val_loader)
180
+ print(f" VAL → loss: {val_loss:.4f} | ppl: {math.exp(val_loss):.1f}\n")
181
+
182
+ # Сохранение
183
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
184
+ epoch_dir.mkdir(exist_ok=True)
185
+ model.save(epoch_dir / SAVE_NAME)
186
+ print(f"Сохранено → {epoch_dir / SAVE_NAME}")
187
+ cleanup()
188
+
189
+ # Финал
190
+ final = OUTPUT_DIR / "final"
191
+ final.mkdir(parents=True, exist_ok=True)
192
+ model.save(final / SAVE_NAME)
193
+ train_ds.tokenizer.save_pretrained(final)
194
+
195
+ if LAST_TRAINED_PATH.exists():
196
+ shutil.copy(LAST_TRAINED_PATH, BACKUP_DIR / f"backup_{int(time.time())}.pt")
197
+ shutil.copy(final / SAVE_NAME, LAST_TRAINED_PATH)
198
+
199
+ print("\nГОТОВО! Модель обучена и сохранена:")
200
+ print(f" → {final / SAVE_NAME}")
201
+ print(f" → {LAST_TRAINED_PATH}")
202
+
203
+ if __name__ == "__main__":
204
  train()