文银龙
commited on
Commit
·
4db1ad4
1
Parent(s):
b1c66ba
fix train acc=0 with filter <pad>
Browse files- .gitignore +1 -0
- README.md +38 -9
- app.py +1 -1
- dataset.py +13 -3
- eval.py +3 -6
- gen_vocab.py +25 -0
- img/im2latex.png +0 -0
- train.py +10 -4
.gitignore
CHANGED
|
@@ -24,4 +24,5 @@ __pycache__/*
|
|
| 24 |
dataset/test.json
|
| 25 |
dataset/train.json
|
| 26 |
app/
|
|
|
|
| 27 |
|
|
|
|
| 24 |
dataset/test.json
|
| 25 |
dataset/train.json
|
| 26 |
app/
|
| 27 |
+
models/
|
| 28 |
|
README.md
CHANGED
|
@@ -19,13 +19,16 @@ docker run --gpus all -it -v /tmp/trocr-chinese:/trocr-chinese trocr-chinese:lat
|
|
| 19 |
vocab.txt
|
| 20 |
1
|
| 21 |
2
|
| 22 |
-
3
|
| 23 |
-
4
|
| 24 |
-
5
|
| 25 |
...
|
| 26 |
a
|
| 27 |
b
|
| 28 |
c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
```
|
| 30 |
### 初始化自定义数据集模型
|
| 31 |
#### 下载预训练模型trocr模型权重
|
|
@@ -62,28 +65,54 @@ python train.py \
|
|
| 62 |
--CUDA_VISIBLE_DEVICES 1
|
| 63 |
```
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
## 测试模型
|
| 66 |
```
|
| 67 |
## 拷贝训练完成的pytorch_model.bin 到 ./cust-data/weights 目录下
|
| 68 |
-
index = 2300
|
| 69 |
cp ./checkpoint/trocr-custdata/checkpoint-$index/pytorch_model.bin ./cust-data/weights
|
| 70 |
-
python app.py --test_img test/test.jpg
|
| 71 |
```
|
| 72 |
|
| 73 |
## 预训练模型
|
| 74 |
| 模型 | cer(字符错误率) | acc(文本行) | 下载地址 |训练数据来源 |训练耗时(GPU:3090) |
|
| 75 |
| ------------- |:-------------:| -----:|-----:|-----:|-----:|
|
| 76 |
-
| hand-write(中文手写) |0.011 | 0.940
|
| 77 |
-
| 印章识别 |- | - |- |- |
|
| 78 |
-
| im2latex(数学公式识别) |- | - |- |https://zenodo.org/record/56198#.YkniL25Bx_S
|
| 79 |
-
| 表格识别
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
备注:后续所有模型会开源在这个目录下链接,可以自由下载. https://pan.baidu.com/s/1uSdWQhJPEy2CYoEULoOhRA 密码: vwi2
|
| 82 |
### 模型调用
|
|
|
|
|
|
|
| 83 |
```
|
| 84 |
unzip hand-write.zip
|
| 85 |
python app.py --cust_data_init_weights_path hand-write --test_img test/hand.png
|
|
|
|
|
|
|
| 86 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
## 捐助
|
| 88 |
如果此项目给您的工作带来了帮忙,希望您能贡献自己微薄的爱心,
|
| 89 |
该项目的每一份收入将用着福利事业,每一季度在issues上公布捐赠明细!
|
|
|
|
| 19 |
vocab.txt
|
| 20 |
1
|
| 21 |
2
|
|
|
|
|
|
|
|
|
|
| 22 |
...
|
| 23 |
a
|
| 24 |
b
|
| 25 |
c
|
| 26 |
+
```
|
| 27 |
+
```[python]
|
| 28 |
+
python gen_vocab.py \
|
| 29 |
+
--dataset_dataset_path "dataset/cust-data/0/*.txt" \
|
| 30 |
+
--cust_vocab ./cust-data/vocab.txt
|
| 31 |
+
|
| 32 |
```
|
| 33 |
### 初始化自定义数据集模型
|
| 34 |
#### 下载预训练模型trocr模型权重
|
|
|
|
| 65 |
--CUDA_VISIBLE_DEVICES 1
|
| 66 |
```
|
| 67 |
|
| 68 |
+
#### 评估模型
|
| 69 |
+
##### 拷贝checkpoint/trocr-custdata训练完成的pytorch_model.bin 到 ./cust-data/weights 目录下
|
| 70 |
+
|
| 71 |
+
```[python]
|
| 72 |
+
python eval.py \
|
| 73 |
+
--dataset_path "./data/cust-data/test/*/*.jpg" \
|
| 74 |
+
--cust_data_init_weights_path ./cust-data/weights
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
## 测试模型
|
| 78 |
```
|
| 79 |
## 拷贝训练完成的pytorch_model.bin 到 ./cust-data/weights 目录下
|
| 80 |
+
index = 2300 ##选择最好的或者最后一个step模型
|
| 81 |
cp ./checkpoint/trocr-custdata/checkpoint-$index/pytorch_model.bin ./cust-data/weights
|
| 82 |
+
python app.py --cust_data_init_weights_path ./cust-data/weights --test_img test/test.jpg
|
| 83 |
```
|
| 84 |
|
| 85 |
## 预训练模型
|
| 86 |
| 模型 | cer(字符错误率) | acc(文本行) | 下载地址 |训练数据来源 |训练耗时(GPU:3090) |
|
| 87 |
| ------------- |:-------------:| -----:|-----:|-----:|-----:|
|
| 88 |
+
| hand-write(中文手写) |0.011 | 0.940 |[hand-write](https://pan.baidu.com/s/19f7iu9tLHkcT_zpi3UfqLQ) 密码: punl |[数据集地址](https://aistudio.baidu.com/aistudio/datasetdetail/102884/0) |8.5h(10epoch)|
|
| 89 |
+
| seal(印章识别) |- | - |- |- |
|
| 90 |
+
| im2latex(数学公式识别) |- | - |- |[im2latex](https://zenodo.org/record/56198#.YkniL25Bx_S) ||
|
| 91 |
+
| TAL_OCR_TABLE(表格识别) |- | - |- |[TAL_OCR_TABLE](https://ai.100tal.com/dataset) |
|
| 92 |
+
| TAL_OCR_MATH(小学低年级算式数据集)|- | - |- | [TAL_OCR_MATH](https://ai.100tal.com/dataset) |
|
| 93 |
+
| TAL_OCR_CHN(手写中文数据集)|- | - |- | [TAL_OCR_CHN](https://ai.100tal.com/dataset) ||
|
| 94 |
+
| HME100K(手写公式)|- | - |- | [HME100K](https://ai.100tal.com/dataset) |
|
| 95 |
|
| 96 |
备注:后续所有模型会开源在这个目录下链接,可以自由下载. https://pan.baidu.com/s/1uSdWQhJPEy2CYoEULoOhRA 密码: vwi2
|
| 97 |
### 模型调用
|
| 98 |
+
#### 手写识别
|
| 99 |
+

|
| 100 |
```
|
| 101 |
unzip hand-write.zip
|
| 102 |
python app.py --cust_data_init_weights_path hand-write --test_img test/hand.png
|
| 103 |
+
|
| 104 |
+
## output: '醒我的昏迷,偿还我的天真。'
|
| 105 |
```
|
| 106 |
+
|
| 107 |
+
#### 打印公式识别
|
| 108 |
+

|
| 109 |
+
```
|
| 110 |
+
unzip im2latex.zip
|
| 111 |
+
python app.py --cust_data_init_weights_path im2latex --test_img test/im2latex.png
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
|
| 116 |
## 捐助
|
| 117 |
如果此项目给您的工作带来了帮忙,希望您能贡献自己微薄的爱心,
|
| 118 |
该项目的每一份收入将用着福利事业,每一季度在issues上公布捐赠明细!
|
app.py
CHANGED
|
@@ -32,4 +32,4 @@ if __name__ == '__main__':
|
|
| 32 |
generated_ids = model.generate(pixel_values[:, :, :].cpu())
|
| 33 |
|
| 34 |
generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
|
| 35 |
-
print('time take:', round(time.time() - t, 2), "s ocr:", [generated_text
|
|
|
|
| 32 |
generated_ids = model.generate(pixel_values[:, :, :].cpu())
|
| 33 |
|
| 34 |
generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
|
| 35 |
+
print('time take:', round(time.time() - t, 2), "s ocr:", [generated_text])
|
dataset.py
CHANGED
|
@@ -29,7 +29,13 @@ class trocrDataset(Dataset):
|
|
| 29 |
txt_file = os.path.splitext(image_file)[0]+'.txt'
|
| 30 |
|
| 31 |
with open(txt_file) as f:
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
image = Image.open(image_file).convert("RGB")
|
| 35 |
image = self.transformer(image) ##图像增强函数
|
|
@@ -45,10 +51,13 @@ class trocrDataset(Dataset):
|
|
| 45 |
|
| 46 |
def encode_text(text, max_target_length=128, vocab=None):
|
| 47 |
"""
|
|
|
|
| 48 |
{'input_ids': [0, 1092, 2, 1, 1],
|
| 49 |
'attention_mask': [1, 1, 1, 0, 0]}
|
| 50 |
"""
|
| 51 |
-
text
|
|
|
|
|
|
|
| 52 |
text = text[:max_target_length - 2]
|
| 53 |
tokens = [vocab.get('<s>')]
|
| 54 |
unk = vocab.get('<unk>')
|
|
@@ -76,9 +85,10 @@ def decode_text(tokens, vocab, vocab_inp):
|
|
| 76 |
s_start = vocab.get('<s>')
|
| 77 |
s_end = vocab.get('</s>')
|
| 78 |
unk = vocab.get('<unk>')
|
|
|
|
| 79 |
text = ''
|
| 80 |
for tk in tokens:
|
| 81 |
-
if tk not in [s_end, s_start]:
|
| 82 |
text += vocab_inp[tk]
|
| 83 |
|
| 84 |
return text
|
|
|
|
| 29 |
txt_file = os.path.splitext(image_file)[0]+'.txt'
|
| 30 |
|
| 31 |
with open(txt_file) as f:
|
| 32 |
+
text = f.read().strip().replace('xa0','')
|
| 33 |
+
if text.startswith('[') and text.endswith(']'):
|
| 34 |
+
##list
|
| 35 |
+
try:
|
| 36 |
+
text = json.loads(text)
|
| 37 |
+
except:
|
| 38 |
+
pass
|
| 39 |
|
| 40 |
image = Image.open(image_file).convert("RGB")
|
| 41 |
image = self.transformer(image) ##图像增强函数
|
|
|
|
| 51 |
|
| 52 |
def encode_text(text, max_target_length=128, vocab=None):
|
| 53 |
"""
|
| 54 |
+
##自持自定义 list: ['<td>',"3","3",'</td>',....]
|
| 55 |
{'input_ids': [0, 1092, 2, 1, 1],
|
| 56 |
'attention_mask': [1, 1, 1, 0, 0]}
|
| 57 |
"""
|
| 58 |
+
if type(text) is not list:
|
| 59 |
+
text = list(text)
|
| 60 |
+
|
| 61 |
text = text[:max_target_length - 2]
|
| 62 |
tokens = [vocab.get('<s>')]
|
| 63 |
unk = vocab.get('<unk>')
|
|
|
|
| 85 |
s_start = vocab.get('<s>')
|
| 86 |
s_end = vocab.get('</s>')
|
| 87 |
unk = vocab.get('<unk>')
|
| 88 |
+
pad = vocab.get('<pad>')
|
| 89 |
text = ''
|
| 90 |
for tk in tokens:
|
| 91 |
+
if tk not in [s_end, s_start , pad, unk]:
|
| 92 |
text += vocab_inp[tk]
|
| 93 |
|
| 94 |
return text
|
eval.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
| 6 |
import time
|
|
@@ -29,17 +27,16 @@ def compute_metrics(pred_str, label_str):
|
|
| 29 |
|
| 30 |
|
| 31 |
if __name__ == '__main__':
|
| 32 |
-
parser = argparse.ArgumentParser(description='trocr
|
| 33 |
parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
|
| 34 |
help="初始化训练权重,用于自己数据集上fine-tune权重")
|
| 35 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
|
| 36 |
-
parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path")
|
| 37 |
parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
|
| 38 |
help="img path")
|
| 39 |
-
parser.add_argument('--random_state', default=
|
| 40 |
|
| 41 |
args = parser.parse_args()
|
| 42 |
-
|
| 43 |
paths = glob(args.dataset_path)
|
| 44 |
if args.random_state is not None:
|
| 45 |
train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=args.random_state)
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
from PIL import Image
|
| 3 |
import numpy as np
|
| 4 |
import time
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
if __name__ == '__main__':
|
| 30 |
+
parser = argparse.ArgumentParser(description='trocr 模型评估')
|
| 31 |
parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
|
| 32 |
help="初始化训练权重,用于自己数据集上fine-tune权重")
|
| 33 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
|
|
|
|
| 34 |
parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
|
| 35 |
help="img path")
|
| 36 |
+
parser.add_argument('--random_state', default=None, type=int, help="用于训练集划分的随机数")
|
| 37 |
|
| 38 |
args = parser.parse_args()
|
| 39 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
|
| 40 |
paths = glob(args.dataset_path)
|
| 41 |
if args.random_state is not None:
|
| 42 |
train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=args.random_state)
|
gen_vocab.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from glob import glob
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import codecs
|
| 5 |
+
import argparse
|
| 6 |
+
if __name__=='__main__':
|
| 7 |
+
parser = argparse.ArgumentParser(description='trocr vocab生成')
|
| 8 |
+
parser.add_argument('--cust_vocab', default="./cust-data/vocab.txt", type=str, help="自定义vocab文件生成")
|
| 9 |
+
parser.add_argument('--dataset_path', default="./dataset/train/*/*.jpg", type=str, help="自定义训练数字符集")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
paths = glob(args.dataset_path)
|
| 12 |
+
vocab = set()
|
| 13 |
+
for p in tqdm(paths):
|
| 14 |
+
with codecs.open(p, encoding='utf-8') as f:
|
| 15 |
+
txt = f.read().strip()
|
| 16 |
+
vocab.update(txt)
|
| 17 |
+
root_path = os.path.split(args.cust_vocab)
|
| 18 |
+
os.makedirs(root_path, exist_ok=True)
|
| 19 |
+
with open(args.cust_vocab, 'w') as f:
|
| 20 |
+
f.write('\n'.join(list(vocab)))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
img/im2latex.png
ADDED
|
train.py
CHANGED
|
@@ -22,8 +22,11 @@ def compute_metrics(pred):
|
|
| 22 |
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
|
| 23 |
label_str = [decode_text(labels_id, vocab, vocab_inp) for labels_id in labels_ids]
|
| 24 |
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
|
|
|
| 25 |
acc = [pred == label for pred, label in zip(pred_str, label_str)]
|
|
|
|
| 26 |
acc = sum(acc)/(len(acc)+0.000001)
|
|
|
|
| 27 |
return {"cer": cer, "acc": acc}
|
| 28 |
|
| 29 |
if __name__ == '__main__':
|
|
@@ -34,10 +37,11 @@ if __name__ == '__main__':
|
|
| 34 |
parser.add_argument('--dataset_path', default='./dataset/cust-data/*/*.jpg', type=str, help="训练数据集")
|
| 35 |
parser.add_argument('--per_device_train_batch_size', default=32, type=int, help="train batch size")
|
| 36 |
parser.add_argument('--per_device_eval_batch_size', default=8, type=int, help="eval batch size")
|
|
|
|
| 37 |
|
| 38 |
parser.add_argument('--num_train_epochs', default=10, type=int, help="训练epoch数")
|
| 39 |
-
parser.add_argument('--eval_steps', default=
|
| 40 |
-
parser.add_argument('--save_steps', default=
|
| 41 |
|
| 42 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='0,1', type=str, help="GPU设置")
|
| 43 |
|
|
@@ -56,8 +60,8 @@ if __name__ == '__main__':
|
|
| 56 |
vocab = processor.tokenizer.get_vocab()
|
| 57 |
vocab_inp = {vocab[key]: key for key in vocab}
|
| 58 |
|
| 59 |
-
train_dataset = trocrDataset(paths=train_paths, processor=processor)
|
| 60 |
-
eval_dataset = trocrDataset(paths=test_paths, processor=processor)
|
| 61 |
|
| 62 |
model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
|
| 63 |
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
|
@@ -99,6 +103,8 @@ if __name__ == '__main__':
|
|
| 99 |
data_collator=default_data_collator,
|
| 100 |
)
|
| 101 |
trainer.train()
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
|
|
|
|
| 22 |
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
|
| 23 |
label_str = [decode_text(labels_id, vocab, vocab_inp) for labels_id in labels_ids]
|
| 24 |
cer = cer_metric.compute(predictions=pred_str, references=label_str)
|
| 25 |
+
|
| 26 |
acc = [pred == label for pred, label in zip(pred_str, label_str)]
|
| 27 |
+
print([pred_str[0], label_str[0]])
|
| 28 |
acc = sum(acc)/(len(acc)+0.000001)
|
| 29 |
+
|
| 30 |
return {"cer": cer, "acc": acc}
|
| 31 |
|
| 32 |
if __name__ == '__main__':
|
|
|
|
| 37 |
parser.add_argument('--dataset_path', default='./dataset/cust-data/*/*.jpg', type=str, help="训练数据集")
|
| 38 |
parser.add_argument('--per_device_train_batch_size', default=32, type=int, help="train batch size")
|
| 39 |
parser.add_argument('--per_device_eval_batch_size', default=8, type=int, help="eval batch size")
|
| 40 |
+
parser.add_argument('--max_target_length', default=128, type=int, help="训练文字字符数")
|
| 41 |
|
| 42 |
parser.add_argument('--num_train_epochs', default=10, type=int, help="训练epoch数")
|
| 43 |
+
parser.add_argument('--eval_steps', default=1000, type=int, help="模型评估间隔数")
|
| 44 |
+
parser.add_argument('--save_steps', default=1000, type=int, help="模型保存间隔步数")
|
| 45 |
|
| 46 |
parser.add_argument('--CUDA_VISIBLE_DEVICES', default='0,1', type=str, help="GPU设置")
|
| 47 |
|
|
|
|
| 60 |
vocab = processor.tokenizer.get_vocab()
|
| 61 |
vocab_inp = {vocab[key]: key for key in vocab}
|
| 62 |
|
| 63 |
+
train_dataset = trocrDataset(paths=train_paths, processor=processor, max_target_length=args.max_target_length)
|
| 64 |
+
eval_dataset = trocrDataset(paths=test_paths, processor=processor, max_target_length=args.max_target_length)
|
| 65 |
|
| 66 |
model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
|
| 67 |
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
|
|
|
| 103 |
data_collator=default_data_collator,
|
| 104 |
)
|
| 105 |
trainer.train()
|
| 106 |
+
trainer.save_model(os.path.join(args.checkpoint_path, 'last'))
|
| 107 |
+
processor.save_pretrained(os.path.join(args.checkpoint_path, 'last'))
|
| 108 |
|
| 109 |
|
| 110 |
|