aT5translation / aT5translation.py
Souta0919's picture
Training in progress, step 500
1dd25ea verified
#moduleのインポートモデル、トークナイザー指定
from transformers import (
T5Tokenizer,
MT5ForConditionalGeneration,
Text2TextGenerationPipeline,
)
import sys
from datasets import load_dataset, load_metric
dataset_files = {
"train": ["2train.json"],
"validation": ["2validation.json"],
"test": ["2test.json"],
}
raw_datasets=load_dataset("json",data_files=dataset_files)
print(raw_datasets)
#metric=load_metric("sacrebleu")
path = "K024/mt5-zh-ja-en-trimmed"
tokenizer=T5Tokenizer.from_pretrained(path)
max_input_length = 128
max_target_length = 128
source_lang = "en"
target_lang = "ja"
ratio=0.7
result_dir="C:\\Users\\kagen\\OneDrive\\デスクトップ\\sotuken\\aT5translation\\aT5translation"
#データの前処理
def preprocess_function(examples):
inputs = [ex['agetag'] + ex[source_lang] for ex in examples["translation"]]
targets = [ex[target_lang] for ex in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
#code start
print(raw_datasets['train'][0])
#preprocess_function(raw_datasets['train'][:2])
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
#FineTuning
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
model=MT5ForConditionalGeneration.from_pretrained(path)
#各変数の定義
batch_size = 8
print("args")
args = Seq2SeqTrainingArguments(
output_dir=result_dir,
evaluation_strategy = "epoch",
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=1,
fp16=True,
push_to_hub=True,
)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
import numpy as np
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [[label.strip()] for label in labels]
return preds, labels
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
print("trainingstart")
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
trainer.save_model(result_dir)