Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Mon Sep 11 09:46:51 2023 | |
| @author: peter | |
| """ | |
| from allennlp.predictors.predictor import Predictor | |
| import pandas | |
| def clean(sentence): | |
| """ | |
| Ensure sentence ends with full stop | |
| Parameters | |
| ---------- | |
| sentence : str | |
| Sentence to be cleaned | |
| Returns | |
| ------- | |
| str | |
| Sentence with full stop at the end. | |
| """ | |
| return sentence if sentence.strip().endswith('.') else sentence+'.' | |
| class CoreferenceResolver(object): | |
| def __init__(self): | |
| """ | |
| Creates the Coreference resolver | |
| Returns | |
| ------- | |
| None. | |
| """ | |
| model_url = "https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz" | |
| self.predictor = Predictor.from_path(model_url) | |
| def __call__(self,group): | |
| """ | |
| Parameters | |
| ---------- | |
| group : pandas.Series | |
| Sentences on which to perform coreference resolution | |
| Returns | |
| ------- | |
| pandas.Series | |
| Sentences with coreferences resolved | |
| """ | |
| tokenized = group.apply(clean).str.split() | |
| line_breaks = tokenized.apply(len).cumsum() | |
| doc = [] | |
| for line in tokenized: | |
| doc.extend(line) | |
| clusters = self.predictor.predict_tokenized(doc) | |
| resolutions = {} | |
| for cluster in clusters['clusters']: | |
| starts = [] | |
| longest = -1 | |
| canonical = None | |
| for [start_pos,end_pos] in cluster: | |
| resolutions[start_pos]={'end':end_pos+1} | |
| starts.append(start_pos) | |
| length = end_pos - start_pos | |
| if length > longest: | |
| longest = length | |
| canonical = doc[start_pos:end_pos+1] | |
| for start in starts: | |
| resolutions[start]['canonical']=canonical | |
| doc_pos = 0 | |
| line = 0 | |
| results = [] | |
| current = [] | |
| while doc_pos < len(doc): | |
| if doc_pos in resolutions: | |
| current.extend(resolutions[doc_pos]['canonical']) | |
| doc_pos=resolutions[doc_pos]['end'] | |
| else: | |
| current.append(doc[doc_pos]) | |
| doc_pos+=1 | |
| if doc_pos>=line_breaks.iloc[line]: | |
| results.append(' '.join(current)) | |
| line+=1 | |
| current = [] | |
| return pandas.Series(results, | |
| index=group.index) | |