kgrabko commited on
Commit
8db18a6
·
verified ·
1 Parent(s): 017d8ab

Update fine_tune1b_with_validation_no_torchscript.py

Browse files
fine_tune1b_with_validation_no_torchscript.py CHANGED
@@ -1,253 +1,268 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- #
4
- # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
- # this code under the terms of the GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
6
- # please read <http://www.gnu.org/licenses/>.
7
-
8
- import os
9
- import sys
10
- import torch
11
- import torch.nn as nn
12
- import torch.optim as optim
13
- from torch.utils.data import Dataset, DataLoader
14
- from transformers import GPT2TokenizerFast
15
- from tqdm import tqdm
16
- import shutil
17
- import math
18
- from pathlib import Path
19
- import re
20
- import logging
21
- from torch.amp import GradScaler, autocast
22
-
23
- # --- ДОБАВЛЕНО: Отключаем предупреждение о длинной последовательности ---
24
- logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
25
- # -----------------------------------------------------------------------
26
-
27
- # Убедитесь, что этот файл содержит ИСПРАВЛЕНИЯ СТАБИЛЬНОСТИ (FP32 Attention, _init_weights)!
28
- from gpt_jit_modern_1b import JiRackPyTorch
29
-
30
- # ============================= SETTINGS =============================
31
- # --- НАСТРОЙКИ (независимые от устройства) ---
32
- TRAIN_SEQ_LEN = 64
33
- BATCH_SIZE = 1
34
- ACCUM_STEPS = 32 # Эффективный батч = 32
35
- EPOCHS = 500
36
- LEARNING_RATE = 1e-6
37
- WEIGHT_DECAY = 0.01
38
- GRAD_CLIP = 1.0
39
- VAL_SPLIT_RATIO = 0.05
40
- KEEP_LAST_EPOCHS = 3
41
- # ====================================================================
42
-
43
- # 💻 Device Configuration: АВТООПРЕДЕЛЕНИЕ
44
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
-
46
- if device.type == 'cuda':
47
- USE_AMP = True
48
- AUTOCAST_DTYPE = torch.float16
49
- print(f"Using device: {device} (GPU). AMP (FP16) enabled for efficiency.")
50
- elif device.type == 'cpu':
51
- USE_AMP = False
52
- AUTOCAST_DTYPE = torch.float32
53
- print(f"Using device: {device} (CPU). WARNING: Training 1.2B model on CPU will be extremely slow.")
54
- else:
55
- USE_AMP = False
56
- AUTOCAST_DTYPE = torch.float32
57
- print(f"Using device: {device}. AMP disabled.")
58
-
59
- # === PATHS ===
60
- BASE_MODEL_PATH = Path("models/gpt_modern_1b_class.state_dict.pt")
61
- LAST_TRAINED_PATH = Path("models/gpt_last_modern_1b_class.state_dict.pt")
62
- BACKUP_DIR = Path("models/backups")
63
- BACKUP_DIR.mkdir(exist_ok=True, parents=True)
64
-
65
- RAW_PATH = Path("datasets/dialogues_text.txt")
66
- CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
67
-
68
- # === DATASET CLEANING ===
69
- if not CLEAN_PATH.exists() or RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
70
- print("Cleaning dataset...")
71
- try:
72
- text = RAW_PATH.read_text(encoding="utf-8")
73
- text = re.sub(r' {2,}', ' ', text)
74
- text = text.replace(" \n", "\n").replace("\n ", "\n")
75
- CLEAN_PATH.write_text(text, encoding="utf-8")
76
- print(f"Done {CLEAN_PATH}")
77
- except FileNotFoundError:
78
- print(f"ERROR: Raw dataset not found at {RAW_PATH}")
79
- sys.exit(1)
80
-
81
- DATASET_PATH = CLEAN_PATH
82
- OUTPUT_DIR = Path("build/fine_tuning_output")
83
- MODEL_SAVE_NAME = "pytorch_model.bin"
84
-
85
- # ============================= DATASET =============================
86
- class TextDataset(Dataset):
87
- def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, split='train'):
88
- self.seq_len = seq_len
89
- try:
90
- tokenizer = GPT2TokenizerFast.from_pretrained("./tokenizer", local_files_only=True)
91
- except Exception:
92
- tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
93
-
94
- tokenizer.pad_token = tokenizer.eos_token
95
-
96
- text = Path(text_file).read_text(encoding="utf-8")
97
- tokens = tokenizer.encode(text)
98
-
99
- sequences = []
100
- for i in range(0, len(tokens) - seq_len, seq_len):
101
- sequences.append(tokens[i:i + seq_len + 1])
102
-
103
- split_idx = int(len(sequences) * (1 - VAL_SPLIT_RATIO))
104
- if split == 'train':
105
- self.data = sequences[:split_idx]
106
- else:
107
- self.data = sequences[split_idx:]
108
-
109
- print(f"{split.upper()} sequences: {len(self.data):,}")
110
-
111
- def __len__(self):
112
- return len(self.data)
113
-
114
- def __getitem__(self, idx):
115
- seq = self.data[idx]
116
- return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long)
117
-
118
-
119
- def evaluate(model, loader):
120
- model.eval()
121
- total_loss = 0
122
- criterion = nn.CrossEntropyLoss()
123
- # autocast используется только при USE_AMP=True (только на GPU)
124
- with torch.no_grad(), autocast(device_type=device.type, enabled=USE_AMP, dtype=AUTOCAST_DTYPE):
125
- for x, y in loader:
126
- x, y = x.to(device), y.to(device)
127
-
128
- logits = model(x)
129
- if isinstance(logits, tuple):
130
- logits = logits[0]
131
-
132
- input_logits = logits.contiguous().view(-1, logits.size(-1))
133
- target_labels = y.contiguous().view(-1)[:input_logits.size(0)]
134
-
135
- # Loss всегда вычисляется в FP32 для точности
136
- loss = criterion(input_logits.float(), target_labels)
137
-
138
- total_loss += loss.item()
139
-
140
- model.train()
141
- return total_loss / len(loader)
142
-
143
-
144
- def train():
145
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
146
-
147
- print("Loading model...")
148
- model = JiRackPyTorch().to(device)
149
-
150
- # GradScaler инициализируется, но будет работать только если USE_AMP=True
151
- scaler = GradScaler(enabled=USE_AMP, device=device.type)
152
-
153
- # =========================================================================
154
- # 🔥 ВРЕМЕННО ОТКЛЮЧЕНА ЗАГРУЗКА ВЕСОВ
155
- # =========================================================================
156
- print("Starting from scratch — random weights (Skipping state_dict load for stability test!)")
157
- # =========================================================================
158
-
159
- model.train()
160
-
161
- train_dataset = TextDataset(DATASET_PATH, split='train')
162
- val_dataset = TextDataset(DATASET_PATH, split='val')
163
-
164
- train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
165
- val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0)
166
-
167
- optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
168
- criterion = nn.CrossEntropyLoss()
169
-
170
- print("\nFULL TRAINING STARTED! No LoRA, no compromises — we're training the whole thing!\n")
171
- print(f"Batch size: {BATCH_SIZE * ACCUM_STEPS} (effective) | LR: {LEARNING_RATE} | AMP: {USE_AMP} ({AUTOCAST_DTYPE})")
172
-
173
- for epoch in range(1, EPOCHS + 1):
174
- total_loss = 0
175
- pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]")
176
-
177
- for step, (x, y) in enumerate(pbar):
178
- x, y = x.to(device), y.to(device)
179
-
180
- # 1. Прямой проход и Loss в AMP (только если GPU)
181
- with autocast(device_type=device.type, enabled=USE_AMP, dtype=AUTOCAST_DTYPE):
182
-
183
- logits = model(x)
184
- if isinstance(logits, tuple):
185
- logits = logits[0]
186
-
187
- input_logits = logits.contiguous().view(-1, logits.size(-1))
188
- target_labels = y.contiguous().view(-1)[:input_logits.size(0)]
189
-
190
- loss = criterion(input_logits.float(), target_labels)
191
- loss = loss / ACCUM_STEPS
192
-
193
- # Проверка NaN
194
- if torch.isnan(loss).any():
195
- print(f"\n[FATAL ERROR] Loss became NaN at step {step}. Stopping training.")
196
- raise RuntimeError("Loss became NaN during training, stopping.")
197
-
198
- # 2. Обратный проход через scaler
199
- scaler.scale(loss).backward()
200
-
201
- total_loss += loss.item() * ACCUM_STEPS
202
-
203
- if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
204
- # 3. Обновление оптимизатора через scaler
205
- if USE_AMP:
206
- scaler.unscale_(optimizer) # Снимаем масштабирование (только для GPU)
207
-
208
- torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
209
- scaler.step(optimizer)
210
- scaler.update()
211
- optimizer.zero_grad()
212
-
213
- # Обновление TQDM
214
- current_avg_loss = total_loss / (step + 1)
215
- ppl_val = math.exp(min(current_avg_loss, 10))
216
- pbar.set_postfix({"loss (avg)": f"{current_avg_loss:.4f}", "ppl": f"{ppl_val:.2f}"})
217
-
218
-
219
- avg_train_loss = total_loss / len(train_loader)
220
- val_loss = evaluate(model, val_loader)
221
-
222
- print(f"\nEpoch {epoch}")
223
- print(f" Train loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}")
224
- print(f" Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
225
-
226
- # Save checkpoint
227
- save_dir = OUTPUT_DIR / f"epoch_{epoch}"
228
- save_dir.mkdir(exist_ok=True, parents=True)
229
-
230
- torch.save(model.state_dict(), save_dir / MODEL_SAVE_NAME)
231
- torch.save(model.state_dict(), LAST_TRAINED_PATH)
232
-
233
- # Keep only the last N epochs to save disk space
234
- epochs_dirs = sorted([p for p in OUTPUT_DIR.iterdir() if p.is_dir() and p.name.startswith("epoch_")])
235
- for old in epochs_dirs[:-KEEP_LAST_EPOCHS]:
236
- shutil.rmtree(old)
237
-
238
- print("\nDONE! Full model trained. You are now the emperor of fine-tuning.")
239
-
240
-
241
- if __name__ == "__main__":
242
- try:
243
- train()
244
- except RuntimeError as e:
245
- if "Loss became NaN" in str(e):
246
- print("\n[CRITICAL FAILURE] Training stopped due to NaN loss.")
247
- print("Action: Revisit JiRackPyTorch weight initialization (reduce STD further) or reduce LEARNING_RATE to 1e-6.")
248
- sys.exit(1)
249
- elif "CUDA out of memory" in str(e):
250
- print("\n[CRITICAL FAILURE] CUDA Out of Memory.")
251
- print("Action: Current configuration BATCH_SIZE=1, AMP=FP16 is the minimum memory usage possible. Try reducing TRAIN_SEQ_LEN from 256 to 128.")
252
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  raise
 
1
+ # Copyright (c) 2025 CMS Manhattan
2
+ # All rights reserved.
3
+ # Author: Konstantin Vladimirovich Grabko
4
+ # Email: grabko@cmsmanhattan.com
5
+ # Phone: +1(516)777-0945
6
+ #
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, version 3 of the License.
10
+ #
11
+ # This program is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU General Public License for more details.
15
+ #
16
+ # You should have received a copy of the GNU General Public License
17
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
18
+ #
19
+ # Additional terms:
20
+ # Any commercial use or distribution of this software or derivative works
21
+ # requires explicit written permission from the copyright holder.
22
+
23
+ import os
24
+ import sys
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.optim as optim
28
+ from torch.utils.data import Dataset, DataLoader
29
+ from transformers import GPT2TokenizerFast
30
+ from tqdm import tqdm
31
+ import shutil
32
+ import math
33
+ from pathlib import Path
34
+ import re
35
+ import logging
36
+ from torch.amp import GradScaler, autocast
37
+
38
+ # --- ДОБАВЛЕНО: Отключаем предупреждение о длинной последовательности ---
39
+ logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
40
+ # -----------------------------------------------------------------------
41
+
42
+ # Убедитесь, что этот файл содержит ИСПРАВЛЕНИЯ СТАБИЛЬНОСТИ (FP32 Attention, _init_weights)!
43
+ from gpt_jit_modern_1b import JiRackPyTorch
44
+
45
+ # ============================= SETTINGS =============================
46
+ # --- НАСТРОЙКИ (независимые от устройства) ---
47
+ TRAIN_SEQ_LEN = 64
48
+ BATCH_SIZE = 1
49
+ ACCUM_STEPS = 32 # Эффективный батч = 32
50
+ EPOCHS = 500
51
+ LEARNING_RATE = 1e-6
52
+ WEIGHT_DECAY = 0.01
53
+ GRAD_CLIP = 1.0
54
+ VAL_SPLIT_RATIO = 0.05
55
+ KEEP_LAST_EPOCHS = 3
56
+ # ====================================================================
57
+
58
+ # 💻 Device Configuration: АВТООПРЕДЕЛЕНИЕ
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+
61
+ if device.type == 'cuda':
62
+ USE_AMP = True
63
+ AUTOCAST_DTYPE = torch.float16
64
+ print(f"Using device: {device} (GPU). AMP (FP16) enabled for efficiency.")
65
+ elif device.type == 'cpu':
66
+ USE_AMP = False
67
+ AUTOCAST_DTYPE = torch.float32
68
+ print(f"Using device: {device} (CPU). WARNING: Training 1.2B model on CPU will be extremely slow.")
69
+ else:
70
+ USE_AMP = False
71
+ AUTOCAST_DTYPE = torch.float32
72
+ print(f"Using device: {device}. AMP disabled.")
73
+
74
+ # === PATHS ===
75
+ BASE_MODEL_PATH = Path("models/gpt_modern_1b_class.state_dict.pt")
76
+ LAST_TRAINED_PATH = Path("models/gpt_last_modern_1b_class.state_dict.pt")
77
+ BACKUP_DIR = Path("models/backups")
78
+ BACKUP_DIR.mkdir(exist_ok=True, parents=True)
79
+
80
+ RAW_PATH = Path("datasets/dialogues_text.txt")
81
+ CLEAN_PATH = Path("datasets/dialogues_text_clean.txt")
82
+
83
+ # === DATASET CLEANING ===
84
+ if not CLEAN_PATH.exists() or RAW_PATH.stat().st_mtime > CLEAN_PATH.stat().st_mtime:
85
+ print("Cleaning dataset...")
86
+ try:
87
+ text = RAW_PATH.read_text(encoding="utf-8")
88
+ text = re.sub(r' {2,}', ' ', text)
89
+ text = text.replace(" \n", "\n").replace("\n ", "\n")
90
+ CLEAN_PATH.write_text(text, encoding="utf-8")
91
+ print(f"Done → {CLEAN_PATH}")
92
+ except FileNotFoundError:
93
+ print(f"ERROR: Raw dataset not found at {RAW_PATH}")
94
+ sys.exit(1)
95
+
96
+ DATASET_PATH = CLEAN_PATH
97
+ OUTPUT_DIR = Path("build/fine_tuning_output")
98
+ MODEL_SAVE_NAME = "pytorch_model.bin"
99
+
100
+ # ============================= DATASET =============================
101
+ class TextDataset(Dataset):
102
+ def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, split='train'):
103
+ self.seq_len = seq_len
104
+ try:
105
+ tokenizer = GPT2TokenizerFast.from_pretrained("./tokenizer", local_files_only=True)
106
+ except Exception:
107
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
108
+
109
+ tokenizer.pad_token = tokenizer.eos_token
110
+
111
+ text = Path(text_file).read_text(encoding="utf-8")
112
+ tokens = tokenizer.encode(text)
113
+
114
+ sequences = []
115
+ for i in range(0, len(tokens) - seq_len, seq_len):
116
+ sequences.append(tokens[i:i + seq_len + 1])
117
+
118
+ split_idx = int(len(sequences) * (1 - VAL_SPLIT_RATIO))
119
+ if split == 'train':
120
+ self.data = sequences[:split_idx]
121
+ else:
122
+ self.data = sequences[split_idx:]
123
+
124
+ print(f"{split.upper()} sequences: {len(self.data):,}")
125
+
126
+ def __len__(self):
127
+ return len(self.data)
128
+
129
+ def __getitem__(self, idx):
130
+ seq = self.data[idx]
131
+ return torch.tensor(seq[:-1], dtype=torch.long), torch.tensor(seq[1:], dtype=torch.long)
132
+
133
+
134
+ def evaluate(model, loader):
135
+ model.eval()
136
+ total_loss = 0
137
+ criterion = nn.CrossEntropyLoss()
138
+ # autocast используется только при USE_AMP=True (только на GPU)
139
+ with torch.no_grad(), autocast(device_type=device.type, enabled=USE_AMP, dtype=AUTOCAST_DTYPE):
140
+ for x, y in loader:
141
+ x, y = x.to(device), y.to(device)
142
+
143
+ logits = model(x)
144
+ if isinstance(logits, tuple):
145
+ logits = logits[0]
146
+
147
+ input_logits = logits.contiguous().view(-1, logits.size(-1))
148
+ target_labels = y.contiguous().view(-1)[:input_logits.size(0)]
149
+
150
+ # Loss всегда вычисляется в FP32 для точности
151
+ loss = criterion(input_logits.float(), target_labels)
152
+
153
+ total_loss += loss.item()
154
+
155
+ model.train()
156
+ return total_loss / len(loader)
157
+
158
+
159
+ def train():
160
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
161
+
162
+ print("Loading model...")
163
+ model = JiRackPyTorch().to(device)
164
+
165
+ # GradScaler инициализируется, но будет работать только если USE_AMP=True
166
+ scaler = GradScaler(enabled=USE_AMP, device=device.type)
167
+
168
+ # =========================================================================
169
+ # 🔥 ВРЕМЕННО ОТКЛЮЧЕНА ЗАГРУЗКА ВЕСОВ
170
+ # =========================================================================
171
+ print("Starting from scratch random weights (Skipping state_dict load for stability test!)")
172
+ # =========================================================================
173
+
174
+ model.train()
175
+
176
+ train_dataset = TextDataset(DATASET_PATH, split='train')
177
+ val_dataset = TextDataset(DATASET_PATH, split='val')
178
+
179
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
180
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0)
181
+
182
+ optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
183
+ criterion = nn.CrossEntropyLoss()
184
+
185
+ print("\nFULL TRAINING STARTED! No LoRA, no compromises — we're training the whole thing!\n")
186
+ print(f"Batch size: {BATCH_SIZE * ACCUM_STEPS} (effective) | LR: {LEARNING_RATE} | AMP: {USE_AMP} ({AUTOCAST_DTYPE})")
187
+
188
+ for epoch in range(1, EPOCHS + 1):
189
+ total_loss = 0
190
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]")
191
+
192
+ for step, (x, y) in enumerate(pbar):
193
+ x, y = x.to(device), y.to(device)
194
+
195
+ # 1. Прямой проход и Loss в AMP (только если GPU)
196
+ with autocast(device_type=device.type, enabled=USE_AMP, dtype=AUTOCAST_DTYPE):
197
+
198
+ logits = model(x)
199
+ if isinstance(logits, tuple):
200
+ logits = logits[0]
201
+
202
+ input_logits = logits.contiguous().view(-1, logits.size(-1))
203
+ target_labels = y.contiguous().view(-1)[:input_logits.size(0)]
204
+
205
+ loss = criterion(input_logits.float(), target_labels)
206
+ loss = loss / ACCUM_STEPS
207
+
208
+ # Проверка NaN
209
+ if torch.isnan(loss).any():
210
+ print(f"\n[FATAL ERROR] Loss became NaN at step {step}. Stopping training.")
211
+ raise RuntimeError("Loss became NaN during training, stopping.")
212
+
213
+ # 2. Обратный проход через scaler
214
+ scaler.scale(loss).backward()
215
+
216
+ total_loss += loss.item() * ACCUM_STEPS
217
+
218
+ if (step + 1) % ACCUM_STEPS == 0 or (step + 1) == len(train_loader):
219
+ # 3. Обновление оптимизатора через scaler
220
+ if USE_AMP:
221
+ scaler.unscale_(optimizer) # Снимаем масштабирование (только для GPU)
222
+
223
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
224
+ scaler.step(optimizer)
225
+ scaler.update()
226
+ optimizer.zero_grad()
227
+
228
+ # Обновление TQDM
229
+ current_avg_loss = total_loss / (step + 1)
230
+ ppl_val = math.exp(min(current_avg_loss, 10))
231
+ pbar.set_postfix({"loss (avg)": f"{current_avg_loss:.4f}", "ppl": f"{ppl_val:.2f}"})
232
+
233
+
234
+ avg_train_loss = total_loss / len(train_loader)
235
+ val_loss = evaluate(model, val_loader)
236
+
237
+ print(f"\nEpoch {epoch}")
238
+ print(f" Train loss: {avg_train_loss:.4f} | PPL: {math.exp(avg_train_loss):.2f}")
239
+ print(f" Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
240
+
241
+ # Save checkpoint
242
+ save_dir = OUTPUT_DIR / f"epoch_{epoch}"
243
+ save_dir.mkdir(exist_ok=True, parents=True)
244
+
245
+ torch.save(model.state_dict(), save_dir / MODEL_SAVE_NAME)
246
+ torch.save(model.state_dict(), LAST_TRAINED_PATH)
247
+
248
+ # Keep only the last N epochs to save disk space
249
+ epochs_dirs = sorted([p for p in OUTPUT_DIR.iterdir() if p.is_dir() and p.name.startswith("epoch_")])
250
+ for old in epochs_dirs[:-KEEP_LAST_EPOCHS]:
251
+ shutil.rmtree(old)
252
+
253
+ print("\nDONE! Full model trained. You are now the emperor of fine-tuning.")
254
+
255
+
256
+ if __name__ == "__main__":
257
+ try:
258
+ train()
259
+ except RuntimeError as e:
260
+ if "Loss became NaN" in str(e):
261
+ print("\n[CRITICAL FAILURE] Training stopped due to NaN loss.")
262
+ print("Action: Revisit JiRackPyTorch weight initialization (reduce STD further) or reduce LEARNING_RATE to 1e-6.")
263
+ sys.exit(1)
264
+ elif "CUDA out of memory" in str(e):
265
+ print("\n[CRITICAL FAILURE] CUDA Out of Memory.")
266
+ print("Action: Current configuration BATCH_SIZE=1, AMP=FP16 is the minimum memory usage possible. Try reducing TRAIN_SEQ_LEN from 256 to 128.")
267
+ sys.exit(1)
268
  raise