文银龙
commited on
Commit
·
b1c66ba
1
Parent(s):
ca5fd54
update
Browse files
README.md
CHANGED
|
@@ -35,10 +35,12 @@ python init_custdata_model.py \
|
|
| 35 |
--cust_vocab ./cust-data/vocab.txt \
|
| 36 |
--pretrain_model ./weights \
|
| 37 |
--cust_data_init_weights_path ./cust-data/weights
|
| 38 |
-
|
| 39 |
## cust_vocab 词库文件
|
| 40 |
## pretrain_model 预训练模型权重
|
| 41 |
-
##
|
|
|
|
|
|
|
| 42 |
|
| 43 |
### 训练模型
|
| 44 |
#### 数据准备,数据结构如下图所示
|
|
|
|
| 35 |
--cust_vocab ./cust-data/vocab.txt \
|
| 36 |
--pretrain_model ./weights \
|
| 37 |
--cust_data_init_weights_path ./cust-data/weights
|
| 38 |
+
|
| 39 |
## cust_vocab 词库文件
|
| 40 |
## pretrain_model 预训练模型权重
|
| 41 |
+
## cust_data_init_weights_path 自定义模型初始化模型权重保存位置
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
|
| 45 |
### 训练模型
|
| 46 |
#### 数据准备,数据结构如下图所示
|
eval.py
CHANGED
|
@@ -30,7 +30,7 @@ def compute_metrics(pred_str, label_str):
|
|
| 30 |
|
| 31 |
if __name__ == '__main__':
|
| 32 |
parser = argparse.ArgumentParser(description='trocr fine-tune训练')
|
| 33 |
-
parser.add_argument('--
|
| 34 |
help="初始化训练权重,用于自己数据集上fine-tune权重")
|
| 35 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
|
| 36 |
parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path")
|
|
@@ -50,11 +50,11 @@ if __name__ == '__main__':
|
|
| 50 |
|
| 51 |
print("train num:", len(train_paths), "test num:", len(test_paths))
|
| 52 |
|
| 53 |
-
processor = TrOCRProcessor.from_pretrained(args.
|
| 54 |
vocab = processor.tokenizer.get_vocab()
|
| 55 |
|
| 56 |
vocab_inp = {vocab[key]: key for key in vocab}
|
| 57 |
-
model = VisionEncoderDecoderModel.from_pretrained(args.
|
| 58 |
model.eval()
|
| 59 |
model.cuda()
|
| 60 |
|
|
|
|
| 30 |
|
| 31 |
if __name__ == '__main__':
|
| 32 |
parser = argparse.ArgumentParser(description='trocr fine-tune训练')
|
| 33 |
+
parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
|
| 34 |
help="初始化训练权重,用于自己数据集上fine-tune权重")
|
| 35 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
|
| 36 |
parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path")
|
|
|
|
| 50 |
|
| 51 |
print("train num:", len(train_paths), "test num:", len(test_paths))
|
| 52 |
|
| 53 |
+
processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path)
|
| 54 |
vocab = processor.tokenizer.get_vocab()
|
| 55 |
|
| 56 |
vocab_inp = {vocab[key]: key for key in vocab}
|
| 57 |
+
model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
|
| 58 |
model.eval()
|
| 59 |
model.cuda()
|
| 60 |
|