文银龙 commited on
Commit
4db1ad4
·
1 Parent(s): b1c66ba

fix train acc=0 with filter <pad>

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. README.md +38 -9
  3. app.py +1 -1
  4. dataset.py +13 -3
  5. eval.py +3 -6
  6. gen_vocab.py +25 -0
  7. img/im2latex.png +0 -0
  8. train.py +10 -4
.gitignore CHANGED
@@ -24,4 +24,5 @@ __pycache__/*
24
  dataset/test.json
25
  dataset/train.json
26
  app/
 
27
 
 
24
  dataset/test.json
25
  dataset/train.json
26
  app/
27
+ models/
28
 
README.md CHANGED
@@ -19,13 +19,16 @@ docker run --gpus all -it -v /tmp/trocr-chinese:/trocr-chinese trocr-chinese:lat
19
  vocab.txt
20
  1
21
  2
22
- 3
23
- 4
24
- 5
25
  ...
26
  a
27
  b
28
  c
 
 
 
 
 
 
29
  ```
30
  ### 初始化自定义数据集模型
31
  #### 下载预训练模型trocr模型权重
@@ -62,28 +65,54 @@ python train.py \
62
  --CUDA_VISIBLE_DEVICES 1
63
  ```
64
 
 
 
 
 
 
 
 
 
 
65
  ## 测试模型
66
  ```
67
  ## 拷贝训练完成的pytorch_model.bin 到 ./cust-data/weights 目录下
68
- index = 2300
69
  cp ./checkpoint/trocr-custdata/checkpoint-$index/pytorch_model.bin ./cust-data/weights
70
- python app.py --test_img test/test.jpg
71
  ```
72
 
73
  ## 预训练模型
74
  | 模型 | cer(字符错误率) | acc(文本行) | 下载地址 |训练数据来源 |训练耗时(GPU:3090) |
75
  | ------------- |:-------------:| -----:|-----:|-----:|-----:|
76
- | hand-write(中文手写) |0.011 | 0.940 |链接: https://pan.baidu.com/s/19f7iu9tLHkcT_zpi3UfqLQ 密码: punl |https://aistudio.baidu.com/aistudio/datasetdetail/102884/0 |8.5h|
77
- | 印章识别 |- | - |- |- |
78
- | im2latex(数学公式识别) |- | - |- |https://zenodo.org/record/56198#.YkniL25Bx_S |
79
- | 表格识别 |- | - |- |链接:https://pan.baidu.com/s/1V0NT2XmQDDb0mHQlw7V7_w 提取码:oo4a |
 
 
 
80
 
81
  备注:后续所有模型会开源在这个目录下链接,可以自由下载. https://pan.baidu.com/s/1uSdWQhJPEy2CYoEULoOhRA 密码: vwi2
82
  ### 模型调用
 
 
83
  ```
84
  unzip hand-write.zip
85
  python app.py --cust_data_init_weights_path hand-write --test_img test/hand.png
 
 
86
  ```
 
 
 
 
 
 
 
 
 
 
87
  ## 捐助
88
  如果此项目给您的工作带来了帮忙,希望您能贡献自己微薄的爱心,
89
  该项目的每一份收入将用着福利事业,每一季度在issues上公布捐赠明细!
 
19
  vocab.txt
20
  1
21
  2
 
 
 
22
  ...
23
  a
24
  b
25
  c
26
+ ```
27
+ ```[python]
28
+ python gen_vocab.py \
29
+ --dataset_dataset_path "dataset/cust-data/0/*.txt" \
30
+ --cust_vocab ./cust-data/vocab.txt
31
+
32
  ```
33
  ### 初始化自定义数据集模型
34
  #### 下载预训练模型trocr模型权重
 
65
  --CUDA_VISIBLE_DEVICES 1
66
  ```
67
 
68
+ #### 评估模型
69
+ ##### 拷贝checkpoint/trocr-custdata训练完成的pytorch_model.bin 到 ./cust-data/weights 目录下
70
+
71
+ ```[python]
72
+ python eval.py \
73
+ --dataset_path "./data/cust-data/test/*/*.jpg" \
74
+ --cust_data_init_weights_path ./cust-data/weights
75
+ ```
76
+
77
  ## 测试模型
78
  ```
79
  ## 拷贝训练完成的pytorch_model.bin 到 ./cust-data/weights 目录下
80
+ index = 2300 ##选择最好的或者最后一个step模型
81
  cp ./checkpoint/trocr-custdata/checkpoint-$index/pytorch_model.bin ./cust-data/weights
82
+ python app.py --cust_data_init_weights_path ./cust-data/weights --test_img test/test.jpg
83
  ```
84
 
85
  ## 预训练模型
86
  | 模型 | cer(字符错误率) | acc(文本行) | 下载地址 |训练数据来源 |训练耗时(GPU:3090) |
87
  | ------------- |:-------------:| -----:|-----:|-----:|-----:|
88
+ | hand-write(中文手写) |0.011 | 0.940 |[hand-write](https://pan.baidu.com/s/19f7iu9tLHkcT_zpi3UfqLQ) 密码: punl |[数据集地址](https://aistudio.baidu.com/aistudio/datasetdetail/102884/0) |8.5h(10epoch)|
89
+ | seal(印章识别) |- | - |- |- |
90
+ | im2latex(数学公式识别) |- | - |- |[im2latex](https://zenodo.org/record/56198#.YkniL25Bx_S) ||
91
+ | TAL_OCR_TABLE(表格识别) |- | - |- |[TAL_OCR_TABLE](https://ai.100tal.com/dataset) |
92
+ | TAL_OCR_MATH(小学低年级算式数据集)|- | - |- | [TAL_OCR_MATH](https://ai.100tal.com/dataset) |
93
+ | TAL_OCR_CHN(手写中文数据集)|- | - |- | [TAL_OCR_CHN](https://ai.100tal.com/dataset) ||
94
+ | HME100K(手写公式)|- | - |- | [HME100K](https://ai.100tal.com/dataset) |
95
 
96
  备注:后续所有模型会开源在这个目录下链接,可以自由下载. https://pan.baidu.com/s/1uSdWQhJPEy2CYoEULoOhRA 密码: vwi2
97
  ### 模型调用
98
+ #### 手写识别
99
+ ![image](img/hand.png)
100
  ```
101
  unzip hand-write.zip
102
  python app.py --cust_data_init_weights_path hand-write --test_img test/hand.png
103
+
104
+ ## output: '醒我的昏迷,偿还我的天真。'
105
  ```
106
+
107
+ #### 打印公式识别
108
+ ![image](img/im2latex.png)
109
+ ```
110
+ unzip im2latex.zip
111
+ python app.py --cust_data_init_weights_path im2latex --test_img test/im2latex.png
112
+
113
+ ```
114
+
115
+
116
  ## 捐助
117
  如果此项目给您的工作带来了帮忙,希望您能贡献自己微薄的爱心,
118
  该项目的每一份收入将用着福利事业,每一季度在issues上公布捐赠明细!
app.py CHANGED
@@ -32,4 +32,4 @@ if __name__ == '__main__':
32
  generated_ids = model.generate(pixel_values[:, :, :].cpu())
33
 
34
  generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
35
- print('time take:', round(time.time() - t, 2), "s ocr:", [generated_text.replace(' ', '\n')])
 
32
  generated_ids = model.generate(pixel_values[:, :, :].cpu())
33
 
34
  generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp)
35
+ print('time take:', round(time.time() - t, 2), "s ocr:", [generated_text])
dataset.py CHANGED
@@ -29,7 +29,13 @@ class trocrDataset(Dataset):
29
  txt_file = os.path.splitext(image_file)[0]+'.txt'
30
 
31
  with open(txt_file) as f:
32
- text = f.read().strip().replace('xa0','')
 
 
 
 
 
 
33
 
34
  image = Image.open(image_file).convert("RGB")
35
  image = self.transformer(image) ##图像增强函数
@@ -45,10 +51,13 @@ class trocrDataset(Dataset):
45
 
46
  def encode_text(text, max_target_length=128, vocab=None):
47
  """
 
48
  {'input_ids': [0, 1092, 2, 1, 1],
49
  'attention_mask': [1, 1, 1, 0, 0]}
50
  """
51
- text = list(text)
 
 
52
  text = text[:max_target_length - 2]
53
  tokens = [vocab.get('<s>')]
54
  unk = vocab.get('<unk>')
@@ -76,9 +85,10 @@ def decode_text(tokens, vocab, vocab_inp):
76
  s_start = vocab.get('<s>')
77
  s_end = vocab.get('</s>')
78
  unk = vocab.get('<unk>')
 
79
  text = ''
80
  for tk in tokens:
81
- if tk not in [s_end, s_start]:
82
  text += vocab_inp[tk]
83
 
84
  return text
 
29
  txt_file = os.path.splitext(image_file)[0]+'.txt'
30
 
31
  with open(txt_file) as f:
32
+ text = f.read().strip().replace('xa0','')
33
+ if text.startswith('[') and text.endswith(']'):
34
+ ##list
35
+ try:
36
+ text = json.loads(text)
37
+ except:
38
+ pass
39
 
40
  image = Image.open(image_file).convert("RGB")
41
  image = self.transformer(image) ##图像增强函数
 
51
 
52
  def encode_text(text, max_target_length=128, vocab=None):
53
  """
54
+ ##自持自定义 list: ['<td>',"3","3",'</td>',....]
55
  {'input_ids': [0, 1092, 2, 1, 1],
56
  'attention_mask': [1, 1, 1, 0, 0]}
57
  """
58
+ if type(text) is not list:
59
+ text = list(text)
60
+
61
  text = text[:max_target_length - 2]
62
  tokens = [vocab.get('<s>')]
63
  unk = vocab.get('<unk>')
 
85
  s_start = vocab.get('<s>')
86
  s_end = vocab.get('</s>')
87
  unk = vocab.get('<unk>')
88
+ pad = vocab.get('<pad>')
89
  text = ''
90
  for tk in tokens:
91
+ if tk not in [s_end, s_start , pad, unk]:
92
  text += vocab_inp[tk]
93
 
94
  return text
eval.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
-
3
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
4
  from PIL import Image
5
  import numpy as np
6
  import time
@@ -29,17 +27,16 @@ def compute_metrics(pred_str, label_str):
29
 
30
 
31
  if __name__ == '__main__':
32
- parser = argparse.ArgumentParser(description='trocr fine-tune训练')
33
  parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
34
  help="初始化训练权重,用于自己数据集上fine-tune权重")
35
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
36
- parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path")
37
  parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
38
  help="img path")
39
- parser.add_argument('--random_state', default=10086, type=int, help="用于训练集划分的随机数")
40
 
41
  args = parser.parse_args()
42
-
43
  paths = glob(args.dataset_path)
44
  if args.random_state is not None:
45
  train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=args.random_state)
 
1
  import os
 
 
2
  from PIL import Image
3
  import numpy as np
4
  import time
 
27
 
28
 
29
  if __name__ == '__main__':
30
+ parser = argparse.ArgumentParser(description='trocr 模型评估')
31
  parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
32
  help="初始化训练权重,用于自己数据集上fine-tune权重")
33
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
 
34
  parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str,
35
  help="img path")
36
+ parser.add_argument('--random_state', default=None, type=int, help="用于训练集划分的随机数")
37
 
38
  args = parser.parse_args()
39
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
40
  paths = glob(args.dataset_path)
41
  if args.random_state is not None:
42
  train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=args.random_state)
gen_vocab.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from glob import glob
3
+ from tqdm import tqdm
4
+ import codecs
5
+ import argparse
6
+ if __name__=='__main__':
7
+ parser = argparse.ArgumentParser(description='trocr vocab生成')
8
+ parser.add_argument('--cust_vocab', default="./cust-data/vocab.txt", type=str, help="自定义vocab文件生成")
9
+ parser.add_argument('--dataset_path', default="./dataset/train/*/*.jpg", type=str, help="自定义训练数字符集")
10
+ args = parser.parse_args()
11
+ paths = glob(args.dataset_path)
12
+ vocab = set()
13
+ for p in tqdm(paths):
14
+ with codecs.open(p, encoding='utf-8') as f:
15
+ txt = f.read().strip()
16
+ vocab.update(txt)
17
+ root_path = os.path.split(args.cust_vocab)
18
+ os.makedirs(root_path, exist_ok=True)
19
+ with open(args.cust_vocab, 'w') as f:
20
+ f.write('\n'.join(list(vocab)))
21
+
22
+
23
+
24
+
25
+
img/im2latex.png ADDED
train.py CHANGED
@@ -22,8 +22,11 @@ def compute_metrics(pred):
22
  labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
23
  label_str = [decode_text(labels_id, vocab, vocab_inp) for labels_id in labels_ids]
24
  cer = cer_metric.compute(predictions=pred_str, references=label_str)
 
25
  acc = [pred == label for pred, label in zip(pred_str, label_str)]
 
26
  acc = sum(acc)/(len(acc)+0.000001)
 
27
  return {"cer": cer, "acc": acc}
28
 
29
  if __name__ == '__main__':
@@ -34,10 +37,11 @@ if __name__ == '__main__':
34
  parser.add_argument('--dataset_path', default='./dataset/cust-data/*/*.jpg', type=str, help="训练数据集")
35
  parser.add_argument('--per_device_train_batch_size', default=32, type=int, help="train batch size")
36
  parser.add_argument('--per_device_eval_batch_size', default=8, type=int, help="eval batch size")
 
37
 
38
  parser.add_argument('--num_train_epochs', default=10, type=int, help="训练epoch数")
39
- parser.add_argument('--eval_steps', default=5000, type=int, help="模型评估间隔数")
40
- parser.add_argument('--save_steps', default=500, type=int, help="模型保存间隔步数")
41
 
42
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='0,1', type=str, help="GPU设置")
43
 
@@ -56,8 +60,8 @@ if __name__ == '__main__':
56
  vocab = processor.tokenizer.get_vocab()
57
  vocab_inp = {vocab[key]: key for key in vocab}
58
 
59
- train_dataset = trocrDataset(paths=train_paths, processor=processor)
60
- eval_dataset = trocrDataset(paths=test_paths, processor=processor)
61
 
62
  model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
63
  model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
@@ -99,6 +103,8 @@ if __name__ == '__main__':
99
  data_collator=default_data_collator,
100
  )
101
  trainer.train()
 
 
102
 
103
 
104
 
 
22
  labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
23
  label_str = [decode_text(labels_id, vocab, vocab_inp) for labels_id in labels_ids]
24
  cer = cer_metric.compute(predictions=pred_str, references=label_str)
25
+
26
  acc = [pred == label for pred, label in zip(pred_str, label_str)]
27
+ print([pred_str[0], label_str[0]])
28
  acc = sum(acc)/(len(acc)+0.000001)
29
+
30
  return {"cer": cer, "acc": acc}
31
 
32
  if __name__ == '__main__':
 
37
  parser.add_argument('--dataset_path', default='./dataset/cust-data/*/*.jpg', type=str, help="训练数据集")
38
  parser.add_argument('--per_device_train_batch_size', default=32, type=int, help="train batch size")
39
  parser.add_argument('--per_device_eval_batch_size', default=8, type=int, help="eval batch size")
40
+ parser.add_argument('--max_target_length', default=128, type=int, help="训练文字字符数")
41
 
42
  parser.add_argument('--num_train_epochs', default=10, type=int, help="训练epoch数")
43
+ parser.add_argument('--eval_steps', default=1000, type=int, help="模型评估间隔数")
44
+ parser.add_argument('--save_steps', default=1000, type=int, help="模型保存间隔步数")
45
 
46
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='0,1', type=str, help="GPU设置")
47
 
 
60
  vocab = processor.tokenizer.get_vocab()
61
  vocab_inp = {vocab[key]: key for key in vocab}
62
 
63
+ train_dataset = trocrDataset(paths=train_paths, processor=processor, max_target_length=args.max_target_length)
64
+ eval_dataset = trocrDataset(paths=test_paths, processor=processor, max_target_length=args.max_target_length)
65
 
66
  model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
67
  model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
 
103
  data_collator=default_data_collator,
104
  )
105
  trainer.train()
106
+ trainer.save_model(os.path.join(args.checkpoint_path, 'last'))
107
+ processor.save_pretrained(os.path.join(args.checkpoint_path, 'last'))
108
 
109
 
110