File size: 2,190 Bytes
ff2f4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import MT5Tokenizer, MT5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset
import os
import wandb

#cd workspace && pip install --no-cache-dir -r requirements.txt
#apt-get update && apt-get install -y screen & apt install git-lfs -y
#screen -S train
#python train.py

# Загружаем модель и токенизатор
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")

# Загружаем датасет
data_files = {
    "train": "mt5_training_data-1.jsonl",
    "validation": "mt5_validation_data-1.jsonl"
}
dataset = load_dataset("json", data_files=data_files)

# Токенизация
def tokenize_function(examples):
    model_inputs = tokenizer(examples["text"], max_length=256, truncation=True, padding="max_length")
    labels = tokenizer(examples["target"], max_length=64, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(tokenize_function, batched=True)

wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e")

training_args = TrainingArguments(
    output_dir="./mt5-finetuned",
    evaluation_strategy="steps",
    eval_steps=100,
    learning_rate=5e-5,
    per_device_train_batch_size=250,
    per_device_eval_batch_size=250,
    num_train_epochs=3,
    logging_steps=100,
    warmup_ratio=0.06,
    logging_first_step=True,
    weight_decay=0.01,
    logging_dir="./logs",
    save_total_limit=2,
    save_strategy="epoch",
    report_to="wandb",                     
    run_name="mt5-finetuning-run",   
    disable_tqdm=False,
    max_grad_norm=1.0                       
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)

# Обучение
#trainer.train()
trainer.train(resume_from_checkpoint=True)

# Сохраняем локально
model.save_pretrained("./mt5-finetuned")
tokenizer.save_pretrained("./mt5-finetuned")
print("✅ Модель сохранена локально в ./mt5-finetuned")