ajkndfjsdfasdf commited on
Commit
9f144f4
·
verified ·
1 Parent(s): aac46f2

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +63 -50
train.py CHANGED
@@ -1,30 +1,28 @@
1
- from transformers import MT5Tokenizer, MT5ForConditionalGeneration, Trainer, TrainingArguments
2
- from transformers import ByT5Tokenizer, T5ForConditionalGeneration
3
  from datasets import load_dataset
4
  import os
5
  import wandb
 
6
 
7
- # 🔧 Название запуска (используется и как run_name, и как output_dir)
8
- run_name = "byt5-finetuning-run"
 
 
 
 
9
 
10
- # 🧠 Название модели для фантюнинга
11
- model_id = "google/byt5-small"
 
12
 
13
- # 📂 Куда сохранять результат обучения
14
- output_dir = f"./{run_name}"
15
-
16
- # Загружаем модель и токенизатор
17
- model = T5ForConditionalGeneration.from_pretrained(model_id)
18
- tokenizer = ByT5Tokenizer.from_pretrained(model_id)
19
-
20
- # Загружаем датасет
21
  data_files = {
22
  "train": "mt5_training_data-1.jsonl",
23
  "validation": "mt5_validation_data-1.jsonl"
24
  }
25
  dataset = load_dataset("json", data_files=data_files)
26
 
27
- # Токенизация
28
  def tokenize_function(examples):
29
  model_inputs = tokenizer(examples["text"], max_length=256, truncation=True, padding="max_length")
30
  labels = tokenizer(examples["target"], max_length=64, truncation=True, padding="max_length")
@@ -33,45 +31,60 @@ def tokenize_function(examples):
33
 
34
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
35
 
36
- # Авторизация в Weights & Biases
37
  wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e")
38
 
39
- # Аргументы обучения
40
- training_args = TrainingArguments(
41
- output_dir=output_dir,
42
- evaluation_strategy="steps",
43
- eval_steps=100,
44
- learning_rate=5e-5,
45
- per_device_train_batch_size=200,
46
- per_device_eval_batch_size=200,
47
- fp16=True,
48
- num_train_epochs=3,
49
- logging_steps=100,
50
- warmup_ratio=0.06,
51
- logging_first_step=True,
52
- weight_decay=0.01,
53
- logging_dir="./logs",
54
- save_total_limit=2,
55
- save_strategy="epoch",
56
- report_to="wandb",
57
- run_name=run_name,
58
- disable_tqdm=False,
59
- max_grad_norm=1.0
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- # Инициализируем Trainer
63
- trainer = Trainer(
64
- model=model,
65
- args=training_args,
66
- train_dataset=tokenized_datasets["train"],
67
- eval_dataset=tokenized_datasets["validation"]
68
- )
 
 
 
69
 
70
- # Обучение
71
- trainer.train()
72
- # trainer.train(resume_from_checkpoint=True)
73
 
74
- # Сохраняем модель
75
  model.save_pretrained(output_dir)
76
  tokenizer.save_pretrained(output_dir)
77
- print(f"✅ Модель сохранена локально в {output_dir}")
 
1
+ from transformers import T5ForConditionalGeneration, ByT5Tokenizer, Trainer, TrainingArguments
 
2
  from datasets import load_dataset
3
  import os
4
  import wandb
5
+ import torch
6
 
7
+ # 🔧 Название модели и путь
8
+ model_name = "google/byt5-small"
9
+ run_id = "byt5-autobatch"
10
+ output_dir = f"./{run_id}"
11
+ start_batch_size = 300
12
+ step_batch_size = 5
13
 
14
+ # 📦 Загружаем модель и токенизатор
15
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
16
+ tokenizer = ByT5Tokenizer.from_pretrained(model_name)
17
 
18
+ # 📂 Загружаем датасет
 
 
 
 
 
 
 
19
  data_files = {
20
  "train": "mt5_training_data-1.jsonl",
21
  "validation": "mt5_validation_data-1.jsonl"
22
  }
23
  dataset = load_dataset("json", data_files=data_files)
24
 
25
+ # 🔠 Токенизация
26
  def tokenize_function(examples):
27
  model_inputs = tokenizer(examples["text"], max_length=256, truncation=True, padding="max_length")
28
  labels = tokenizer(examples["target"], max_length=64, truncation=True, padding="max_length")
 
31
 
32
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
33
 
34
+ # 🔑 Авторизация W&B
35
  wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e")
36
 
37
+ # 🚀 Функция автоподбора batch size
38
+ def try_training_with_batch_size(batch_size_start):
39
+ batch_size = batch_size_start
40
+ while batch_size > 0:
41
+ try:
42
+ print(f"\n🚀 Пробуем batch_size = {batch_size}")
43
+ training_args = TrainingArguments(
44
+ output_dir=output_dir,
45
+ evaluation_strategy="steps",
46
+ eval_steps=100,
47
+ learning_rate=5e-5,
48
+ per_device_train_batch_size=batch_size,
49
+ per_device_eval_batch_size=batch_size,
50
+ fp16=True,
51
+ num_train_epochs=3,
52
+ logging_steps=100,
53
+ warmup_ratio=0.06,
54
+ logging_first_step=True,
55
+ weight_decay=0.01,
56
+ logging_dir="./logs",
57
+ save_total_limit=2,
58
+ save_strategy="epoch",
59
+ report_to="wandb",
60
+ run_name=run_id,
61
+ disable_tqdm=False,
62
+ max_grad_norm=1.0
63
+ )
64
+
65
+ trainer = Trainer(
66
+ model=model,
67
+ args=training_args,
68
+ train_dataset=tokenized_datasets["train"],
69
+ eval_dataset=tokenized_datasets["validation"]
70
+ )
71
 
72
+ trainer.train()
73
+ return batch_size
74
+ except RuntimeError as e:
75
+ if "CUDA out of memory" in str(e):
76
+ print(f"❌ OOM на batch_size = {batch_size}, уменьшаем...")
77
+ torch.cuda.empty_cache()
78
+ batch_size -= step_batch_size
79
+ else:
80
+ raise e
81
+ raise RuntimeError("Не удалось подобрать подходящий batch size 😢")
82
 
83
+ # 🏁 Запуск с автоподбором
84
+ optimal_batch_size = try_training_with_batch_size(start_batch_size)
85
+ print(f"\n✅ Успешно обучено с batch_size = {optimal_batch_size}")
86
 
87
+ # 💾 Сохраняем модель
88
  model.save_pretrained(output_dir)
89
  tokenizer.save_pretrained(output_dir)
90
+ print(f"✅ Модель сохранена в {output_dir}")