| | import os |
| | from typing import Dict, List, Any |
| | from long_coref.coref.prediction import CorefPredictor |
| | from long_coref.coref.utils import ArchiveContent |
| | from allennlp.common.params import Params |
| |
|
| | CHECKPOINT = "coref-spanbert-large-2021.03.10" |
| |
|
| |
|
| | class PreTrainedPipeline: |
| | def __init__(self, path=""): |
| | archive_content = ArchiveContent( |
| | archive_dir=os.path.join(path, CHECKPOINT), |
| | weight_path=os.path.join(path, CHECKPOINT, "weights.th"), |
| | config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")), |
| | ) |
| | self.predictor = CorefPredictor.from_extracted_archive(archive_content) |
| |
|
| | def __call__(self, data: str) -> Dict[str, Any]: |
| | """ |
| | data args: |
| | inputs (:obj: `str`) |
| | date (:obj: `str`) |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | |
| | prediction = self.predictor.resolve_paragraphs(data.split("\n\n")) |
| | return prediction.to_dict() |
| |
|