from datasets import load_dataset import random import re import evaluate import torch import numpy as np from transformers import ( pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer ) def breaking_text(original): assert isinstance(original, str) broken = [] btype = random.choice(['um', 'aa', 'sara', 'none']) bchar = random.choice([' ', '', '', '', '�']) for c in original: if random.random() < 0.3: btype = random.choice(['um', 'aa', 'sara', 'none']) if btype == 'um' and c == 'ำ': broken.append(' า') elif btype == 'aa' and c == 'า': broken.append('ำ') elif btype == 'sara' and re.match(r'[\u0E2F-\u0E4E]', c): broken.append(bchar) else: broken.append(c) return ''.join(broken) def metrics_func(eval_arg): preds, labels = eval_arg # Replace -100 preds = np.where(preds != -100, preds, tokenizer.pad_token_id) labels = np.where(labels != -100, labels, tokenizer.pad_token_id) # Convert id tokens to text text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) ps = [] ls = [] for p,l in zip(text_preds, text_labels): if p and l: ps.append(p) ls.append(l) return {'cer': cer_metric.compute(predictions=ps,references=ls,)} dataset = load_dataset("pythainlp/thai_wikipedia_clean_20230101") dataset = dataset['train'].filter(lambda x: 50 < len(x['text']) < 200).train_test_split(500) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_repo = 'google/mt5-base' tokenizer = AutoTokenizer.from_pretrained(model_repo) model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device) text_template = 'Fix the following corrupted text: "{}"' def preprocess_function(examples): max_length = 256 inputs = [text_template.format(breaking_text(ex)) for ex in examples["text"]] targets = [ex for ex in examples["text"]] model_inputs = tokenizer(inputs, max_length=max_length, truncation=True) labels = tokenizer(targets, max_length=max_length, truncation=True) model_inputs["labels"] = labels["input_ids"] return model_inputs tokenized_ds = dataset.map(preprocess_function, batched=True) data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt") cer_metric = evaluate.load('cer') training_args = Seq2SeqTrainingArguments( output_dir="mt5-fixth", log_level="error", num_train_epochs=10, learning_rate=5e-4, lr_scheduler_type="linear", warmup_steps=90, optim="adafactor", weight_decay=0.01, per_device_train_batch_size=4, per_device_eval_batch_size=1, gradient_accumulation_steps=16, evaluation_strategy="steps", eval_steps=500, predict_with_generate=True, generation_max_length=254, save_steps=500, logging_steps=100, push_to_hub=False ) trainer = Seq2SeqTrainer( model=model, args=training_args, data_collator=data_collator, compute_metrics=metrics_func, train_dataset=tokenized_ds["train"], eval_dataset=tokenized_ds["test"].select(range(500)), tokenizer=tokenizer, ) trainer.train(resume_from_checkpoint='mt5-fixth/checkpoint-500')