import numpy as np import evaluate from datasets import load_dataset from transformers import ( MT5ForConditionalGeneration, MT5Tokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq ) # Load metrics cer_metric = evaluate.load("cer") wer_metric = evaluate.load("wer") model_nm = "google/mt5-small" tokenizer = MT5Tokenizer.from_pretrained(model_nm) model = MT5ForConditionalGeneration.from_pretrained(model_nm) def compute_metrics(eval_preds): preds, labels = eval_preds if isinstance(preds, tuple): preds = preds[0] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) # Replace -100 in labels as we can't decode them labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) cer = cer_metric.compute(predictions=decoded_preds, references=decoded_labels) wer = wer_metric.compute(predictions=decoded_preds, references=decoded_labels) return {"cer": cer, "wer": wer} def tokenize_fn(batch): inputs = tokenizer(batch['source'], padding="max_length", truncation=True, max_length=64) labels = tokenizer(batch['target'], padding="max_length", truncation=True, max_length=64) inputs["labels"] = labels["input_ids"] return inputs # Load and process data dataset = load_dataset('csv', data_files={'train': 'train.csv', 'test': 'val.csv'}) tokenized_dataset = dataset.map(tokenize_fn, batched=True) args = Seq2SeqTrainingArguments( output_dir="./translit-results", evaluation_strategy="epoch", learning_rate=2e-4, per_device_train_batch_size=16, per_device_eval_batch_size=16, weight_decay=0.01, save_total_limit=2, num_train_epochs=3, predict_with_generate=True, fp16=True, # Set to False if not using GPU logging_steps=100, ) trainer = Seq2SeqTrainer( model=model, args=args, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["test"], tokenizer=tokenizer, data_collator=DataCollatorForSeq2Seq(tokenizer, model=model), compute_metrics=compute_metrics ) trainer.train() model.save_pretrained("./final_model") tokenizer.save_pretrained("./final_model")