Spaces:
Runtime error
Runtime error
File size: 2,263 Bytes
88d9f70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import numpy as np
import evaluate
from datasets import load_dataset
from transformers import (
MT5ForConditionalGeneration,
MT5Tokenizer,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
DataCollatorForSeq2Seq
)
# Load metrics
cer_metric = evaluate.load("cer")
wer_metric = evaluate.load("wer")
model_nm = "google/mt5-small"
tokenizer = MT5Tokenizer.from_pretrained(model_nm)
model = MT5ForConditionalGeneration.from_pretrained(model_nm)
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100 in labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
cer = cer_metric.compute(predictions=decoded_preds, references=decoded_labels)
wer = wer_metric.compute(predictions=decoded_preds, references=decoded_labels)
return {"cer": cer, "wer": wer}
def tokenize_fn(batch):
inputs = tokenizer(batch['source'], padding="max_length", truncation=True, max_length=64)
labels = tokenizer(batch['target'], padding="max_length", truncation=True, max_length=64)
inputs["labels"] = labels["input_ids"]
return inputs
# Load and process data
dataset = load_dataset('csv', data_files={'train': 'train.csv', 'test': 'val.csv'})
tokenized_dataset = dataset.map(tokenize_fn, batched=True)
args = Seq2SeqTrainingArguments(
output_dir="./translit-results",
evaluation_strategy="epoch",
learning_rate=2e-4,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
save_total_limit=2,
num_train_epochs=3,
predict_with_generate=True,
fp16=True, # Set to False if not using GPU
logging_steps=100,
)
trainer = Seq2SeqTrainer(
model=model,
args=args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=tokenizer,
data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
compute_metrics=compute_metrics
)
trainer.train()
model.save_pretrained("./final_model")
tokenizer.save_pretrained("./final_model") |