Spaces:
Sleeping
Sleeping
| from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments | |
| from core.model import model, processor | |
| from core.data import load | |
| from core.utils import compute_metrics | |
| from config import OUTPUT_DIR, BATCH_SIZE, EPOCHS | |
| train_ds, eval_ds = load() | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir = OUTPUT_DIR, | |
| per_device_train_batch_size = BATCH_SIZE, | |
| per_device_eval_batch_size = BATCH_SIZE, | |
| predict_with_generate = True, | |
| eval_strategy = 'epoch', | |
| logging_steps = 50, | |
| num_train_epochs = EPOCHS, | |
| save_total_limit = 1, | |
| remove_unused_columns = False, | |
| learning_rate = 5e-5, | |
| fp16 = False | |
| ) | |
| trainer = Seq2SeqTrainer( | |
| model = model, | |
| args = training_args, | |
| train_dataset = train_ds, | |
| eval_dataset = eval_ds, | |
| processing_class = processor.image_processor, | |
| compute_metrics = compute_metrics | |
| ) | |
| def train_save(): | |
| trainer.train() | |
| trainer.save_model('./model') | |
| processor.save_pretrained('./model') | |
| if __name__ == '__main__': | |
| train_save() | |