napatswift commited on
Commit
c9afc58
·
1 Parent(s): b9adfcd

Upload train-fix-text.py

Browse files
Files changed (1) hide show
  1. train-fix-text.py +110 -0
train-fix-text.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import random
3
+ import re
4
+ import evaluate
5
+ import torch
6
+ import numpy as np
7
+ from transformers import (
8
+ pipeline,
9
+ AutoModelForSeq2SeqLM,
10
+ AutoTokenizer,
11
+ DataCollatorForSeq2Seq,
12
+ Seq2SeqTrainingArguments,
13
+ Seq2SeqTrainer
14
+ )
15
+
16
+ def breaking_text(original):
17
+ assert isinstance(original, str)
18
+
19
+ broken = []
20
+ btype = random.choice(['um', 'aa', 'sara', 'none'])
21
+ bchar = random.choice([' ', '', '', '', '�'])
22
+ for c in original:
23
+ if random.random() < 0.3:
24
+ btype = random.choice(['um', 'aa', 'sara', 'none'])
25
+ if btype == 'um' and c == 'ำ':
26
+ broken.append(' า')
27
+ elif btype == 'aa' and c == 'า':
28
+ broken.append('ำ')
29
+ elif btype == 'sara' and re.match(r'[\u0E2F-\u0E4E]', c):
30
+ broken.append(bchar)
31
+ else:
32
+ broken.append(c)
33
+ return ''.join(broken)
34
+
35
+ def metrics_func(eval_arg):
36
+ preds, labels = eval_arg
37
+ # Replace -100
38
+ preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
39
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
40
+ # Convert id tokens to text
41
+ text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
42
+ text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
43
+
44
+ ps = []
45
+ ls = []
46
+ for p,l in zip(text_preds, text_labels):
47
+ if p and l:
48
+ ps.append(p)
49
+ ls.append(l)
50
+
51
+ return {'cer': cer_metric.compute(predictions=ps,references=ls,)}
52
+
53
+ dataset = load_dataset("pythainlp/thai_wikipedia_clean_20230101")
54
+ dataset = dataset['train'].filter(lambda x: 50 < len(x['text']) < 200).train_test_split(500)
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+
58
+ model_repo = 'google/mt5-base'
59
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
60
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_repo).to(device)
61
+
62
+ text_template = 'Fix the following corrupted text: "{}"'
63
+
64
+ def preprocess_function(examples):
65
+ max_length = 256
66
+
67
+ inputs = [text_template.format(breaking_text(ex)) for ex in examples["text"]]
68
+ targets = [ex for ex in examples["text"]]
69
+ model_inputs = tokenizer(inputs, max_length=max_length, truncation=True)
70
+ labels = tokenizer(targets, max_length=max_length, truncation=True)
71
+
72
+ model_inputs["labels"] = labels["input_ids"]
73
+ return model_inputs
74
+
75
+ tokenized_ds = dataset.map(preprocess_function, batched=True)
76
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")
77
+ cer_metric = evaluate.load('cer')
78
+
79
+ training_args = Seq2SeqTrainingArguments(
80
+ output_dir="mt5-fixth",
81
+ log_level="error",
82
+ num_train_epochs=10,
83
+ learning_rate=5e-4,
84
+ lr_scheduler_type="linear",
85
+ warmup_steps=90,
86
+ optim="adafactor",
87
+ weight_decay=0.01,
88
+ per_device_train_batch_size=4,
89
+ per_device_eval_batch_size=1,
90
+ gradient_accumulation_steps=16,
91
+ evaluation_strategy="steps",
92
+ eval_steps=500,
93
+ predict_with_generate=True,
94
+ generation_max_length=254,
95
+ save_steps=500,
96
+ logging_steps=100,
97
+ push_to_hub=False
98
+ )
99
+
100
+ trainer = Seq2SeqTrainer(
101
+ model=model,
102
+ args=training_args,
103
+ data_collator=data_collator,
104
+ compute_metrics=metrics_func,
105
+ train_dataset=tokenized_ds["train"],
106
+ eval_dataset=tokenized_ds["test"].select(range(500)),
107
+ tokenizer=tokenizer,
108
+ )
109
+
110
+ trainer.train(resume_from_checkpoint='mt5-fixth/checkpoint-500')