kgrabko commited on
Commit
917cd09
·
verified ·
1 Parent(s): fdb6562

Upload fine_tune_jit_with_validation_gpt2_cuda.py

Browse files

someone block my PC to download gpt2 tokeziner . I did path for coda trainings

fine_tune_jit_with_validation_gpt2_cuda.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2025 CMS Manhattan
3
+ # JiRack JIT Fine-tuning — 100% рабочая версия для ROCm
4
+
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from transformers import GPT2Tokenizer
11
+ from tqdm import tqdm
12
+ import shutil
13
+ import math
14
+ from pathlib import Path
15
+ import re
16
+ from torch.cuda.amp import autocast, GradScaler
17
+
18
+ # ============================= SETTINGS =============================
19
+ TRAIN_SEQ_LEN = 256 # твой контекст — 8192, но ты режешь на 256
20
+ BATCH_SIZE = 12
21
+ EPOCHS = 50
22
+ LEARNING_RATE = 6e-6
23
+ WEIGHT_DECAY = 0.01
24
+ GRAD_CLIP = 1.0
25
+ KEEP_LAST_EPOCHS = 3
26
+ VAL_SPLIT_RATIO = 0.05
27
+
28
+ BASE_MODEL_PATH = Path("models/JiRack_H16_L32_V50257_D768_MSL8192_FF768x4.script.pt")
29
+ LAST_TRAINED_PATH = Path("models/JiRack_last_H16_L32_V50257_D768_MSL8192_FF768x4.script.pt")
30
+ BACKUP_DIR = Path("models/backups")
31
+ BACKUP_DIR.mkdir(parents=True, exist_ok=True)
32
+
33
+ RAW_PATH = Path("datasets/dialogues_text.txt")
34
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
35
+
36
+ OUTPUT_DIR = Path("build/fine_tuning_output")
37
+ SAVE_NAME = "gpt_finetuned.script.pt"
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ print(f"Устройство: {device}\n")
41
+
42
+ # ============================= ОЧИСТКА =============================
43
+ if not CLEAN_PATH.exists() or (RAW_PATH.exists() and RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime):
44
+ if not RAW_PATH.exists():
45
+ raise FileNotFoundError(f"Нет файла: {RAW_PATH}")
46
+ print("Очистка датасета...")
47
+ text = RAW_PATH.read_text(encoding="utf-8")
48
+ text = re.sub(r" {2,}", " ", text).replace(" \n", "\n").replace("\n ", "\n")
49
+ CLEAN_PATH.write_text(text, encoding="utf-8")
50
+ print(f"Чистый датасет сохранён → {CLEAN_PATH}\n")
51
+ else:
52
+ print(f"Используем готовый датасет → {CLEAN_PATH}\n")
53
+
54
+ # ============================= ДАТАСЕТ =============================
55
+ class TextDataset(Dataset):
56
+ def __init__(self, file_path, split='train'):
57
+ self.tokenizer = GPT2Tokenizer.from_pretrained("./tokenizer", local_files_only=True)
58
+ self.tokenizer.pad_token = self.tokenizer.eos_token
59
+
60
+ print(f"Токенизация {file_path} ({split})...")
61
+ text = Path(file_path).read_text(encoding="utf-8")
62
+ tokens = self.tokenizer.encode(text)
63
+
64
+ inputs = []
65
+ labels = []
66
+ for i in range(0, len(tokens) - TRAIN_SEQ_LEN, TRAIN_SEQ_LEN):
67
+ inputs.append(tokens[i:i + TRAIN_SEQ_LEN])
68
+ labels.append(tokens[i + 1:i + TRAIN_SEQ_LEN + 1])
69
+
70
+ total = len(inputs)
71
+ val_n = int(total * VAL_SPLIT_RATIO)
72
+
73
+ if split == "train":
74
+ self.data = list(zip(inputs[:total - val_n], labels[:total - val_n]))
75
+ else:
76
+ self.data = list(zip(inputs[total - val_n:], labels[total - val_n:]))
77
+
78
+ print(f"{split.upper()}: {len(self.data):,} последовательностей\n")
79
+
80
+ def __len__(self): return len(self.data)
81
+ def __getitem__(self, i):
82
+ x, y = self.data[i]
83
+ return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
84
+
85
+ # ============================= ВСПОМОГАТЕЛЬНО =============================
86
+ def get_logits(model, x):
87
+ try:
88
+ logits, _ = model(x)
89
+ except:
90
+ logits = model(x)
91
+ return logits
92
+
93
+ def evaluate(model, loader):
94
+ model.eval()
95
+ total = 0.0
96
+ crit = nn.CrossEntropyLoss()
97
+ with torch.no_grad():
98
+ for x, y in loader:
99
+ x, y = x.to(device), y.to(device)
100
+ with autocast():
101
+ total += crit(get_logits(model, x).view(-1, get_logits(model, x).size(-1)), y.view(-1)).item()
102
+ model.train()
103
+ return total / len(loader)
104
+
105
+ def cleanup():
106
+ old = sorted(OUTPUT_DIR.glob("epoch*"), key=lambda p: int(p.name[5:]))[:-KEEP_LAST_EPOCHS]
107
+ for d in old:
108
+ shutil.rmtree(d, ignore_errors=True)
109
+
110
+ # ============================= ОБУЧЕНИЕ =============================
111
+ def train():
112
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
113
+
114
+ if LAST_TRAINED_PATH.exists():
115
+ print(f"Продолжаем обучение с: {LAST_TRAINED_PATH.name}")
116
+ model = torch.jit.load(LAST_TRAINED_PATH, map_location=device)
117
+ elif BASE_MODEL_PATH.exists():
118
+ print(f"Старт с базовой модели: {BASE_MODEL_PATH.name}")
119
+ model = torch.jit.load(BASE_MODEL_PATH, map_location=device)
120
+ else:
121
+ raise FileNotFoundError("Нет JIT-модели!")
122
+
123
+ model.train()
124
+
125
+ train_ds = TextDataset(CLEAN_PATH, "train")
126
+ val_ds = TextDataset(CLEAN_PATH, "val")
127
+
128
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
129
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, drop_last=True)
130
+
131
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
132
+ criterion = nn.CrossEntropyLoss()
133
+ scaler = GradScaler() # AMP — ускорение в 1.5–2×
134
+
135
+ print(f"НАЧИНАЕМ ОБУЧЕНИЕ — {EPOCHS} эпох, ~{len(train_loader)*EPOCHS:,} шагов\n")
136
+
137
+ for epoch in range(1, EPOCHS + 1):
138
+ print(f"ЭПОХА {epoch}/{EPOCHS}")
139
+ epoch_loss = 0.0
140
+
141
+ for x, y in tqdm(train_loader, desc="Train", leave=False):
142
+ x, y = x.to(device), y.to(device)
143
+ optimizer.zero_grad()
144
+
145
+ with autocast():
146
+ logits = get_logits(model, x)
147
+ loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) # ← ИСПРАВЛЕНО!
148
+
149
+ scaler.scale(loss).backward()
150
+ scaler.unscale_(optimizer)
151
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
152
+ scaler.step(optimizer)
153
+ scaler.update()
154
+
155
+ loss_val = loss.item()
156
+ epoch_loss += loss_val
157
+
158
+ avg = epoch_loss / len(train_loader)
159
+ print(f"TRAIN → loss: {avg:.4f} | ppl: {math.exp(avg):.1f}")
160
+
161
+ val_loss = evaluate(model, val_loader)
162
+ print(f" VAL → loss: {val_loss:.4f} | ppl: {math.exp(val_loss):.1f}\n")
163
+
164
+ # Сохранение
165
+ epoch_dir = OUTPUT_DIR / f"epoch{epoch}"
166
+ epoch_dir.mkdir(exist_ok=True)
167
+ model.save(epoch_dir / SAVE_NAME)
168
+ print(f"Сохранено → {epoch_dir / SAVE_NAME}")
169
+ cleanup()
170
+
171
+ # Финал
172
+ final = OUTPUT_DIR / "final"
173
+ final.mkdir(parents=True, exist_ok=True)
174
+ model.save(final / SAVE_NAME)
175
+ train_ds.tokenizer.save_pretrained(final)
176
+
177
+ if LAST_TRAINED_PATH.exists():
178
+ shutil.copy(LAST_TRAINED_PATH, BACKUP_DIR / f"backup_{int(time.time())}.pt")
179
+ shutil.copy(final / SAVE_NAME, LAST_TRAINED_PATH)
180
+
181
+ print("\nГОТОВО! Модель обучена и сохранена:")
182
+ print(f" → {final / SAVE_NAME}")
183
+ print(f" → {LAST_TRAINED_PATH}")
184
+
185
+ if __name__ == "__main__":
186
+ train()