| | from datasets import load_dataset |
| | import random |
| | import re |
| | import evaluate |
| | import torch |
| | import numpy as np |
| | from transformers import ( |
| | pipeline, |
| | AutoModelForSeq2SeqLM, |
| | AutoTokenizer, |
| | DataCollatorForSeq2Seq, |
| | Seq2SeqTrainingArguments, |
| | Seq2SeqTrainer |
| | ) |
| |
|
| | def breaking_text(original): |
| | assert isinstance(original, str) |
| |
|
| | broken = [] |
| | btype = random.choice(['um', 'aa', 'sara', 'none']) |
| | bchar = random.choice([' ', '', '', '', '�']) |
| | for c in original: |
| | if random.random() < 0.3: |
| | btype = random.choice(['um', 'aa', 'sara', 'none']) |
| | if btype == 'um' and c == 'ำ': |
| | broken.append(' า') |
| | elif btype == 'aa' and c == 'า': |
| | broken.append('ำ') |
| | elif btype == 'sara' and re.match(r'[\u0E2F-\u0E4E]', c): |
| | broken.append(bchar) |
| | else: |
| | broken.append(c) |
| | return ''.join(broken) |
| |
|
| | def metrics_func(eval_arg): |
| | preds, labels = eval_arg |
| | |
| | preds = np.where(preds != -100, preds, tokenizer.pad_token_id) |
| | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
| | |
| | text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
| | text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
| |
|
| | ps = [] |
| | ls = [] |
| | for p,l in zip(text_preds, text_labels): |
| | if p and l: |
| | ps.append(p) |
| | ls.append(l) |
| |
|
| | return {'cer': cer_metric.compute(predictions=ps,references=ls,)} |
| |
|
| | dataset = load_dataset("pythainlp/thai_wikipedia_clean_20230101") |
| | dataset = dataset['train'].filter(lambda x: 50 < len(x['text']) < 200).train_test_split(500) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | model_repo = 'google/mt5-base' |
| | tokenizer = AutoTokenizer.from_pretrained(model_repo) |
| | model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device) |
| |
|
| | text_template = 'Fix the following corrupted text: "{}"' |
| |
|
| | def preprocess_function(examples): |
| | max_length = 256 |
| |
|
| | inputs = [text_template.format(breaking_text(ex)) for ex in examples["text"]] |
| | targets = [ex for ex in examples["text"]] |
| | model_inputs = tokenizer(inputs, max_length=max_length, truncation=True) |
| | labels = tokenizer(targets, max_length=max_length, truncation=True) |
| |
|
| | model_inputs["labels"] = labels["input_ids"] |
| | return model_inputs |
| |
|
| | tokenized_ds = dataset.map(preprocess_function, batched=True) |
| | data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt") |
| | cer_metric = evaluate.load('cer') |
| |
|
| | training_args = Seq2SeqTrainingArguments( |
| | output_dir="mt5-fixth", |
| | log_level="error", |
| | num_train_epochs=10, |
| | learning_rate=5e-4, |
| | lr_scheduler_type="linear", |
| | warmup_steps=90, |
| | optim="adafactor", |
| | weight_decay=0.01, |
| | per_device_train_batch_size=4, |
| | per_device_eval_batch_size=1, |
| | gradient_accumulation_steps=16, |
| | evaluation_strategy="steps", |
| | eval_steps=500, |
| | predict_with_generate=True, |
| | generation_max_length=254, |
| | save_steps=500, |
| | logging_steps=100, |
| | push_to_hub=False |
| | ) |
| |
|
| | trainer = Seq2SeqTrainer( |
| | model=model, |
| | args=training_args, |
| | data_collator=data_collator, |
| | compute_metrics=metrics_func, |
| | train_dataset=tokenized_ds["train"], |
| | eval_dataset=tokenized_ds["test"].select(range(500)), |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | trainer.train(resume_from_checkpoint='mt5-fixth/checkpoint-500') |
| |
|