File size: 3,250 Bytes
c9afc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from datasets import load_dataset
import random
import re
import evaluate
import torch
import numpy as np
from transformers import (
  pipeline,
  AutoModelForSeq2SeqLM,
  AutoTokenizer,
  DataCollatorForSeq2Seq,
  Seq2SeqTrainingArguments,
  Seq2SeqTrainer
)

def breaking_text(original):
  assert isinstance(original, str)

  broken = []
  btype = random.choice(['um', 'aa', 'sara', 'none'])
  bchar = random.choice([' ', '', '', '', '�'])
  for c in original:
    if random.random() < 0.3:
        btype = random.choice(['um', 'aa', 'sara', 'none'])
    if btype == 'um' and c == 'ำ':
      broken.append(' า')
    elif btype == 'aa' and c == 'า':
      broken.append('ำ')
    elif btype == 'sara' and re.match(r'[\u0E2F-\u0E4E]', c):
      broken.append(bchar)
    else:
      broken.append(c)
  return ''.join(broken)

def metrics_func(eval_arg):
  preds, labels = eval_arg
  # Replace -100
  preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  # Convert id tokens to text
  text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
  text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  ps = []
  ls = []
  for p,l in zip(text_preds, text_labels):
    if p and l:
      ps.append(p)
      ls.append(l)

  return {'cer': cer_metric.compute(predictions=ps,references=ls,)}

dataset = load_dataset("pythainlp/thai_wikipedia_clean_20230101")
dataset = dataset['train'].filter(lambda x: 50 < len(x['text']) < 200).train_test_split(500)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_repo = 'google/mt5-base'
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device)

text_template = 'Fix the following corrupted text: "{}"'

def preprocess_function(examples):
    max_length = 256

    inputs = [text_template.format(breaking_text(ex)) for ex in examples["text"]]
    targets = [ex for ex in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=max_length, truncation=True)
    labels = tokenizer(targets, max_length=max_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_ds = dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")
cer_metric = evaluate.load('cer')

training_args = Seq2SeqTrainingArguments(
  output_dir="mt5-fixth",
  log_level="error",
  num_train_epochs=10,
  learning_rate=5e-4,
  lr_scheduler_type="linear",
  warmup_steps=90,
  optim="adafactor",
  weight_decay=0.01,
  per_device_train_batch_size=4,
  per_device_eval_batch_size=1,
  gradient_accumulation_steps=16,
  evaluation_strategy="steps",
  eval_steps=500,
  predict_with_generate=True,
  generation_max_length=254,
  save_steps=500,
  logging_steps=100,
  push_to_hub=False
)

trainer = Seq2SeqTrainer(
  model=model,
  args=training_args,
  data_collator=data_collator,
  compute_metrics=metrics_func,
  train_dataset=tokenized_ds["train"],
  eval_dataset=tokenized_ds["test"].select(range(500)),
  tokenizer=tokenizer,
)

trainer.train(resume_from_checkpoint='mt5-fixth/checkpoint-500')