Add related files
Browse files- app.py +127 -0
- pipeline.py +289 -0
- requirements.txt +3 -0
app.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Callable, Dict
|
| 4 |
+
import transformers
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoModelForTokenClassification,
|
| 7 |
+
AutoTokenizer
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
from pipeline import (
|
| 11 |
+
TokenClassificationPipeline
|
| 12 |
+
)
|
| 13 |
+
import pythainlp
|
| 14 |
+
from pprint import pprint
|
| 15 |
+
from itertools import chain
|
| 16 |
+
|
| 17 |
+
import gradio as gr
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
ner_pipeline_group = TokenClassificationPipeline(
|
| 21 |
+
model=AutoModelForTokenClassification.from_pretrained(
|
| 22 |
+
'airesearch/wangchanberta-base-att-spm-uncased',
|
| 23 |
+
revision='finetuned@thainer-ner'
|
| 24 |
+
),
|
| 25 |
+
tokenizer=AutoTokenizer.from_pretrained(
|
| 26 |
+
'airesearch/wangchanberta-base-att-spm-uncased',
|
| 27 |
+
revision='finetuned@thainer-ner'
|
| 28 |
+
),
|
| 29 |
+
space_token='<_>',
|
| 30 |
+
lowercase=True,
|
| 31 |
+
group_entities=True,
|
| 32 |
+
strict=False,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
color_mapper = {
|
| 36 |
+
"DATE": "#f94144",
|
| 37 |
+
"EMAIL":"#f3722c",
|
| 38 |
+
"LAW":"#f8961e",
|
| 39 |
+
"LEN":"#f9844a",
|
| 40 |
+
"LOCATION":"#f9c74f",
|
| 41 |
+
"MONEY":"#ffcb77",
|
| 42 |
+
"ORGANIZATION":"#f5cac3",
|
| 43 |
+
"PERCENT":"#90be6d",
|
| 44 |
+
"PERSON":"#bfd200",
|
| 45 |
+
"PHONE":"#43aa8b",
|
| 46 |
+
"TIME":"#4d908e",
|
| 47 |
+
"URL":"#577590",
|
| 48 |
+
"ZIP":"#90e0ef",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
css_text = 'p{width: 700px; color: #333; border-radius: 3px; border: solid 1.5px #DDD; background-color: #FFF;\n margin: 10px;\n padding: 30px}\n'
|
| 53 |
+
for k,v in color_mapper.items():
|
| 54 |
+
css_text += "span."+f"{k.lower()}" \
|
| 55 |
+
+"{\n background-color: " \
|
| 56 |
+
+f"{v}"+"50;\n color: #333;\n border-right: 4px solid " \
|
| 57 |
+
+f"{v}"+";" \
|
| 58 |
+
+ "\n align-items: center;" \
|
| 59 |
+
+ "\n margin: 0;" \
|
| 60 |
+
+ "\n padding: 2px 8px;" \
|
| 61 |
+
+ "\n border-radius: 3px;\n}\n" \
|
| 62 |
+
+"span."+f"{k.lower()}"+"::after {" \
|
| 63 |
+
+"\npadding: 2px 1px;" \
|
| 64 |
+
+"font-size: 9.5px;" \
|
| 65 |
+
+"font-weight: bold;" \
|
| 66 |
+
+"font-family: Monaco;" \
|
| 67 |
+
+"vertical-align: super;" \
|
| 68 |
+
+"content: \"" + k.upper() + "\";" \
|
| 69 |
+
+"}\n" \
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def modifiy_segment(text, tag, start, end):
|
| 73 |
+
replaced_text = text[:start] + f'<span class="{tag}">' + text[start:end] +'</span>' + text[end:]
|
| 74 |
+
return replaced_text, len(f'<span class="{tag}">') + len('</span>')
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def render_doc_with_label(label: Dict, doc: str):
|
| 78 |
+
attribute_items = []
|
| 79 |
+
for i, ne_span in enumerate(label):
|
| 80 |
+
if ne_span['entity_group'] != 'O':
|
| 81 |
+
attribute_name = ne_span['entity_group']
|
| 82 |
+
attribute_name = attribute_name.lower()
|
| 83 |
+
|
| 84 |
+
begin_char_idx = ne_span['begin_char_index']
|
| 85 |
+
|
| 86 |
+
tagged_text = ne_span['word']
|
| 87 |
+
end_char_idx = begin_char_idx + len(tagged_text)
|
| 88 |
+
|
| 89 |
+
attribute_items.append((attribute_name, begin_char_idx, end_char_idx))
|
| 90 |
+
|
| 91 |
+
attribute_items = sorted(attribute_items, key=lambda x: (x[1]))
|
| 92 |
+
print(f'attribute_items: {attribute_items}')
|
| 93 |
+
|
| 94 |
+
acc_n_extra_chars = 0
|
| 95 |
+
modified_segment = doc
|
| 96 |
+
for _selected_attribute_item in attribute_items:
|
| 97 |
+
|
| 98 |
+
tag, start, end = _selected_attribute_item[0], _selected_attribute_item[1], _selected_attribute_item[2]
|
| 99 |
+
|
| 100 |
+
modified_segment, n_extra_chars = modifiy_segment(modified_segment, tag, start + acc_n_extra_chars, end + acc_n_extra_chars)
|
| 101 |
+
acc_n_extra_chars += n_extra_chars
|
| 102 |
+
|
| 103 |
+
return f'<style>{css_text}</style><p>{modified_segment}</p>'
|
| 104 |
+
|
| 105 |
+
def ner_tagging(text: str):
|
| 106 |
+
results = ner_pipeline_group(text)
|
| 107 |
+
print(f'results:\n{results}')
|
| 108 |
+
html_text = render_doc_with_label(results, text)
|
| 109 |
+
|
| 110 |
+
return json.dumps(results, ensure_ascii=False, indent=4), html_text
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
demo = gr.Interface(fn=ner_tagging,
|
| 114 |
+
inputs=gr.Textbox(lines=5, placeholder='Input text in Thai', label='Input text'),
|
| 115 |
+
examples=[
|
| 116 |
+
["ไมโครซอฟท์ได้จัดจำหน่ายบนแพลตฟอร์มไมโครซอฟท์ วินโดวส์ ในเดือนเมษายน 2020"],
|
| 117 |
+
['ชัชชาติ สิทธิพันธุ์ ผู้ว่าราชการกรุงเทพมหานคร (กทม.) คนที่ 17 เตรียมเข้ารับตำแหน่งอย่างเป็นทางการและเปิดตัวทีมงานในช่วงบ่ายวันนี้ (1 มิ.ย.) หลังรับมอบหนังสือรับรองการเป็นผู้ว่าฯ กทม. ที่สำนักงานคณะกรรมการการเลือกตั้ง (กกต.)'],
|
| 118 |
+
["สถาบันวิทยาศาสตร์ทางทะเล มหาวิทยาลัยบูรพา เปิดให้บริการมายาวนานกว่า 30 ปี ตั้งอยู่บริเวณด้านหน้า มหาวิทยาลัยบูรพา บนเนื้อที่กว่า 30 ไร่ เป็นสถานที่ท่องเที่ยว ที่จัดแสดงเพื่อให้ความรู้��กี่ยวกับวิทยาศาสตร์ทางทะเล สิ่งมีชีวิตและความเป็นอยู่ของสัตว์ทะเลชนิดต่างๆที่อาศัยอยู่ในเขตน่านน้ำของไทย"],
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
],
|
| 122 |
+
|
| 123 |
+
outputs=[gr.Textbox(), gr.HTML()])
|
| 124 |
+
|
| 125 |
+
print(f'\nINFO: transformers.__version__: {transformers.__version__}')
|
| 126 |
+
print(f'\nINFO: pythainlp.__version__: {pythainlp.__version__}')
|
| 127 |
+
demo.launch()
|
pipeline.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from typing import Callable, List, Tuple, Union
|
| 4 |
+
from functools import partial
|
| 5 |
+
import itertools
|
| 6 |
+
|
| 7 |
+
from seqeval.scheme import Tokens, IOB2, IOBES
|
| 8 |
+
|
| 9 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 10 |
+
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
| 11 |
+
from pythainlp.tokenize import word_tokenize as pythainlp_word_tokenize
|
| 12 |
+
newmm_word_tokenizer = partial(pythainlp_word_tokenize, keep_whitespace=True, engine='newmm')
|
| 13 |
+
|
| 14 |
+
from thai2transformers.preprocess import rm_useless_spaces
|
| 15 |
+
|
| 16 |
+
SPIECE = '▁'
|
| 17 |
+
|
| 18 |
+
class TokenClassificationPipeline:
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
model: PreTrainedModel,
|
| 22 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 23 |
+
pretokenizer: Callable[[str], List[str]] = newmm_word_tokenizer,
|
| 24 |
+
lowercase=False,
|
| 25 |
+
space_token='<_>',
|
| 26 |
+
device: int = -1,
|
| 27 |
+
group_entities: bool = False,
|
| 28 |
+
strict: bool = False,
|
| 29 |
+
tag_delimiter: str = '-',
|
| 30 |
+
scheme: str = 'IOB',
|
| 31 |
+
use_crf=False,
|
| 32 |
+
remove_spiece=True):
|
| 33 |
+
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
assert isinstance(tokenizer, PreTrainedTokenizerBase)
|
| 37 |
+
# assert isinstance(model, PreTrainedModel)
|
| 38 |
+
|
| 39 |
+
self.model = model
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
self.pretokenizer = pretokenizer
|
| 42 |
+
self.lowercase = lowercase
|
| 43 |
+
self.space_token = space_token
|
| 44 |
+
self.device = 'cpu' if device == -1 or not torch.cuda.is_available() else f'cuda:{device}'
|
| 45 |
+
self.group_entities = group_entities
|
| 46 |
+
self.strict = strict
|
| 47 |
+
self.tag_delimiter = tag_delimiter
|
| 48 |
+
self.scheme = scheme
|
| 49 |
+
self.id2label = self.model.config.id2label
|
| 50 |
+
self.label2id = self.model.config.label2id
|
| 51 |
+
self.use_crf = use_crf
|
| 52 |
+
self.remove_spiece = remove_spiece
|
| 53 |
+
self.model.to(self.device)
|
| 54 |
+
|
| 55 |
+
def preprocess(self, inputs: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
|
| 56 |
+
|
| 57 |
+
if self.lowercase:
|
| 58 |
+
inputs = inputs.lower() if type(inputs) == str else list(map(str.lower, inputs))
|
| 59 |
+
|
| 60 |
+
inputs = rm_useless_spaces(inputs) if type(inputs) == str else list(map(rm_useless_spaces, inputs))
|
| 61 |
+
|
| 62 |
+
tokens = self.pretokenizer(inputs) if type(inputs) == str else list(map(self.pretokenizer, inputs))
|
| 63 |
+
|
| 64 |
+
tokens = list(map(lambda x: x.replace(' ', self.space_token), tokens)) if type(inputs) == str else \
|
| 65 |
+
list(map(lambda _tokens: list(map(lambda x: x.replace(' ', self.space_token), _tokens)), tokens))
|
| 66 |
+
|
| 67 |
+
return tokens
|
| 68 |
+
|
| 69 |
+
def _inference(self, input: str):
|
| 70 |
+
|
| 71 |
+
tokens = [[self.tokenizer.bos_token]] + \
|
| 72 |
+
[self.tokenizer.tokenize(tok) if tok != SPIECE else [SPIECE] for tok in self.preprocess(input)] + \
|
| 73 |
+
[[self.tokenizer.eos_token]]
|
| 74 |
+
ids = [self.tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
| 75 |
+
flatten_tokens = list(itertools.chain(*tokens))
|
| 76 |
+
flatten_ids = list(itertools.chain(*ids))
|
| 77 |
+
|
| 78 |
+
input_ids = torch.LongTensor([flatten_ids]).to(self.device)
|
| 79 |
+
|
| 80 |
+
if self.use_crf:
|
| 81 |
+
out = self.model(input_ids=input_ids)
|
| 82 |
+
else:
|
| 83 |
+
out = self.model(input_ids=input_ids, return_dict=True)
|
| 84 |
+
probs = torch.softmax(out['logits'], dim=-1)
|
| 85 |
+
vals, indices = probs.topk(1)
|
| 86 |
+
indices_np = indices.detach().cpu().numpy().reshape(-1)
|
| 87 |
+
|
| 88 |
+
list_of_token_label_tuple = list(zip(flatten_tokens, [ self.id2label[idx] for idx in indices_np] ))
|
| 89 |
+
merged_preds = self._merged_pred(preds=list_of_token_label_tuple, ids=ids)
|
| 90 |
+
if self.remove_spiece:
|
| 91 |
+
merged_preds = list(map(lambda x: (x[0].replace(SPIECE, ''), x[1]), merged_preds))
|
| 92 |
+
|
| 93 |
+
# remove start and end tokens
|
| 94 |
+
merged_preds_removed_bos_eos = merged_preds[1:-1]
|
| 95 |
+
# convert to list of Dict objects
|
| 96 |
+
merged_preds_return_dict = [ {'word': word if word != self.space_token else ' ', 'entity': tag, '√': idx } \
|
| 97 |
+
for idx, (word, tag) in enumerate(merged_preds_removed_bos_eos) ]
|
| 98 |
+
|
| 99 |
+
if (not self.group_entities or self.scheme == None) and self.strict == True:
|
| 100 |
+
return merged_preds_return_dict
|
| 101 |
+
elif not self.group_entities and self.strict == False:
|
| 102 |
+
|
| 103 |
+
tags = list(map(lambda x: x['entity'], merged_preds_return_dict))
|
| 104 |
+
processed_tags = self._fix_incorrect_tags(tags)
|
| 105 |
+
for i, item in enumerate(merged_preds_return_dict):
|
| 106 |
+
merged_preds_return_dict[i]['entity'] = processed_tags[i]
|
| 107 |
+
return merged_preds_return_dict
|
| 108 |
+
elif self.group_entities:
|
| 109 |
+
return self._group_entities(merged_preds_removed_bos_eos)
|
| 110 |
+
|
| 111 |
+
def __call__(self, inputs: Union[str, List[str]]):
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
if type(inputs) == str:
|
| 117 |
+
return self._inference(inputs)
|
| 118 |
+
|
| 119 |
+
if type(inputs) == list:
|
| 120 |
+
results = [ self._inference(text) for text in inputs]
|
| 121 |
+
return results
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _merged_pred(self, preds: List[Tuple[str, str]], ids: List[List[int]]):
|
| 125 |
+
|
| 126 |
+
token_mapping = [ ]
|
| 127 |
+
for i in range(0, len(ids)):
|
| 128 |
+
for j in range(0, len(ids[i])):
|
| 129 |
+
token_mapping.append(i)
|
| 130 |
+
|
| 131 |
+
grouped_subtokens = []
|
| 132 |
+
_subtoken = []
|
| 133 |
+
prev_idx = 0
|
| 134 |
+
|
| 135 |
+
for i, (subtoken, label) in enumerate(preds):
|
| 136 |
+
|
| 137 |
+
current_idx = token_mapping[i]
|
| 138 |
+
if prev_idx != current_idx:
|
| 139 |
+
grouped_subtokens.append(_subtoken)
|
| 140 |
+
_subtoken = [(subtoken, label)]
|
| 141 |
+
if i == len(preds) -1:
|
| 142 |
+
_subtoken = [(subtoken, label)]
|
| 143 |
+
grouped_subtokens.append(_subtoken)
|
| 144 |
+
elif i == len(preds) -1:
|
| 145 |
+
_subtoken += [(subtoken, label)]
|
| 146 |
+
grouped_subtokens.append(_subtoken)
|
| 147 |
+
else:
|
| 148 |
+
_subtoken += [(subtoken, label)]
|
| 149 |
+
prev_idx = current_idx
|
| 150 |
+
|
| 151 |
+
merged_subtokens = []
|
| 152 |
+
_merged_subtoken = ''
|
| 153 |
+
for subtoken_group in grouped_subtokens:
|
| 154 |
+
|
| 155 |
+
first_token_pred = subtoken_group[0][1]
|
| 156 |
+
_merged_subtoken = ''.join(list(map(lambda x: x[0], subtoken_group)))
|
| 157 |
+
merged_subtokens.append((_merged_subtoken, first_token_pred))
|
| 158 |
+
return merged_subtokens
|
| 159 |
+
|
| 160 |
+
def _fix_incorrect_tags(self, tags: List[str]) -> List[str]:
|
| 161 |
+
|
| 162 |
+
I_PREFIX = f'I{self.tag_delimiter}'
|
| 163 |
+
E_PREFIX = f'E{self.tag_delimiter}'
|
| 164 |
+
B_PREFIX = f'B{self.tag_delimiter}'
|
| 165 |
+
O_PREFIX = 'O'
|
| 166 |
+
|
| 167 |
+
previous_tag_ne = None
|
| 168 |
+
for i, current_tag in enumerate(tags):
|
| 169 |
+
|
| 170 |
+
current_tag_ne = current_tag.split(self.tag_delimiter)[-1] if current_tag != O_PREFIX else O_PREFIX
|
| 171 |
+
|
| 172 |
+
if i == 0 and (current_tag.startswith(I_PREFIX) or \
|
| 173 |
+
current_tag.startswith(E_PREFIX)):
|
| 174 |
+
# if a NE tag (with I-, or E- prefix) occuring at the begining of sentence
|
| 175 |
+
# e.g. (I-LOC, I-LOC) , (E-LOC, B-PER) (I-LOC, O, O)
|
| 176 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
| 177 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
| 178 |
+
elif i >= 1 and tags[i-1] == O_PREFIX and (
|
| 179 |
+
current_tag.startswith(I_PREFIX) or \
|
| 180 |
+
current_tag.startswith(E_PREFIX)):
|
| 181 |
+
# if a NE tag (with I-, or E- prefix) occuring after O tag
|
| 182 |
+
# e.g. (O, I-LOC, I-LOC) , (O, E-LOC, B-PER) (O, I-LOC, O, O)
|
| 183 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
| 184 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
| 185 |
+
elif i >= 1 and ( tags[i-1].startswith(I_PREFIX) or \
|
| 186 |
+
tags[i-1].startswith(E_PREFIX) or \
|
| 187 |
+
tags[i-1].startswith(B_PREFIX)) and \
|
| 188 |
+
( current_tag.startswith(I_PREFIX) or current_tag.startswith(E_PREFIX) ) and \
|
| 189 |
+
previous_tag_ne != current_tag_ne:
|
| 190 |
+
# if a NE tag (with I-, or E- prefix) occuring after NE tag with different NE
|
| 191 |
+
# e.g. (B-LOC, I-PER) , (B-LOC, E-LOC, E-PER) (B-LOC, I-LOC, I-PER)
|
| 192 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
| 193 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
| 194 |
+
elif i == len(tags) - 1 and tags[i-1] == O_PREFIX and (
|
| 195 |
+
current_tag.startswith(I_PREFIX) or \
|
| 196 |
+
current_tag.startswith(E_PREFIX)):
|
| 197 |
+
# if a NE tag (with I-, or E- prefix) occuring at the end of sentence
|
| 198 |
+
# e.g. (O, O, I-LOC) , (O, O, E-LOC)
|
| 199 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
| 200 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
| 201 |
+
|
| 202 |
+
previous_tag_ne = current_tag_ne
|
| 203 |
+
|
| 204 |
+
return tags
|
| 205 |
+
|
| 206 |
+
def _group_entities(self, ner_tags: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
| 207 |
+
|
| 208 |
+
if self.scheme not in ['IOB', 'IOBES', 'IOBE']:
|
| 209 |
+
raise AttributeError()
|
| 210 |
+
|
| 211 |
+
tokens, tags = zip(*ner_tags)
|
| 212 |
+
tokens, tags = list(tokens), list(tags)
|
| 213 |
+
|
| 214 |
+
if self.scheme == 'IOBE':
|
| 215 |
+
# Replace E prefix with I prefix
|
| 216 |
+
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
|
| 217 |
+
if self.scheme == 'IOBES':
|
| 218 |
+
# Replace E prefix with I prefix and replace S prefix with B
|
| 219 |
+
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
|
| 220 |
+
tags = list(map(lambda x: x.replace(f'S{self.tag_delimiter}', f'B{self.tag_delimiter}'), tags))
|
| 221 |
+
|
| 222 |
+
if not self.strict:
|
| 223 |
+
|
| 224 |
+
tags = self._fix_incorrect_tags(tags)
|
| 225 |
+
|
| 226 |
+
ent = Tokens(tokens=tags, scheme=IOB2,
|
| 227 |
+
suffix=False, delimiter=self.tag_delimiter)
|
| 228 |
+
|
| 229 |
+
ne_position_mappings = ent.entities
|
| 230 |
+
token_positions = []
|
| 231 |
+
curr_len = 0
|
| 232 |
+
tokens = list(map(lambda x: x.replace('<_>', ' ').replace('ํา', 'ำ'), tokens))
|
| 233 |
+
for i, token in enumerate(tokens):
|
| 234 |
+
token_len = len(token)
|
| 235 |
+
if i == 0:
|
| 236 |
+
token_positions.append((0, curr_len + token_len))
|
| 237 |
+
else:
|
| 238 |
+
token_positions.append((curr_len, curr_len + token_len ))
|
| 239 |
+
curr_len += token_len
|
| 240 |
+
print(f'token_positions: {list(zip(tokens, token_positions))}')
|
| 241 |
+
begin_end_pos = []
|
| 242 |
+
begin_end_char_pos = []
|
| 243 |
+
accum_char_len = 0
|
| 244 |
+
for i, ne_position_mapping in enumerate(ne_position_mappings):
|
| 245 |
+
print(f'ne_position_mapping.start: {ne_position_mapping.start}')
|
| 246 |
+
print(f'ne_position_mapping.end: {ne_position_mapping.end}\n')
|
| 247 |
+
begin_end_pos.append((ne_position_mapping.start, ne_position_mapping.end))
|
| 248 |
+
begin_end_char_pos.append((token_positions[ne_position_mapping.start][0], token_positions[ne_position_mapping.end-1][1]))
|
| 249 |
+
print(f'begin_end_pos: {begin_end_pos}')
|
| 250 |
+
print(f'begin_end_char_pos: {begin_end_char_pos}')
|
| 251 |
+
|
| 252 |
+
j = 0
|
| 253 |
+
# print(f'tokens: {tokens}')
|
| 254 |
+
for i, pos_tuple in enumerate(begin_end_pos):
|
| 255 |
+
# print(f'j = {j}')
|
| 256 |
+
if pos_tuple[0] > 0 and i == 0:
|
| 257 |
+
ne_position_mappings.insert(0, (None, 'O', 0, pos_tuple[0]))
|
| 258 |
+
j += 1
|
| 259 |
+
if begin_end_pos[i-1][1] != begin_end_pos[i][0] and len(begin_end_pos) > 1 and i > 0 :
|
| 260 |
+
ne_position_mappings.insert(j, (None, 'O', begin_end_pos[i-1][1], begin_end_pos[i][0]))
|
| 261 |
+
j += 1
|
| 262 |
+
|
| 263 |
+
j += 1
|
| 264 |
+
print('ne_position_mappings', ne_position_mappings)
|
| 265 |
+
|
| 266 |
+
groups = []
|
| 267 |
+
k = 0
|
| 268 |
+
for i, ne_position_mapping in enumerate(ne_position_mappings):
|
| 269 |
+
if type(ne_position_mapping) != tuple:
|
| 270 |
+
ne_position_mapping = ne_position_mapping.to_tuple()
|
| 271 |
+
ne = ne_position_mapping[1]
|
| 272 |
+
|
| 273 |
+
text = ''
|
| 274 |
+
for ne_position in range(ne_position_mapping[2], ne_position_mapping[3]):
|
| 275 |
+
_token = tokens[ne_position]
|
| 276 |
+
text += _token if _token != self.space_token else ' '
|
| 277 |
+
if ne.lower() != 'o':
|
| 278 |
+
groups.append({
|
| 279 |
+
'entity_group': ne,
|
| 280 |
+
'word': text,
|
| 281 |
+
'begin_char_index': begin_end_char_pos[k][0]
|
| 282 |
+
})
|
| 283 |
+
k+=1
|
| 284 |
+
else:
|
| 285 |
+
groups.append({
|
| 286 |
+
'entity_group': ne,
|
| 287 |
+
'word': text,
|
| 288 |
+
})
|
| 289 |
+
return groups
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
git+https://github.com/vistec-ai/thai2transformers.git@feature/add_ner_scheme
|
| 3 |
+
pythainlp==2.2.4
|