|
|
from datasets import load_dataset |
|
|
import random |
|
|
import re |
|
|
import evaluate |
|
|
import torch |
|
|
import numpy as np |
|
|
from transformers import ( |
|
|
pipeline, |
|
|
AutoModelForSeq2SeqLM, |
|
|
AutoTokenizer, |
|
|
DataCollatorForSeq2Seq, |
|
|
Seq2SeqTrainingArguments, |
|
|
Seq2SeqTrainer |
|
|
) |
|
|
|
|
|
def breaking_text(original): |
|
|
assert isinstance(original, str) |
|
|
|
|
|
broken = [] |
|
|
btype = random.choice(['um', 'aa', 'sara', 'none']) |
|
|
bchar = random.choice([' ', '', '', '', '�']) |
|
|
for c in original: |
|
|
if random.random() < 0.3: |
|
|
btype = random.choice(['um', 'aa', 'sara', 'none']) |
|
|
if btype == 'um' and c == 'ำ': |
|
|
broken.append(' า') |
|
|
elif btype == 'aa' and c == 'า': |
|
|
broken.append('ำ') |
|
|
elif btype == 'sara' and re.match(r'[\u0E2F-\u0E4E]', c): |
|
|
broken.append(bchar) |
|
|
else: |
|
|
broken.append(c) |
|
|
return ''.join(broken) |
|
|
|
|
|
def metrics_func(eval_arg): |
|
|
preds, labels = eval_arg |
|
|
|
|
|
preds = np.where(preds != -100, preds, tokenizer.pad_token_id) |
|
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
|
|
|
|
|
text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
|
|
text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
|
|
|
ps = [] |
|
|
ls = [] |
|
|
for p,l in zip(text_preds, text_labels): |
|
|
if p and l: |
|
|
ps.append(p) |
|
|
ls.append(l) |
|
|
|
|
|
return {'cer': cer_metric.compute(predictions=ps,references=ls,)} |
|
|
|
|
|
dataset = load_dataset("pythainlp/thai_wikipedia_clean_20230101") |
|
|
dataset = dataset['train'].filter(lambda x: 50 < len(x['text']) < 200).train_test_split(500) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_repo = 'google/mt5-base' |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_repo) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device) |
|
|
|
|
|
text_template = 'Fix the following corrupted text: "{}"' |
|
|
|
|
|
def preprocess_function(examples): |
|
|
max_length = 256 |
|
|
|
|
|
inputs = [text_template.format(breaking_text(ex)) for ex in examples["text"]] |
|
|
targets = [ex for ex in examples["text"]] |
|
|
model_inputs = tokenizer(inputs, max_length=max_length, truncation=True) |
|
|
labels = tokenizer(targets, max_length=max_length, truncation=True) |
|
|
|
|
|
model_inputs["labels"] = labels["input_ids"] |
|
|
return model_inputs |
|
|
|
|
|
tokenized_ds = dataset.map(preprocess_function, batched=True) |
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt") |
|
|
cer_metric = evaluate.load('cer') |
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
|
output_dir="mt5-fixth", |
|
|
log_level="error", |
|
|
num_train_epochs=10, |
|
|
learning_rate=5e-4, |
|
|
lr_scheduler_type="linear", |
|
|
warmup_steps=90, |
|
|
optim="adafactor", |
|
|
weight_decay=0.01, |
|
|
per_device_train_batch_size=4, |
|
|
per_device_eval_batch_size=1, |
|
|
gradient_accumulation_steps=16, |
|
|
evaluation_strategy="steps", |
|
|
eval_steps=500, |
|
|
predict_with_generate=True, |
|
|
generation_max_length=254, |
|
|
save_steps=500, |
|
|
logging_steps=100, |
|
|
push_to_hub=False |
|
|
) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
data_collator=data_collator, |
|
|
compute_metrics=metrics_func, |
|
|
train_dataset=tokenized_ds["train"], |
|
|
eval_dataset=tokenized_ds["test"].select(range(500)), |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
|
|
|
trainer.train(resume_from_checkpoint='mt5-fixth/checkpoint-500') |
|
|
|