File size: 3,869 Bytes
07d2e4c
567deeb
 
 
9f144f4
567deeb
9f144f4
07d2e4c
 
9f144f4
b8ec39a
07d2e4c
aac46f2
9f144f4
 
07d2e4c
aac46f2
9f144f4
567deeb
07d2e4c
 
567deeb
 
 
9f144f4
567deeb
07d2e4c
 
 
 
 
 
 
 
 
 
 
 
567deeb
 
 
 
 
9f144f4
b8ec39a
567deeb
9f144f4
 
 
 
 
 
 
 
 
 
07d2e4c
9f144f4
 
07d2e4c
 
9f144f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567deeb
9f144f4
 
 
 
 
 
 
 
 
 
567deeb
9f144f4
 
 
567deeb
9f144f4
aac46f2
 
9f144f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import os
import wandb
import torch

# 🔧 Название модели и путь
model_name = "google/flan-t5-large"
run_id = "flan-t5-large-ru-autobatch"
output_dir = f"./{run_id}"
start_batch_size = 20   # ⚠️ Начинаем с небольшого batch, чтобы избежать OOM
step_batch_size = 1

# 📦 Загружаем модель и токенизатор
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

# 📂 Загружаем датасет
data_files = {
    "train": "mt5_ru_gen_async.jsonl",
    "validation": "mt5_ru_gen_eval.jsonl"
}
dataset = load_dataset("json", data_files=data_files)

# 🔠 Токенизация
def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["text"], max_length=256, truncation=True, padding="max_length"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["target"], max_length=256, truncation=True, padding="max_length"
        )
    # Заменяем PAD-токены на -100, чтобы не учитывать их в подсчёте loss
    labels["input_ids"] = [
        [(token if token != tokenizer.pad_token_id else -100) for token in label]
        for label in labels["input_ids"]
    ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 🔑 Авторизация W&B
wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e")

# 🚀 Функция автоподбора batch size
def try_training_with_batch_size(batch_size_start):
    batch_size = batch_size_start
    while batch_size > 0:
        try:
            print(f"\n🚀 Пробуем batch_size = {batch_size}")
            training_args = TrainingArguments(
                output_dir=output_dir,
                evaluation_strategy="steps",
                eval_steps=100,
                learning_rate=3e-5,
                per_device_train_batch_size=batch_size,
                per_device_eval_batch_size=batch_size,
                #fp16=True,  # Включайте при наличии подходящего GPU (A100 / V100 / T4)
                num_train_epochs=10,
                logging_steps=100,
                warmup_ratio=0.06,
                logging_first_step=True,
                weight_decay=0.01,
                logging_dir="./logs",
                save_total_limit=2,
                save_strategy="epoch",
                report_to="wandb",
                run_name=run_id,
                disable_tqdm=False,
                max_grad_norm=1.0
            )

            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=tokenized_datasets["train"],
                eval_dataset=tokenized_datasets["validation"]
            )

            trainer.train()
            return batch_size
        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                print(f"❌ OOM на batch_size = {batch_size}, уменьшаем...")
                torch.cuda.empty_cache()
                batch_size -= step_batch_size
            else:
                raise e
    raise RuntimeError("Не удалось подобрать подходящий batch size 😢")

# 🏁 Запуск с автоподбором
optimal_batch_size = try_training_with_batch_size(start_batch_size)
print(f"\n✅ Успешно обучено с batch_size = {optimal_batch_size}")

# 💾 Сохраняем модель
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"✅ Модель сохранена в {output_dir}")