文银龙 commited on
Commit
b1c66ba
·
1 Parent(s): ca5fd54
Files changed (2) hide show
  1. README.md +4 -2
  2. eval.py +3 -3
README.md CHANGED
@@ -35,10 +35,12 @@ python init_custdata_model.py \
35
  --cust_vocab ./cust-data/vocab.txt \
36
  --pretrain_model ./weights \
37
  --cust_data_init_weights_path ./cust-data/weights
38
- ```
39
  ## cust_vocab 词库文件
40
  ## pretrain_model 预训练模型权重
41
- ## cut_data_init_weights_path 自定义模型初始化模型权重保存位置
 
 
42
 
43
  ### 训练模型
44
  #### 数据准备,数据结构如下图所示
 
35
  --cust_vocab ./cust-data/vocab.txt \
36
  --pretrain_model ./weights \
37
  --cust_data_init_weights_path ./cust-data/weights
38
+
39
  ## cust_vocab 词库文件
40
  ## pretrain_model 预训练模型权重
41
+ ## cust_data_init_weights_path 自定义模型初始化模型权重保存位置
42
+
43
+ ```
44
 
45
  ### 训练模型
46
  #### 数据准备,数据结构如下图所示
eval.py CHANGED
@@ -30,7 +30,7 @@ def compute_metrics(pred_str, label_str):
30
 
31
  if __name__ == '__main__':
32
  parser = argparse.ArgumentParser(description='trocr fine-tune训练')
33
- parser.add_argument('--cut_data_init_weights_path', default='./cust-data/weights', type=str,
34
  help="初始化训练权重,用于自己数据集上fine-tune权重")
35
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
36
  parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path")
@@ -50,11 +50,11 @@ if __name__ == '__main__':
50
 
51
  print("train num:", len(train_paths), "test num:", len(test_paths))
52
 
53
- processor = TrOCRProcessor.from_pretrained(args.cut_data_init_weights_path)
54
  vocab = processor.tokenizer.get_vocab()
55
 
56
  vocab_inp = {vocab[key]: key for key in vocab}
57
- model = VisionEncoderDecoderModel.from_pretrained(args.cut_data_init_weights_path)
58
  model.eval()
59
  model.cuda()
60
 
 
30
 
31
  if __name__ == '__main__':
32
  parser = argparse.ArgumentParser(description='trocr fine-tune训练')
33
+ parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
34
  help="初始化训练权重,用于自己数据集上fine-tune权重")
35
  parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置")
36
  parser.add_argument('--test_img', default='test/test.jpg', type=str, help="img path")
 
50
 
51
  print("train num:", len(train_paths), "test num:", len(test_paths))
52
 
53
+ processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path)
54
  vocab = processor.tokenizer.get_vocab()
55
 
56
  vocab_inp = {vocab[key]: key for key in vocab}
57
+ model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path)
58
  model.eval()
59
  model.cuda()
60