ajkndfjsdfasdf commited on
Commit
aac46f2
·
verified ·
1 Parent(s): 0d78d6d

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +26 -20
train.py CHANGED
@@ -1,18 +1,21 @@
1
  from transformers import MT5Tokenizer, MT5ForConditionalGeneration, Trainer, TrainingArguments
2
  from transformers import ByT5Tokenizer, T5ForConditionalGeneration
3
- from transformers import T5ForConditionalGeneration
4
- from accelerate import init_empty_weights, infer_auto_device_map
5
  from datasets import load_dataset
6
  import os
7
  import wandb
8
 
 
 
 
 
 
 
 
 
 
9
  # Загружаем модель и токенизатор
10
- model = T5ForConditionalGeneration.from_pretrained(
11
- "google/byt5-small",
12
- device_map="auto",
13
- low_cpu_mem_usage=True
14
- )
15
- tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
16
 
17
  # Загружаем датасет
18
  data_files = {
@@ -30,15 +33,18 @@ def tokenize_function(examples):
30
 
31
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
32
 
 
33
  wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e")
34
 
 
35
  training_args = TrainingArguments(
36
- output_dir="./mt5-finetuned",
37
  evaluation_strategy="steps",
38
  eval_steps=100,
39
  learning_rate=5e-5,
40
- per_device_train_batch_size=140,
41
- per_device_eval_batch_size=140,
 
42
  num_train_epochs=3,
43
  logging_steps=100,
44
  warmup_ratio=0.06,
@@ -47,13 +53,13 @@ training_args = TrainingArguments(
47
  logging_dir="./logs",
48
  save_total_limit=2,
49
  save_strategy="epoch",
50
- report_to="wandb",
51
- run_name="mt5-finetuning-run",
52
  disable_tqdm=False,
53
- max_grad_norm=1.0
54
  )
55
 
56
-
57
  trainer = Trainer(
58
  model=model,
59
  args=training_args,
@@ -63,9 +69,9 @@ trainer = Trainer(
63
 
64
  # Обучение
65
  trainer.train()
66
- #trainer.train(resume_from_checkpoint=True)
67
 
68
- # Сохраняем локально
69
- model.save_pretrained("./mt5-finetuned")
70
- tokenizer.save_pretrained("./mt5-finetuned")
71
- print("✅ Модель сохранена локально в ./mt5-finetuned")
 
1
  from transformers import MT5Tokenizer, MT5ForConditionalGeneration, Trainer, TrainingArguments
2
  from transformers import ByT5Tokenizer, T5ForConditionalGeneration
 
 
3
  from datasets import load_dataset
4
  import os
5
  import wandb
6
 
7
+ # 🔧 Название запуска (используется и как run_name, и как output_dir)
8
+ run_name = "byt5-finetuning-run"
9
+
10
+ # 🧠 Название модели для фантюнинга
11
+ model_id = "google/byt5-small"
12
+
13
+ # 📂 Куда сохранять результат обучения
14
+ output_dir = f"./{run_name}"
15
+
16
  # Загружаем модель и токенизатор
17
+ model = T5ForConditionalGeneration.from_pretrained(model_id)
18
+ tokenizer = ByT5Tokenizer.from_pretrained(model_id)
 
 
 
 
19
 
20
  # Загружаем датасет
21
  data_files = {
 
33
 
34
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
35
 
36
+ # Авторизация в Weights & Biases
37
  wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e")
38
 
39
+ # Аргументы обучения
40
  training_args = TrainingArguments(
41
+ output_dir=output_dir,
42
  evaluation_strategy="steps",
43
  eval_steps=100,
44
  learning_rate=5e-5,
45
+ per_device_train_batch_size=200,
46
+ per_device_eval_batch_size=200,
47
+ fp16=True,
48
  num_train_epochs=3,
49
  logging_steps=100,
50
  warmup_ratio=0.06,
 
53
  logging_dir="./logs",
54
  save_total_limit=2,
55
  save_strategy="epoch",
56
+ report_to="wandb",
57
+ run_name=run_name,
58
  disable_tqdm=False,
59
+ max_grad_norm=1.0
60
  )
61
 
62
+ # Инициализируем Trainer
63
  trainer = Trainer(
64
  model=model,
65
  args=training_args,
 
69
 
70
  # Обучение
71
  trainer.train()
72
+ # trainer.train(resume_from_checkpoint=True)
73
 
74
+ # Сохраняем модель
75
+ model.save_pretrained(output_dir)
76
+ tokenizer.save_pretrained(output_dir)
77
+ print(f"✅ Модель сохранена локально в {output_dir}")