initial commit
Browse files
app.py
CHANGED
|
@@ -13,16 +13,40 @@ from classy.utils.streamlit import get_md_200_random_color_generator
|
|
| 13 |
|
| 14 |
def main(
|
| 15 |
model_checkpoint_path: str,
|
| 16 |
-
|
| 17 |
cuda_device: int,
|
| 18 |
):
|
| 19 |
# setup examples
|
| 20 |
examples = [
|
| 21 |
-
"
|
| 22 |
-
"Japan began the defence of their
|
| 23 |
"The project was coded in Java.",
|
| 24 |
]
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# css rules
|
| 27 |
st.write(
|
| 28 |
"""
|
|
@@ -69,13 +93,13 @@ def main(
|
|
| 69 |
Given the sentence *After a long fight Superman saved Metropolis*, where *Superman* is the mention
|
| 70 |
to disambiguate, ExtEnD first concatenates the descriptions of all the possible candidates of *Superman* in the
|
| 71 |
inventory and then selects the span whose description best suits the mention in its context.
|
| 72 |
-
|
| 73 |
-
To convert this task to end2end entity linking, as we do in *Model demo*, we leverage spaCy
|
| 74 |
-
(more specifically, its NER) and run ExtEnD on each named entity spaCy identifies
|
| 75 |
-
(if the corresponding mention is contained in the inventory).
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
* [GitHub](https://github.com/SapienzaNLP/extend)
|
| 80 |
"""
|
| 81 |
)
|
|
@@ -84,25 +108,6 @@ def main(
|
|
| 84 |
def demo():
|
| 85 |
st.markdown("## Demo")
|
| 86 |
|
| 87 |
-
@st.cache(allow_output_mutation=True)
|
| 88 |
-
def load_resources(inventory_path):
|
| 89 |
-
|
| 90 |
-
# load nlp
|
| 91 |
-
nlp = spacy.load("en_core_web_sm")
|
| 92 |
-
extend_config = dict(
|
| 93 |
-
checkpoint_path=model_checkpoint_path,
|
| 94 |
-
mentions_inventory_path=inventory_path,
|
| 95 |
-
device=cuda_device,
|
| 96 |
-
tokens_per_batch=10_000,
|
| 97 |
-
)
|
| 98 |
-
nlp.add_pipe("extend", after="ner", config=extend_config)
|
| 99 |
-
|
| 100 |
-
# mock call to load resources
|
| 101 |
-
nlp(examples[0])
|
| 102 |
-
|
| 103 |
-
# return
|
| 104 |
-
return nlp
|
| 105 |
-
|
| 106 |
# read input
|
| 107 |
placeholder = st.selectbox(
|
| 108 |
"Examples",
|
|
@@ -111,24 +116,14 @@ def main(
|
|
| 111 |
)
|
| 112 |
input_text = st.text_area("Input text to entity-disambiguate", placeholder)
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
"[Optional] Upload custom inventory (tsv file, mention \\t desc1 \\t desc2 \\t)",
|
| 117 |
-
accept_multiple_files=False,
|
| 118 |
-
type=["tsv"],
|
| 119 |
-
)
|
| 120 |
-
if uploaded_inventory_path is not None:
|
| 121 |
-
inventory_path = f"data/inventories/{uploaded_inventory_path.name}"
|
| 122 |
-
with open(inventory_path, "wb") as f:
|
| 123 |
-
f.write(uploaded_inventory_path.getbuffer())
|
| 124 |
-
else:
|
| 125 |
-
inventory_path = default_inventory_path
|
| 126 |
|
| 127 |
# load model and color generator
|
| 128 |
nlp = load_resources(inventory_path)
|
| 129 |
color_generator = get_md_200_random_color_generator()
|
| 130 |
|
| 131 |
-
if
|
| 132 |
|
| 133 |
# tag sentence
|
| 134 |
time_start = time.perf_counter()
|
|
@@ -184,7 +179,6 @@ def main(
|
|
| 184 |
hiw()
|
| 185 |
|
| 186 |
|
| 187 |
-
|
| 188 |
if __name__ == "__main__":
|
| 189 |
main(
|
| 190 |
"experiments/extend-longformer-large/2021-10-22/09-11-39/checkpoints/best.ckpt",
|
|
|
|
| 13 |
|
| 14 |
def main(
|
| 15 |
model_checkpoint_path: str,
|
| 16 |
+
inventory_path: str,
|
| 17 |
cuda_device: int,
|
| 18 |
):
|
| 19 |
# setup examples
|
| 20 |
examples = [
|
| 21 |
+
"Rome is in Italy",
|
| 22 |
+
"Japan began the defence of their title with a lucky 2-1 win against Syria in a Group C championship match on Friday.",
|
| 23 |
"The project was coded in Java.",
|
| 24 |
]
|
| 25 |
|
| 26 |
+
# define load_resources
|
| 27 |
+
|
| 28 |
+
@st.cache(allow_output_mutation=True)
|
| 29 |
+
def load_resources(inventory_path):
|
| 30 |
+
|
| 31 |
+
# load nlp
|
| 32 |
+
nlp = spacy.load("en_core_web_sm")
|
| 33 |
+
extend_config = dict(
|
| 34 |
+
checkpoint_path=model_checkpoint_path,
|
| 35 |
+
mentions_inventory_path=inventory_path,
|
| 36 |
+
device=cuda_device,
|
| 37 |
+
tokens_per_batch=10_000,
|
| 38 |
+
)
|
| 39 |
+
nlp.add_pipe("extend", after="ner", config=extend_config)
|
| 40 |
+
|
| 41 |
+
# mock call to load resources
|
| 42 |
+
nlp(examples[0])
|
| 43 |
+
|
| 44 |
+
# return
|
| 45 |
+
return nlp
|
| 46 |
+
|
| 47 |
+
# preload default resources
|
| 48 |
+
load_resources(inventory_path)
|
| 49 |
+
|
| 50 |
# css rules
|
| 51 |
st.write(
|
| 52 |
"""
|
|
|
|
| 93 |
Given the sentence *After a long fight Superman saved Metropolis*, where *Superman* is the mention
|
| 94 |
to disambiguate, ExtEnD first concatenates the descriptions of all the possible candidates of *Superman* in the
|
| 95 |
inventory and then selects the span whose description best suits the mention in its context.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
To use ExtEnD for full end2end entity linking, as we do in *Demo*, we just need to leverage a mention
|
| 98 |
+
identifier. Here [we use spaCy](https://github.com/SapienzaNLP/extend#spacy) (more specifically, its NER) and run ExtEnD on each named
|
| 99 |
+
entity spaCy identifies (if the corresponding mention is contained in the inventory).
|
| 100 |
+
|
| 101 |
+
##### Links:
|
| 102 |
+
* [Full Paper](https://www.researchgate.net/publication/359392427_ExtEnD_Extractive_Entity_Disambiguation)
|
| 103 |
* [GitHub](https://github.com/SapienzaNLP/extend)
|
| 104 |
"""
|
| 105 |
)
|
|
|
|
| 108 |
def demo():
|
| 109 |
st.markdown("## Demo")
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
# read input
|
| 112 |
placeholder = st.selectbox(
|
| 113 |
"Examples",
|
|
|
|
| 116 |
)
|
| 117 |
input_text = st.text_area("Input text to entity-disambiguate", placeholder)
|
| 118 |
|
| 119 |
+
# button
|
| 120 |
+
should_disambiguate = st.button("Disambiguate", key="classify")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
# load model and color generator
|
| 123 |
nlp = load_resources(inventory_path)
|
| 124 |
color_generator = get_md_200_random_color_generator()
|
| 125 |
|
| 126 |
+
if should_disambiguate:
|
| 127 |
|
| 128 |
# tag sentence
|
| 129 |
time_start = time.perf_counter()
|
|
|
|
| 179 |
hiw()
|
| 180 |
|
| 181 |
|
|
|
|
| 182 |
if __name__ == "__main__":
|
| 183 |
main(
|
| 184 |
"experiments/extend-longformer-large/2021-10-22/09-11-39/checkpoints/best.ckpt",
|