|
|
import os |
|
|
import argparse |
|
|
from glob import glob |
|
|
from dataset import trocrDataset, decode_text |
|
|
from transformers import TrOCRProcessor |
|
|
from transformers import VisionEncoderDecoderModel |
|
|
from transformers import default_data_collator |
|
|
from sklearn.model_selection import train_test_split |
|
|
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments |
|
|
from datasets import load_metric |
|
|
|
|
|
def compute_metrics(pred): |
|
|
""" |
|
|
计算cer,acc |
|
|
:param pred: |
|
|
:return: |
|
|
""" |
|
|
labels_ids = pred.label_ids |
|
|
pred_ids = pred.predictions |
|
|
|
|
|
pred_str = [decode_text(pred_id, vocab, vocab_inp) for pred_id in pred_ids] |
|
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
|
label_str = [decode_text(labels_id, vocab, vocab_inp) for labels_id in labels_ids] |
|
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
|
|
acc = [pred == label for pred, label in zip(pred_str, label_str)] |
|
|
print([pred_str[0], label_str[0]]) |
|
|
acc = sum(acc)/(len(acc)+0.000001) |
|
|
|
|
|
return {"cer": cer, "acc": acc} |
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description='trocr fine-tune训练') |
|
|
parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str, |
|
|
help="初始化训练权重,用于自己数据集上fine-tune权重") |
|
|
parser.add_argument('--checkpoint_path', default='./checkpoint/trocr', type=str, help="训练模型保存地址") |
|
|
parser.add_argument('--dataset_path', default='./dataset/cust-data/*/*.jpg', type=str, help="训练数据集") |
|
|
parser.add_argument('--per_device_train_batch_size', default=32, type=int, help="train batch size") |
|
|
parser.add_argument('--per_device_eval_batch_size', default=8, type=int, help="eval batch size") |
|
|
parser.add_argument('--max_target_length', default=128, type=int, help="训练文字字符数") |
|
|
|
|
|
parser.add_argument('--num_train_epochs', default=10, type=int, help="训练epoch数") |
|
|
parser.add_argument('--eval_steps', default=1000, type=int, help="模型评估间隔数") |
|
|
parser.add_argument('--save_steps', default=1000, type=int, help="模型保存间隔步数") |
|
|
|
|
|
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='0,1', type=str, help="GPU设置") |
|
|
|
|
|
args = parser.parse_args() |
|
|
print("train param") |
|
|
print(args) |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES |
|
|
print("loading data .................") |
|
|
paths = glob(args.dataset_path) |
|
|
|
|
|
train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=10086) |
|
|
print("train num:", len(train_paths), "test num:", len(test_paths)) |
|
|
|
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path) |
|
|
vocab = processor.tokenizer.get_vocab() |
|
|
vocab_inp = {vocab[key]: key for key in vocab} |
|
|
transformer = lambda x: x |
|
|
|
|
|
train_dataset = trocrDataset(paths=train_paths, processor=processor, max_target_length=args.max_target_length, transformer=transformer) |
|
|
transformer = lambda x: x |
|
|
eval_dataset = trocrDataset(paths=test_paths, processor=processor, max_target_length=args.max_target_length, transformer=transformer) |
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path) |
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
|
|
model.config.vocab_size = model.config.decoder.vocab_size |
|
|
|
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
|
model.config.max_length = 256 |
|
|
model.config.early_stopping = True |
|
|
model.config.no_repeat_ngram_size = 3 |
|
|
model.config.length_penalty = 2.0 |
|
|
model.config.num_beams = 4 |
|
|
|
|
|
cer_metric = load_metric("./cer.py") |
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
|
predict_with_generate=True, |
|
|
evaluation_strategy="steps", |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=8, |
|
|
fp16=False, |
|
|
output_dir=args.checkpoint_path, |
|
|
logging_steps=10, |
|
|
num_train_epochs=args.num_train_epochs, |
|
|
save_steps=args.eval_steps, |
|
|
eval_steps=args.eval_steps, |
|
|
save_total_limit=5, |
|
|
use_mps_device=True, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
|
model=model, |
|
|
tokenizer=processor.feature_extractor, |
|
|
args=training_args, |
|
|
compute_metrics=compute_metrics, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
data_collator=default_data_collator, |
|
|
) |
|
|
trainer.train() |
|
|
trainer.save_model(os.path.join(args.checkpoint_path, 'last')) |
|
|
processor.save_pretrained(os.path.join(args.checkpoint_path, 'last')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|