Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,12 +21,16 @@ class ExampleDocument(TextDocument):
|
|
| 21 |
|
| 22 |
ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
|
| 23 |
re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited"
|
| 24 |
-
#"pie/example-re-textclf-tacred"
|
| 25 |
-
#"DFKI-SLT/relation_classification_tacred_revisited"
|
| 26 |
|
| 27 |
ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
|
| 28 |
re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def predict(text):
|
| 32 |
document = ExampleDocument(text)
|
|
@@ -34,10 +38,16 @@ def predict(text):
|
|
| 34 |
ner_pipeline(document)
|
| 35 |
|
| 36 |
while len(document.entities.predictions) > 0:
|
| 37 |
-
document.entities.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
re_pipeline(document)
|
| 40 |
|
|
|
|
| 41 |
t = PrettyTable()
|
| 42 |
t.field_names = ["head", "tail", "relation"]
|
| 43 |
t.align = "l"
|
|
|
|
| 21 |
|
| 22 |
ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
|
| 23 |
re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited"
|
|
|
|
|
|
|
| 24 |
|
| 25 |
ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
|
| 26 |
re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)
|
| 27 |
|
| 28 |
+
ner_tag_mapping = {
|
| 29 |
+
'ORG': 'ORGANIZATION',
|
| 30 |
+
'PER': 'PERSON',
|
| 31 |
+
'LOC': 'LOCATION'
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
|
| 35 |
def predict(text):
|
| 36 |
document = ExampleDocument(text)
|
|
|
|
| 38 |
ner_pipeline(document)
|
| 39 |
|
| 40 |
while len(document.entities.predictions) > 0:
|
| 41 |
+
entity = document.entities.predictions.pop(0)
|
| 42 |
+
if entity.label in ner_tag_mapping:
|
| 43 |
+
entity = LabeledSpan(start=entity.start, end=entity.end, label=ner_tag_mapping[entity.label],
|
| 44 |
+
score=entity.score)
|
| 45 |
+
if entity.label in re_pipeline.taskmodule.entity_labels:
|
| 46 |
+
document.entities.append(entity)
|
| 47 |
|
| 48 |
re_pipeline(document)
|
| 49 |
|
| 50 |
+
|
| 51 |
t = PrettyTable()
|
| 52 |
t.field_names = ["head", "tail", "relation"]
|
| 53 |
t.align = "l"
|