GLiDRE / app.py
rarmingaud's picture
auto annotation
ea5bb02
import io
import json
import re
from typing import Any, Dict, List
import gradio as gr
import subprocess
subprocess.run(["git", "clone", "https://github.com/robinarmingaud/glidre"])
subprocess.run(["pip", "install", "./glidre"])
subprocess.run(["pip", "install", "networkx"])
subprocess.run(["pip", "install", "matplotlib"])
subprocess.run(["pip", "install", "spacy"])
subprocess.run([
"pip", "install",
"https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.8.0/xx_ent_wiki_sm-3.8.0-py3-none-any.whl"
])
import networkx as nx
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import spacy
nlp = spacy.load("xx_ent_wiki_sm")
from glidre import GLiDRE
MODEL_NAME = "cea-list-ia/glidre_multi"
model = GLiDRE.from_pretrained(MODEL_NAME)
def parse_labels(labels_text: str) -> List[str]:
if not labels_text:
return []
parts = [p.strip() for p in labels_text.replace("\n", ",").split(",")]
parts = [p for p in parts if p]
return [p.upper() for p in parts]
def df_to_mentions(df_rows, text: str = "") -> tuple:
if isinstance(df_rows, pd.DataFrame):
df_rows = df_rows.values.tolist()
entity_map: Dict[int, Dict[str, Any]] = {}
warnings: List[str] = []
auto_id = 0
for row in df_rows:
if not row or all(v is None or str(v).strip() == "" for v in row):
continue
try:
rid = int(row[0]) if row[0] is not None and str(row[0]).strip() != "" else auto_id
except Exception:
rid = auto_id
auto_id = max(auto_id, rid) + 1
rtype = (str(row[1]).strip() if len(row) > 1 and row[1] not in (None, "") else "MISC") or "MISC"
value = str(row[2]).strip() if len(row) > 2 and row[2] is not None else ""
try:
start = int(row[3]) if len(row) > 3 and row[3] not in (None, "") and str(row[3]).strip() != "" else None
end = int(row[4]) if len(row) > 4 and row[4] not in (None, "") and str(row[4]).strip() != "" else None
except Exception:
start = None
end = None
if (start is None or end is None) and value and text:
match = re.search(re.escape(value), text)
if match:
start, end = match.start(), match.end()
else:
warnings.append(f"Could not find '{value}' in text; span left as None.")
m: Dict[str, Any] = {"value": value, "start": start, "end": end}
if rid not in entity_map:
entity_map[rid] = {"id": rid, "type": rtype, "mentions": []}
entity_map[rid]["mentions"].append(m)
return list(entity_map.values()), warnings
def auto_annotate(text: str):
if not text or not text.strip():
return []
doc = nlp(text)
span_to_id: Dict[str, int] = {}
rows = []
next_id = 0
for ent in doc.ents:
surface = ent.text
if surface not in span_to_id:
span_to_id[surface] = next_id
next_id += 1
eid = span_to_id[surface]
rows.append([eid, ent.label_, surface, ent.start_char, ent.end_char])
return rows
def relations_to_table(relations: List[Dict[str, Any]]) -> List[List[Any]]:
rows = []
for r in relations:
e1 = r.get("entity_1", {})
e2 = r.get("entity_2", {})
rows.append([
e1[0].get("id"),
e1[0].get("text", ""),
r.get("relation_type", ""),
e2[0].get("id", ""),
e2[0].get("text", ""),
r.get("score", ""),
])
return rows
def draw_relation_graph(relations: List[Dict[str, Any]]):
G = nx.MultiDiGraph()
node_labels = {}
for r in relations:
e1 = r.get("entity_1", {})
e2 = r.get("entity_2", {})
id1 = str(e1[0].get("id"))
id2 = str(e2[0].get("id"))
val1 = e1[0].get("text", id1)
val2 = e2[0].get("text", id2)
node_labels[id1] = f"{val1}\n({id1})"
node_labels[id2] = f"{val2}\n({id2})"
relation_label = r.get("relation_type", "")
G.add_node(id1)
G.add_node(id2)
G.add_edge(id1, id2, label=relation_label)
plt.figure(figsize=(8, 6))
pos = nx.spring_layout(G, k=0.8)
nx.draw_networkx_nodes(G, pos, node_size=1200)
nx.draw_networkx_labels(G, pos, labels=node_labels)
nx.draw_networkx_edges(G, pos, arrows=True)
edge_labels = dict(((u, v), data["label"]) for u, v, key, data in G.edges(data=True, keys=True))
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plt.axis("off")
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close()
buf.seek(0)
return Image.open(buf)
def predict(text: str, labels_text: str, entities_df, threshold: float, multi_label: bool):
labels = parse_labels(labels_text)
mentions, warnings = df_to_mentions(entities_df, text=text)
if not mentions:
msg = "⚠️ No valid mentions found. Please fill in the Mentions table with id, value for each entity."
if warnings:
msg += "\n" + "\n".join(warnings)
return ([], None, f"<pre>{msg}</pre>", "")
relations = model.predict_entities(text=text, labels=labels, mentions=mentions, threshold=threshold, multi_label=multi_label)
table = relations_to_table(relations)
graph_img = draw_relation_graph(relations)
warning_html = ""
if warnings:
warning_html = "<p style='color:orange'>⚠️ " + "<br>".join(warnings) + "</p>"
raw_json = json.dumps(relations, indent=2)
return (table, graph_img, raw_json)
with gr.Blocks(title="GLiDRE — Gradio demo") as demo:
gr.HTML("""
<div id="logo" style="display: flex;align-items: center">
<img src="/gradio_api/file=images/LIST.png" alt="Logo" width="200">
</div>
""")
gr.Markdown("# GLiDRE DEMO")
with gr.Row():
with gr.Column(scale=6):
text_input = gr.Textbox(label="Document text", value="The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna.", lines=6)
labels_input = gr.Textbox(label="Relation labels (comma-separated)", value="COUNTRY_OF_CITIZENSHIP, PUBLICATION_DATE, PART_OF")
annotate_btn = gr.Button("🔍 Auto-annotate using spaCy", variant="secondary")
entities_df = gr.Dataframe(headers=["id", "type", "value", "start", "end"], datatype=["number", "text", "text", "number", "number"], interactive=True, label="Mentions", column_count=5)
with gr.Row():
threshold = gr.Slider(0.0, 1.0, value=0.3, label="Threshold", step = 0.05)
multi_label = gr.Checkbox(label="Allow multi-label (one mention pair can have multiple relations)", value=True)
run = gr.Button("▶ Run prediction", variant="primary", elem_id="run-btn")
gr.Examples(
label="Examples (click to load)",
examples=[
[
"Rihanna released her album in 2016. The artist won several awards that year.",
"AWARD_RECEIVED, RELEASE_DATE",
[[0, "PERSON", "Rihanna", 0, 7], [0, "PERSON", "artist", 40, 46], [1, "MISC", "album", 21, 26], [2, "DATE", "2016", 30, 34]]
],
[
"Steve Jobs co-founded Apple. Jobs also served as its CEO until 2011.",
"FOUNDER, EMPLOYEE_OF, END_DATE",
[[0, "PERSON", "Steve Jobs", 0, 10], [0, "PERSON", "Jobs", 29, 33], [1, "ORG", "Apple", 22, 27], [2, "DATE", "2011", 63, 67]]
],
[
"Marie Curie, born in Warsaw, won the Nobel Prize in Physics in 1903.",
"PLACE_OF_BIRTH, AWARD_RECEIVED",
[[0, "PERSON", "Marie Curie", 0, 11], [1, "LOC", "Warsaw", 21, 27], [2, "MISC", "Nobel Prize in Physics", 37, 59], [3, "DATE", "1903", 63, 67]]
]
],
inputs=[text_input, labels_input, entities_df]
)
with gr.Column(scale=4):
relations_table = gr.Dataframe(headers=["e1_id", "e1_text", "relation", "e2_id", "e2_text", "score"], interactive=False)
graph_out = gr.Image(label="Relation graph", type="pil")
raw_json_out = gr.Textbox(label="Raw JSON output", lines=12)
annotate_btn.click(fn=auto_annotate, inputs=[text_input], outputs=[entities_df])
run.click(fn=predict, inputs=[text_input, labels_input, entities_df, threshold, multi_label], outputs=[relations_table, graph_out, raw_json_out])
if __name__ == "__main__":
demo.launch(favicon_path="images/LIST.svg", debug=True, allowed_paths=["images"])