translation-model / train.py
drixo's picture
Update train.py
7ae3549 verified
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments
)
from config import MODEL_NAME, MAX_LENGTH, DATASET_EN_ES
# Load tokenizer + model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# Load dataset
dataset = load_dataset(DATASET_EN_ES)
# -----------------------------
# FIX: proper preprocessing
# -----------------------------
def preprocess(example):
source = example["term"]["en"]
target = example["term"]["es"]
model_inputs = tokenizer(
source,
max_length=MAX_LENGTH,
truncation=True
)
# IMPORTANT FIX: use text_target (correct way for seq2seq)
labels = tokenizer(
text_target=target,
max_length=MAX_LENGTH,
truncation=True
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# Apply preprocessing
tokenized_dataset = dataset.map(preprocess, remove_columns=dataset["train"].column_names)
# -----------------------------
# Data collator
# -----------------------------
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
# -----------------------------
# Training arguments
# -----------------------------
training_args = Seq2SeqTrainingArguments(
output_dir="./my-translation-model",
learning_rate=2e-5,
per_device_train_batch_size=4,
num_train_epochs=3,
save_strategy="epoch",
logging_steps=50,
evaluation_strategy="no",
fp16=True # faster if GPU supports it
)
# -----------------------------
# Trainer
# -----------------------------
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
tokenizer=tokenizer,
data_collator=data_collator
)
# Train
trainer.train()
# Save model
model.save_pretrained("./my-translation-model")
tokenizer.save_pretrained("./my-translation-model")