Spaces:
Build error
Build error
| # -*- coding:utf-8 -*- | |
| """ | |
| @Last modified date : 2020/12/23 | |
| """ | |
| import re | |
| import nltk | |
| from nltk.stem import WordNetLemmatizer | |
| from allennlp.predictors.predictor import Predictor | |
| nltk.download('wordnet') | |
| nltk.download('stopwords') | |
| def deal_bracket(text, restore, leading_ent=None): | |
| if leading_ent: | |
| leading_ent = ' '.join(leading_ent.split('_')) | |
| text = f'Things about {leading_ent}: ' + text | |
| if restore: | |
| text = text.replace('-LRB-', '(').replace('-RRB-', ')') | |
| text = text.replace('LRB', '(').replace('RRB', ')') | |
| return text | |
| def refine_entity(entity): | |
| entity = re.sub(r'-LRB- .+ -RRB-$', '', entity) | |
| entity = re.sub(r'LRB .+ RRB$', '', entity) | |
| entity = re.sub(r'_', ' ', entity) | |
| entity = re.sub(r'\s+', ' ', entity) | |
| return entity.strip() | |
| def find_sub_seq(seq_a, seq_b, shift=0, uncased=False, lemmatizer=None): | |
| if uncased: | |
| seq_a = [token.lower() for token in seq_a] | |
| seq_b = [token.lower() for token in seq_b] | |
| if lemmatizer is not None: | |
| seq_a = [lemmatizer.lemmatize(token) for token in seq_a] | |
| seq_b = [lemmatizer.lemmatize(token) for token in seq_b] | |
| for i in range(shift, len(seq_a)): | |
| if seq_a[i:i+len(seq_b)] == seq_b: | |
| return i, i + len(seq_b) | |
| return -1, -1 | |
| def is_sub_seq(seq_start, seq_end, all_seqs): | |
| for start, end, is_candidate in all_seqs: | |
| if start <= seq_start < seq_end <= end: | |
| return start, end, is_candidate | |
| return None | |
| # extract named entity with B-I-L-U-O schema | |
| def extract_named_entity(tags): | |
| all_NEs = [] | |
| ne_type, ne_start = '', -1 | |
| for i, t in enumerate(tags): | |
| if t == 'O': | |
| ne_type, ne_start = '', -1 | |
| continue | |
| t1, t2 = t.split('-') | |
| if t1 == 'B': | |
| ne_type, ne_start = t2, i | |
| elif t1 == 'I' and t2 != ne_type: | |
| ne_type, ne_start = '', -1 | |
| elif t1 == 'L' and t2 != ne_type: | |
| ne_type, ne_start = '', -1 | |
| elif t1 == 'L' and t2 == ne_type: | |
| all_NEs.append((ne_start, i + 1, False)) | |
| ne_type, ne_start = '', -1 | |
| elif t1 == 'U': | |
| all_NEs.append((i, i + 1, False)) | |
| ne_type, ne_start = '', -1 | |
| return all_NEs | |
| def refine_results(tokens, spans, stopwords): | |
| all_spans = [] | |
| for span_start, span_end, is_candidate in spans: | |
| # remove stopwords | |
| if not is_candidate: | |
| while span_start < span_end and tokens[span_start].lower() in stopwords: | |
| span_start += 1 | |
| if span_start >= span_end: | |
| continue | |
| # add prefix | |
| if span_start > 0 and tokens[span_start - 1] in ['a', 'an', 'A', 'An', 'the', 'The']: | |
| span_start -= 1 | |
| # convert token-level index into char-level index | |
| span = ' '.join(tokens[span_start:span_end]) | |
| span_start = len(' '.join(tokens[:span_start])) + 1 * min(1, span_start) # 1 for blank | |
| span_end = span_start + len(span) | |
| all_spans.append((span, span_start, span_end)) | |
| all_spans = sorted(all_spans, key=lambda x: (x[1], x[1] - x[2])) | |
| # remove overlap | |
| refined_spans = [] | |
| for span, span_start, span_end in all_spans: | |
| flag = True | |
| for _, start, end in refined_spans: | |
| if start <= span_start < span_end <= end: | |
| flag = False | |
| break | |
| if flag: | |
| refined_spans.append((span, span_start, span_end)) | |
| return refined_spans | |
| class SentenceParser: | |
| def __init__(self, device='cuda:0', | |
| ner_path="https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz", | |
| cp_path="https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"): | |
| self.device = self.parse_device(device) | |
| self.ner = Predictor.from_path(ner_path, cuda_device=self.device) | |
| print('* ner loaded') | |
| self.cp = Predictor.from_path(cp_path, cuda_device=self.device) | |
| print('* constituency parser loaded') | |
| self.lemmatizer = WordNetLemmatizer() | |
| # some heuristic rules can be added here | |
| self.stopwords = set(nltk.corpus.stopwords.words('english')) | |
| self.stopwords.update({'-', '\'s', 'try', 'tries', 'tried', 'trying', | |
| 'become', 'becomes', 'became', 'becoming', | |
| 'make', 'makes', 'made', 'making', 'call', 'called', 'calling', | |
| 'put', 'ever', 'something', 'someone', 'sometime'}) | |
| self.special_tokens = ['only', 'most', 'before', 'after', 'behind'] | |
| for token in self.special_tokens: | |
| if token in self.stopwords: self.stopwords.remove(token) | |
| if 'won' in self.stopwords: self.stopwords.remove('won') | |
| if 'own' in self.stopwords: self.stopwords.remove('own') | |
| def parse_device(self, device): | |
| if 'cpu' in device: | |
| return -1 | |
| else: | |
| dev = re.findall('\d+', device) | |
| return 0 if len(dev) == 0 else int(dev[0]) | |
| def identify_NPs(self, text, candidate_NPs=None): | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| if len(text) == 0: return {'text': '', 'NPs': [], 'verbs': [], 'adjs': []} | |
| cp_outputs = self.cp.predict(text) | |
| ner_outputs = self.ner.predict(text) | |
| tokens = cp_outputs['tokens'] | |
| pos_tags = cp_outputs['pos_tags'] | |
| ner_tags = ner_outputs['tags'] | |
| tree = cp_outputs['hierplane_tree']['root'] | |
| # extract candidate noun phrases passed by user with token index | |
| all_NPs = [] | |
| candidate_NPs = [refine_entity(np).split() for np in candidate_NPs] if candidate_NPs else [] | |
| for np in sorted(candidate_NPs, key=len, reverse=True): | |
| np_start, np_end = find_sub_seq(tokens, np, 0, uncased=True, lemmatizer=self.lemmatizer) | |
| while np_start != -1 and np_end != -1: | |
| if not is_sub_seq(np_start, np_end, all_NPs): | |
| all_NPs.append((np_start, np_end, True)) | |
| np_start, np_end = find_sub_seq(tokens, np, np_end, uncased=True, lemmatizer=self.lemmatizer) | |
| # extract noun phrases from tree | |
| def _get_bottom_NPs(children): | |
| if 'children' not in children: | |
| return None | |
| if {'NP', 'OP', 'XP', 'QP'} & set(children['attributes']): | |
| is_bottom = True | |
| for child in children['children']: | |
| if 'children' in child: | |
| is_bottom = False | |
| if is_bottom: | |
| bottom_NPs.append(children['word'].split()) | |
| else: | |
| for child in children['children']: | |
| _get_bottom_NPs(child) | |
| else: | |
| for child in children['children']: | |
| _get_bottom_NPs(child) | |
| bottom_NPs = [] | |
| _get_bottom_NPs(tree) | |
| # find token indices of noun phrases | |
| np_index = -1 | |
| for np in bottom_NPs: | |
| np_start, np_end = find_sub_seq(tokens, np, np_index + 1) | |
| if not is_sub_seq(np_start, np_end, all_NPs): | |
| all_NPs.append((np_start, np_end, False)) | |
| np_index = np_end | |
| # extract named entities with token index | |
| all_NEs = extract_named_entity(ner_tags) | |
| # extract verbs with token index | |
| all_verbs = [] | |
| for i, pos in enumerate(pos_tags): | |
| if pos[0] == 'V': | |
| if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): | |
| all_verbs.append((i, i + 1, False)) | |
| # extract modifiers with token index | |
| all_modifiers = [] | |
| for i, (token, pos) in enumerate(zip(tokens, pos_tags)): | |
| if pos in ['JJ', 'RB']: # adj. and adv. | |
| if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): | |
| all_modifiers.append((i, i + 1, False)) | |
| elif token in self.special_tokens: | |
| if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs): | |
| all_modifiers.append((i, i + 1, False)) | |
| # split noun phrases with named entities | |
| all_spans = [] | |
| for np_start, np_end, np_is_candidate in all_NPs: | |
| if np_is_candidate: # candidate noun phrases will be preserved | |
| all_spans.append((np_start, np_end, np_is_candidate)) | |
| else: | |
| match = is_sub_seq(np_start, np_end, all_NEs) | |
| if match: # if a noun phrase is a sub span of a named entity, the named entity will be preserved | |
| all_spans.append(match) | |
| else: # else if a named entity is a sub span of a noun phrase, the noun phrase will be split | |
| index = np_start | |
| for ne_start, ne_end, ne_is_candidate in all_NEs: | |
| if np_start <= ne_start < ne_end <= np_end: | |
| all_modifiers.append((index, ne_start, False)) | |
| all_spans.append((ne_start, ne_end, ne_is_candidate)) | |
| index = ne_end | |
| all_spans.append((index, np_end, False)) | |
| # named entities without overlapping | |
| for ne_start, ne_end, is_candidate in all_NEs: | |
| if not is_sub_seq(ne_start, ne_end, all_spans): | |
| all_spans.append((ne_start, ne_end, is_candidate)) | |
| all_spans = refine_results(tokens, all_spans, self.stopwords) | |
| all_verbs = refine_results(tokens, all_verbs, self.stopwords) | |
| all_modifiers = refine_results(tokens, all_modifiers, self.stopwords) | |
| return {'text': tree['word'], 'NPs': all_spans, 'verbs': all_verbs, 'adjs': all_modifiers} | |
| if __name__ == '__main__': | |
| import json | |
| print('Initializing sentence parser.') | |
| client = SentenceParser(device='cpu') | |
| print('Parsing sentence.') | |
| sentence = "The Africa Cup of Nations is held in odd - numbered years due to conflict with the World Cup . " | |
| entities = ['Africa Cup of Nations', 'Africa_Cup_of_Nations', 'Africa Cup', 'Africa_Cup'] | |
| results = client.identify_NPs(sentence, entities) | |
| print(json.dumps(results, ensure_ascii=False, indent=4)) | |
| # import random | |
| # from tqdm import tqdm | |
| # from utils import read_json_lines, save_json | |
| # | |
| # print('Parsing file.') | |
| # results = [] | |
| # data = list(read_json_lines('data/train.jsonl')) | |
| # random.shuffle(data) | |
| # for entry in tqdm(data[:100]): | |
| # results.append(client.identify_NPs(entry['claim'])) | |
| # save_json(results, 'data/results.json') | |