Spaces:
Build error
Build error
| import torch | |
| from transformers import Trainer, TrainingArguments | |
| from model import get_model_and_tokenizer | |
| from data_loader import get_dataloader | |
| from utils import load_config, set_seed | |
| def main(): | |
| config = load_config('configs/model_config.yaml') | |
| set_seed(config['training']['seed']) | |
| model, tokenizer = get_model_and_tokenizer(config) | |
| train_dataloader = get_dataloader(config, tokenizer, 'train') | |
| val_dataloader = get_dataloader(config, tokenizer, 'validation') | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| num_train_epochs=config['training']['num_epochs'], | |
| per_device_train_batch_size=config['training']['batch_size'], | |
| per_device_eval_batch_size=config['training']['batch_size'], | |
| warmup_steps=500, | |
| weight_decay=0.01, | |
| logging_dir='./logs', | |
| logging_steps=100, | |
| evaluation_strategy="steps", | |
| eval_steps=1000, | |
| save_steps=config['training']['save_every'], | |
| load_best_model_at_end=True, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataloader.dataset, | |
| eval_dataset=val_dataloader.dataset, | |
| ) | |
| trainer.train() | |
| trainer.save_model("./final_model") | |
| if __name__ == "__main__": | |
| main() |