Upload /my_pipeline_wo_ckpt/segment_sent_results/inference.py with huggingface_hub
Browse files
my_pipeline_wo_ckpt/segment_sent_results/inference.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from trankit import Pipeline
|
| 7 |
+
trankit_cache_dir = "./sen_split_models"
|
| 8 |
+
|
| 9 |
+
# utils
|
| 10 |
+
trankit_pipe = Pipeline(lang="english", gpu=True, cache_dir=trankit_cache_dir, embedding='xlm-roberta-large')
|
| 11 |
+
sen_spliter = trankit_pipe.ssplit
|
| 12 |
+
|
| 13 |
+
def sen_split_model(word_list, ):
|
| 14 |
+
text = " ".join(word_list)
|
| 15 |
+
|
| 16 |
+
text_split_list = sen_spliter(text)["sentences"]
|
| 17 |
+
"""
|
| 18 |
+
text_tgt_split_list形如:[
|
| 19 |
+
{
|
| 20 |
+
'id': 1,
|
| 21 |
+
'text': 'Hello!',
|
| 22 |
+
'dspan': (0, 6)
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
'id': 2,
|
| 26 |
+
'text': 'This is Trankit.',
|
| 27 |
+
'dspan': (7, 23)
|
| 28 |
+
}
|
| 29 |
+
]
|
| 30 |
+
"""
|
| 31 |
+
# print(text_tgt_split_list)
|
| 32 |
+
|
| 33 |
+
# 我们不希望像"Hunan Biological Medicine Factory, China.through Hunan Yahua Seed Corporation Ltd."这样的句子被切分为2个句子(从China.处切断),我们只希望断句处的后一个字符一定是空格,这样才能把切分后的所有句子的所有单词和切分前的doc中的单词一一对应起来。因此对sen_spliter的结果进行处理,如果一个句子和它的后一个句子是在非空格处被切分的,那么将这2个句子合并,不进行切分。用前后指针法来处理。
|
| 34 |
+
post_processed_text_split_list = []
|
| 35 |
+
head_ptr, tail_ptr = 0, 0
|
| 36 |
+
while True:
|
| 37 |
+
head_item = text_split_list[head_ptr]
|
| 38 |
+
if tail_ptr + 1 < len(text_split_list):
|
| 39 |
+
tail_ptr += 1
|
| 40 |
+
tail_item = text_split_list[tail_ptr]
|
| 41 |
+
if head_item["dspan"][1] == tail_item["dspan"][0]: # 从非空格处被切断了。
|
| 42 |
+
head_item["text"] = head_item["text"] + tail_item["text"]
|
| 43 |
+
head_item["dspan"] = (head_item["dspan"][0], tail_item["dspan"][1])
|
| 44 |
+
else:
|
| 45 |
+
post_processed_text_split_list.append(head_item)
|
| 46 |
+
head_ptr = tail_ptr
|
| 47 |
+
else: # 最后一项容易被遗漏
|
| 48 |
+
post_processed_text_split_list.append(head_item)
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
word_idx = 0
|
| 53 |
+
# text_split_res_list = []
|
| 54 |
+
word_sublist = []
|
| 55 |
+
for sen_dict in post_processed_text_split_list:
|
| 56 |
+
sen = sen_dict["text"]
|
| 57 |
+
word_sublist.append(sen.split(" "))
|
| 58 |
+
|
| 59 |
+
return word_sublist
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def infer_one_block(reordered_block_textlines_list):
|
| 63 |
+
texts = [item["text"] for item in reordered_block_textlines_list]
|
| 64 |
+
block_text = " ".join(texts)
|
| 65 |
+
split_word_sublist = sen_split_model(block_text.split())
|
| 66 |
+
sents = [" ".join(word_sublist) for word_sublist in split_word_sublist]
|
| 67 |
+
|
| 68 |
+
return sents
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def infer_one_img(img_line_dict):
|
| 73 |
+
reordered_block_sents = []
|
| 74 |
+
|
| 75 |
+
for block_textlines_list in img_line_dict["reordered_blocks_textlines"]:
|
| 76 |
+
sents = infer_one_block(block_textlines_list)
|
| 77 |
+
|
| 78 |
+
reordered_block_sents.append(sents)
|
| 79 |
+
|
| 80 |
+
img_line_dict.update({
|
| 81 |
+
"reordered_block_sents": reordered_block_sents
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
return img_line_dict
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
|
| 89 |
+
domain = "纸质文档" # ["标牌标语", "电脑屏拍", "商品包装", "手机屏拍", "纸质文档"]
|
| 90 |
+
tgt_dir = f"./results/{domain}"
|
| 91 |
+
os.makedirs(tgt_dir, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
src_filepath = f"../blocked_reordered_results/results/{domain}/reordered_blocks_textlines.json"
|
| 94 |
+
tgt_filepath = f"{tgt_dir}/segment_sent.json"
|
| 95 |
+
|
| 96 |
+
with open(src_filepath, "r", encoding="utf8") as src_file, open(tgt_filepath, "w") as tgt_file:
|
| 97 |
+
while True:
|
| 98 |
+
line_str = src_file.readline().strip()
|
| 99 |
+
if not line_str:
|
| 100 |
+
break
|
| 101 |
+
line_dict = json.loads(line_str)
|
| 102 |
+
new_line_dict = infer_one_img(line_dict)
|
| 103 |
+
|
| 104 |
+
tgt_file.write(f"{json.dumps(new_line_dict, ensure_ascii=False)}\n")
|