mt5-fixpdftext / train-fix-text.py
napatswift's picture
Upload train-fix-text.py
c9afc58
from datasets import load_dataset
import random
import re
import evaluate
import torch
import numpy as np
from transformers import (
pipeline,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments,
Seq2SeqTrainer
)
def breaking_text(original):
assert isinstance(original, str)
broken = []
btype = random.choice(['um', 'aa', 'sara', 'none'])
bchar = random.choice([' ', '', '', '', '�'])
for c in original:
if random.random() < 0.3:
btype = random.choice(['um', 'aa', 'sara', 'none'])
if btype == 'um' and c == 'ำ':
broken.append(' า')
elif btype == 'aa' and c == 'า':
broken.append('ำ')
elif btype == 'sara' and re.match(r'[\u0E2F-\u0E4E]', c):
broken.append(bchar)
else:
broken.append(c)
return ''.join(broken)
def metrics_func(eval_arg):
preds, labels = eval_arg
# Replace -100
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
# Convert id tokens to text
text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
ps = []
ls = []
for p,l in zip(text_preds, text_labels):
if p and l:
ps.append(p)
ls.append(l)
return {'cer': cer_metric.compute(predictions=ps,references=ls,)}
dataset = load_dataset("pythainlp/thai_wikipedia_clean_20230101")
dataset = dataset['train'].filter(lambda x: 50 < len(x['text']) < 200).train_test_split(500)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_repo = 'google/mt5-base'
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device)
text_template = 'Fix the following corrupted text: "{}"'
def preprocess_function(examples):
max_length = 256
inputs = [text_template.format(breaking_text(ex)) for ex in examples["text"]]
targets = [ex for ex in examples["text"]]
model_inputs = tokenizer(inputs, max_length=max_length, truncation=True)
labels = tokenizer(targets, max_length=max_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_ds = dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")
cer_metric = evaluate.load('cer')
training_args = Seq2SeqTrainingArguments(
output_dir="mt5-fixth",
log_level="error",
num_train_epochs=10,
learning_rate=5e-4,
lr_scheduler_type="linear",
warmup_steps=90,
optim="adafactor",
weight_decay=0.01,
per_device_train_batch_size=4,
per_device_eval_batch_size=1,
gradient_accumulation_steps=16,
evaluation_strategy="steps",
eval_steps=500,
predict_with_generate=True,
generation_max_length=254,
save_steps=500,
logging_steps=100,
push_to_hub=False
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=data_collator,
compute_metrics=metrics_func,
train_dataset=tokenized_ds["train"],
eval_dataset=tokenized_ds["test"].select(range(500)),
tokenizer=tokenizer,
)
trainer.train(resume_from_checkpoint='mt5-fixth/checkpoint-500')