ajkndfjsdfasdf commited on
Commit
567deeb
·
verified ·
1 Parent(s): 3250810

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +71 -1
train.py CHANGED
@@ -1 +1,71 @@
1
- print("✅ Training script is ready. Customize this file as needed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MT5Tokenizer, MT5ForConditionalGeneration, Trainer, TrainingArguments
2
+ from transformers import ByT5Tokenizer, T5ForConditionalGeneration
3
+ from transformers import T5ForConditionalGeneration
4
+ from accelerate import init_empty_weights, infer_auto_device_map
5
+ from datasets import load_dataset
6
+ import os
7
+ import wandb
8
+
9
+ # Загружаем модель и токенизатор
10
+ model = T5ForConditionalGeneration.from_pretrained(
11
+ "google/byt5-small",
12
+ device_map="auto",
13
+ low_cpu_mem_usage=True
14
+ )
15
+ tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
16
+
17
+ # Загружаем датасет
18
+ data_files = {
19
+ "train": "mt5_training_data-1.jsonl",
20
+ "validation": "mt5_validation_data-1.jsonl"
21
+ }
22
+ dataset = load_dataset("json", data_files=data_files)
23
+
24
+ # Токенизация
25
+ def tokenize_function(examples):
26
+ model_inputs = tokenizer(examples["text"], max_length=256, truncation=True, padding="max_length")
27
+ labels = tokenizer(examples["target"], max_length=64, truncation=True, padding="max_length")
28
+ model_inputs["labels"] = labels["input_ids"]
29
+ return model_inputs
30
+
31
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
32
+
33
+ wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e")
34
+
35
+ training_args = TrainingArguments(
36
+ output_dir="./mt5-finetuned",
37
+ evaluation_strategy="steps",
38
+ eval_steps=100,
39
+ learning_rate=5e-5,
40
+ per_device_train_batch_size=140,
41
+ per_device_eval_batch_size=140,
42
+ num_train_epochs=3,
43
+ logging_steps=100,
44
+ warmup_ratio=0.06,
45
+ logging_first_step=True,
46
+ weight_decay=0.01,
47
+ logging_dir="./logs",
48
+ save_total_limit=2,
49
+ save_strategy="epoch",
50
+ report_to="wandb",
51
+ run_name="mt5-finetuning-run",
52
+ disable_tqdm=False,
53
+ max_grad_norm=1.0
54
+ )
55
+
56
+
57
+ trainer = Trainer(
58
+ model=model,
59
+ args=training_args,
60
+ train_dataset=tokenized_datasets["train"],
61
+ eval_dataset=tokenized_datasets["validation"]
62
+ )
63
+
64
+ # Обучение
65
+ trainer.train()
66
+ #trainer.train(resume_from_checkpoint=True)
67
+
68
+ # Сохраняем локально
69
+ model.save_pretrained("./mt5-finetuned")
70
+ tokenizer.save_pretrained("./mt5-finetuned")
71
+ print("✅ Модель сохранена локально в ./mt5-finetuned")