Chen42 commited on
Commit
5bffd18
·
verified ·
1 Parent(s): 59ca806

Upload /my_pipeline_wo_ckpt/segment_sent_results/inference.py with huggingface_hub

Browse files
my_pipeline_wo_ckpt/segment_sent_results/inference.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
3
+
4
+ import json
5
+
6
+ from trankit import Pipeline
7
+ trankit_cache_dir = "./sen_split_models"
8
+
9
+ # utils
10
+ trankit_pipe = Pipeline(lang="english", gpu=True, cache_dir=trankit_cache_dir, embedding='xlm-roberta-large')
11
+ sen_spliter = trankit_pipe.ssplit
12
+
13
+ def sen_split_model(word_list, ):
14
+ text = " ".join(word_list)
15
+
16
+ text_split_list = sen_spliter(text)["sentences"]
17
+ """
18
+ text_tgt_split_list形如:[
19
+ {
20
+ 'id': 1,
21
+ 'text': 'Hello!',
22
+ 'dspan': (0, 6)
23
+ },
24
+ {
25
+ 'id': 2,
26
+ 'text': 'This is Trankit.',
27
+ 'dspan': (7, 23)
28
+ }
29
+ ]
30
+ """
31
+ # print(text_tgt_split_list)
32
+
33
+ # 我们不希望像"Hunan Biological Medicine Factory, China.through Hunan Yahua Seed Corporation Ltd."这样的句子被切分为2个句子(从China.处切断),我们只希望断句处的后一个字符一定是空格,这样才能把切分后的所有句子的所有单词和切分前的doc中的单词一一对应起来。因此对sen_spliter的结果进行处理,如果一个句子和它的后一个句子是在非空格处被切分的,那么将这2个句子合并,不进行切分。用前后指针法来处理。
34
+ post_processed_text_split_list = []
35
+ head_ptr, tail_ptr = 0, 0
36
+ while True:
37
+ head_item = text_split_list[head_ptr]
38
+ if tail_ptr + 1 < len(text_split_list):
39
+ tail_ptr += 1
40
+ tail_item = text_split_list[tail_ptr]
41
+ if head_item["dspan"][1] == tail_item["dspan"][0]: # 从非空格处被切断了。
42
+ head_item["text"] = head_item["text"] + tail_item["text"]
43
+ head_item["dspan"] = (head_item["dspan"][0], tail_item["dspan"][1])
44
+ else:
45
+ post_processed_text_split_list.append(head_item)
46
+ head_ptr = tail_ptr
47
+ else: # 最后一项容易被遗漏
48
+ post_processed_text_split_list.append(head_item)
49
+ break
50
+
51
+
52
+ word_idx = 0
53
+ # text_split_res_list = []
54
+ word_sublist = []
55
+ for sen_dict in post_processed_text_split_list:
56
+ sen = sen_dict["text"]
57
+ word_sublist.append(sen.split(" "))
58
+
59
+ return word_sublist
60
+
61
+
62
+ def infer_one_block(reordered_block_textlines_list):
63
+ texts = [item["text"] for item in reordered_block_textlines_list]
64
+ block_text = " ".join(texts)
65
+ split_word_sublist = sen_split_model(block_text.split())
66
+ sents = [" ".join(word_sublist) for word_sublist in split_word_sublist]
67
+
68
+ return sents
69
+
70
+
71
+
72
+ def infer_one_img(img_line_dict):
73
+ reordered_block_sents = []
74
+
75
+ for block_textlines_list in img_line_dict["reordered_blocks_textlines"]:
76
+ sents = infer_one_block(block_textlines_list)
77
+
78
+ reordered_block_sents.append(sents)
79
+
80
+ img_line_dict.update({
81
+ "reordered_block_sents": reordered_block_sents
82
+ })
83
+
84
+ return img_line_dict
85
+
86
+
87
+ if __name__ == "__main__":
88
+
89
+ domain = "纸质文档" # ["标牌标语", "电脑屏拍", "商品包装", "手机屏拍", "纸质文档"]
90
+ tgt_dir = f"./results/{domain}"
91
+ os.makedirs(tgt_dir, exist_ok=True)
92
+
93
+ src_filepath = f"../blocked_reordered_results/results/{domain}/reordered_blocks_textlines.json"
94
+ tgt_filepath = f"{tgt_dir}/segment_sent.json"
95
+
96
+ with open(src_filepath, "r", encoding="utf8") as src_file, open(tgt_filepath, "w") as tgt_file:
97
+ while True:
98
+ line_str = src_file.readline().strip()
99
+ if not line_str:
100
+ break
101
+ line_dict = json.loads(line_str)
102
+ new_line_dict = infer_one_img(line_dict)
103
+
104
+ tgt_file.write(f"{json.dumps(new_line_dict, ensure_ascii=False)}\n")