| from transformers import T5ForConditionalGeneration, ByT5Tokenizer, Trainer, TrainingArguments |
| from datasets import load_dataset |
| import os |
| import wandb |
| import torch |
|
|
| |
| model_name = "google/byt5-small" |
| run_id = "byt5-autobatch" |
| output_dir = f"./{run_id}" |
| start_batch_size = 300 |
| step_batch_size = 5 |
|
|
| |
| model = T5ForConditionalGeneration.from_pretrained(model_name) |
| tokenizer = ByT5Tokenizer.from_pretrained(model_name) |
|
|
| |
| data_files = { |
| "train": "mt5_training_data-1.jsonl", |
| "validation": "mt5_validation_data-1.jsonl" |
| } |
| dataset = load_dataset("json", data_files=data_files) |
|
|
| |
| def tokenize_function(examples): |
| model_inputs = tokenizer(examples["text"], max_length=256, truncation=True, padding="max_length") |
| labels = tokenizer(examples["target"], max_length=64, truncation=True, padding="max_length") |
| model_inputs["labels"] = labels["input_ids"] |
| return model_inputs |
|
|
| tokenized_datasets = dataset.map(tokenize_function, batched=True) |
|
|
| |
| wandb.login(key="5f028bc0142fb7fa45bdacdde3c00dbbaf8bf98e") |
|
|
| |
| def try_training_with_batch_size(batch_size_start): |
| batch_size = batch_size_start |
| while batch_size > 0: |
| try: |
| print(f"\n🚀 Пробуем batch_size = {batch_size}") |
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| evaluation_strategy="steps", |
| eval_steps=100, |
| learning_rate=5e-5, |
| per_device_train_batch_size=batch_size, |
| per_device_eval_batch_size=batch_size, |
| |
| num_train_epochs=3, |
| logging_steps=100, |
| warmup_ratio=0.06, |
| logging_first_step=True, |
| weight_decay=0.01, |
| logging_dir="./logs", |
| save_total_limit=2, |
| save_strategy="epoch", |
| report_to="wandb", |
| run_name=run_id, |
| disable_tqdm=False, |
| max_grad_norm=1.0 |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_datasets["train"], |
| eval_dataset=tokenized_datasets["validation"] |
| ) |
|
|
| trainer.train() |
| return batch_size |
| except RuntimeError as e: |
| if "CUDA out of memory" in str(e): |
| print(f"❌ OOM на batch_size = {batch_size}, уменьшаем...") |
| torch.cuda.empty_cache() |
| batch_size -= step_batch_size |
| else: |
| raise e |
| raise RuntimeError("Не удалось подобрать подходящий batch size 😢") |
|
|
| |
| optimal_batch_size = try_training_with_batch_size(start_batch_size) |
| print(f"\n✅ Успешно обучено с batch_size = {optimal_batch_size}") |
|
|
| |
| model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
| print(f"✅ Модель сохранена в {output_dir}") |
|
|