File size: 11,746 Bytes
c88fe21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.

import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
from torch.cuda.amp import autocast, GradScaler # 👈 Добавлен импорт AMP

# Параметры (пример)
TRAIN_SEQ_LEN = 256
BATCH_SIZE = 12
EPOCHS = 10
LEARNING_RATE = 1e-6    # 👈 СНИЖЕНО ДЛЯ СТАБИЛЬНОСТИ
WEIGHT_DECAY = 0.01
GRAD_CLIP = 0.5
VAL_SPLIT_RATIO = 0.05

BASE_MODEL_PATH = Path("models/JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.script.pt")
DATASET_PATH = Path("datasets/dialogues_text_clean.txt")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")

def print_model_devices(model):
    sd = model.state_dict()
    devs = set()
    for k, v in sd.items():
        try:
            devs.add(v.device)
        except Exception:
            devs.add(torch.device('cpu'))
    print("Devices present in model.state_dict():", devs)
    return devs

def safe_load_jit_model(path: Path, map_device: torch.device):
    """

    Загружает JIT модель с map_location и пытается привести её к map_device.

    Возвращает (model, model_device) — модель и устройство, на котором находятся её параметры/буферы.

    """
    if not path.exists():
        raise FileNotFoundError(f"JIT model not found: {path}")

    # Попытка загрузки с map_location
    print(f"Loading JIT model from {path} with map_location={map_device} ...")
    model = torch.jit.load(str(path), map_location=str(map_device))
    print("Loaded model. Попытка model.to(...) ...")
    try:
        model = model.to(map_device)
        print("model.to(map_device) выполнен.")
    except Exception as e:
        # У некоторых JIT объектов .to() может не сработать — это нормально, продолжим диагностику
        print("Warning: model.to(map_device) вызвал исключение:", e)

    # Диагностика устройств, где лежат параметры/буферы
    devs = print_model_devices(model)

    # Выберем устройство "модели" — если их несколько, отдаём предпочтение CUDA если есть
    if len(devs) == 0:
        model_device = map_device
    elif len(devs) == 1:
        model_device = list(devs)[0]
    else:
        # если есть смешанные устройства — попробуем приоритет cuda, иначе первый в множестве
        cuda_devs = [d for d in devs if 'cuda' in str(d)]
        model_device = cuda_devs[0] if cuda_devs else list(devs)[0]
        print("Внимание: обнаружены несколько устройств внутри state_dict(). Выбран model_device =", model_device)

    # Если model_device не равен map_device — уведомим пользователя и попытаемся ещё раз загрузить с конкретным map_location
    if str(model_device) != str(map_device):
        print(f"Model tensors are on {model_device} but requested map_device is {map_device}.")
        print("Попробую заново загрузить модель с map_location=model_device ...")
        try:
            model = torch.jit.load(str(path), map_location=str(model_device))
            try:
                model = model.to(model_device)
            except Exception:
                pass
            devs2 = print_model_devices(model)
            if len(devs2) == 1 and list(devs2)[0] == model_device:
                print("Успешно перезагружено на целевое устройство.")
        except Exception as e:
            print("Не удалось перезагрузить модель на желаемое устройство:", e)
            # продолжаем, но предупредим пользователя
    return model, model_device

def get_logits_from_model(model, inputs):
    """

    Вызов модели, допускающий возможные варианты возврата.

    Мы предполагаем, что inputs уже находится на том же устройстве, что и модель.

    """
    try:
        out = model(inputs)
        # model может вернуть logits или (logits, kv)
        if isinstance(out, tuple) or isinstance(out, list):
            return out[0]
        return out
    except RuntimeError as e:
        # Если ошибка связана с устройствами, добавим детальный лог
        msg = str(e)
        if "Expected all tensors to be on the same device" in msg or "but found at least two devices" in msg:
            print("RuntimeError: вероятно есть mismatch устройств (cpu/cuda) внутри model. Диагностика state_dict():")
            try:
                print_model_devices(model)
            except Exception:
                pass
            # Ребросим исключение с более понятным сообщением
            raise RuntimeError("Device mismatch while running the JIT model. See printed diagnostics above.") from e
        else:
            raise

# ----------------- Пример интеграции в train loop -----------------
def train():
    model, model_device = safe_load_jit_model(BASE_MODEL_PATH, device)

    # Подготовьте датасеты здесь как вы уже делаете (замените на свой TextDataset)
    from transformers import GPT2TokenizerFast
    # Замените на ваш реальный TextDataset; здесь лишь заглушка
    class DummyDataset(torch.utils.data.Dataset):
        def __init__(self, n=1000, seq_len=TRAIN_SEQ_LEN, vocab_size=50257):
            self.n = n
            self.seq_len = seq_len
            self.vocab_size = vocab_size
        def __len__(self): return self.n
        def __getitem__(self, i):
            x = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
            y = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
            return x, y

    train_dataset = DummyDataset(n=2000)
    val_dataset = DummyDataset(n=200)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

    # Создаём optimizer
    params = list(model.parameters()) if hasattr(model, 'parameters') else []
    if len(params) == 0:
        print("Warning: model.parameters() пуст. Убедитесь, что JIT-модель содержит параметры для оптимизации.")
    optimizer = optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) if params else None
    criterion = nn.CrossEntropyLoss()

    # Инициализация GradScaler для AMP
    scaler = GradScaler()

    model.train() 

    for epoch in range(1, EPOCHS + 1):
        print(f"Эпоха {epoch}/{EPOCHS}")
        epoch_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]", leave=False)
        
        batch_count = 0
        skipped_batches = 0
        
        for xb, yb in pbar:
            # === 1. ПРОВЕРКА ДАННЫХ НА NAN/INF ===
            # Проверяем только если тип данных — float (для LongTensor проверка не нужна)
            if torch.is_floating_point(xb) and (torch.isnan(xb).any() or torch.isinf(xb).any()):
                 print(f"\n[E{epoch}] WARNING: NaN or Inf found in input data (xb). Skipping batch.")
                 skipped_batches += 1
                 continue
            
            # Приводим батчи к устройству модели (model_device)
            xb = xb.to(model_device)
            yb = yb.to(model_device)

            if optimizer:
                optimizer.zero_grad()

            # === 2. AMP: Выполняем forward-pass в half-precision ===
            with autocast():
                logits = get_logits_from_model(model, xb)
                
                # У logits размер [B, seq_len, vocab] — приводим к числу классов
                loss = criterion(logits.view(-1, logits.size(-1)), yb.view(-1))
            # ========================================================
            
            # === 3. ПРОВЕРКА ЛОССА НА NAN/INF ПЕРЕД BACKWARD ===
            # Проверяем лосс, который теперь может быть float16 или float32
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\n[E{epoch}] CRITICAL: Loss is NaN or Inf. Skipping backward and update.")
                skipped_batches += 1
                continue
                
            # AMP: Вычисляем градиенты, масштабируя их
            scaler.scale(loss).backward()
            
            if optimizer:
                # AMP: Сначала снимаем масштаб
                scaler.unscale_(optimizer) 
                
                # Обрезка градиентов
                torch.nn.utils.clip_grad_norm_(params, GRAD_CLIP)
                
                # AMP: Обновляем веса (scaler сам проверяет, не являются ли градиенты Inf/NaN)
                scaler.step(optimizer)
                scaler.update()

            # Переводим лосс в float32 для записи и отображения
            loss_val = loss.item()
            epoch_loss += loss_val
            batch_count += 1
            
            pbar.set_postfix({"loss": f"{loss_val:.4f}", "ppl": f"{math.exp(min(loss_val, 10)):.2f}"})

        # Средняя потеря считается только по не пропущенным батчам
        avg_loss = epoch_loss / batch_count if batch_count > 0 else float('nan')
        print(f"Средняя потеря за эпоху: {avg_loss:.4f}")
        
        if skipped_batches > 0:
            print(f"Внимание: {skipped_batches} батчей было пропущено из-за NaN/Inf в данных или лоссе.")


if __name__ == "__main__":
    train()