| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForSeq2SeqLM, |
| DataCollatorForSeq2Seq, |
| Seq2SeqTrainer, |
| Seq2SeqTrainingArguments |
| ) |
|
|
| from config import MODEL_NAME, MAX_LENGTH, DATASET_EN_ES |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
| |
| dataset = load_dataset(DATASET_EN_ES) |
|
|
| |
| |
| |
| def preprocess(example): |
| source = example["term"]["en"] |
| target = example["term"]["es"] |
|
|
| model_inputs = tokenizer( |
| source, |
| max_length=MAX_LENGTH, |
| truncation=True |
| ) |
|
|
| |
| labels = tokenizer( |
| text_target=target, |
| max_length=MAX_LENGTH, |
| truncation=True |
| ) |
|
|
| model_inputs["labels"] = labels["input_ids"] |
| return model_inputs |
|
|
| |
| tokenized_dataset = dataset.map(preprocess, remove_columns=dataset["train"].column_names) |
|
|
| |
| |
| |
| data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) |
|
|
| |
| |
| |
| training_args = Seq2SeqTrainingArguments( |
| output_dir="./my-translation-model", |
| learning_rate=2e-5, |
| per_device_train_batch_size=4, |
| num_train_epochs=3, |
| save_strategy="epoch", |
| logging_steps=50, |
| evaluation_strategy="no", |
| fp16=True |
| ) |
|
|
| |
| |
| |
| trainer = Seq2SeqTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_dataset["train"], |
| tokenizer=tokenizer, |
| data_collator=data_collator |
| ) |
|
|
| |
| trainer.train() |
|
|
| |
| model.save_pretrained("./my-translation-model") |
| tokenizer.save_pretrained("./my-translation-model") |