import io import json import re from typing import Any, Dict, List import gradio as gr import subprocess subprocess.run(["git", "clone", "https://github.com/robinarmingaud/glidre"]) subprocess.run(["pip", "install", "./glidre"]) subprocess.run(["pip", "install", "networkx"]) subprocess.run(["pip", "install", "matplotlib"]) subprocess.run(["pip", "install", "spacy"]) subprocess.run([ "pip", "install", "https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.8.0/xx_ent_wiki_sm-3.8.0-py3-none-any.whl" ]) import networkx as nx import matplotlib.pyplot as plt from PIL import Image import pandas as pd import spacy nlp = spacy.load("xx_ent_wiki_sm") from glidre import GLiDRE MODEL_NAME = "cea-list-ia/glidre_multi" model = GLiDRE.from_pretrained(MODEL_NAME) def parse_labels(labels_text: str) -> List[str]: if not labels_text: return [] parts = [p.strip() for p in labels_text.replace("\n", ",").split(",")] parts = [p for p in parts if p] return [p.upper() for p in parts] def df_to_mentions(df_rows, text: str = "") -> tuple: if isinstance(df_rows, pd.DataFrame): df_rows = df_rows.values.tolist() entity_map: Dict[int, Dict[str, Any]] = {} warnings: List[str] = [] auto_id = 0 for row in df_rows: if not row or all(v is None or str(v).strip() == "" for v in row): continue try: rid = int(row[0]) if row[0] is not None and str(row[0]).strip() != "" else auto_id except Exception: rid = auto_id auto_id = max(auto_id, rid) + 1 rtype = (str(row[1]).strip() if len(row) > 1 and row[1] not in (None, "") else "MISC") or "MISC" value = str(row[2]).strip() if len(row) > 2 and row[2] is not None else "" try: start = int(row[3]) if len(row) > 3 and row[3] not in (None, "") and str(row[3]).strip() != "" else None end = int(row[4]) if len(row) > 4 and row[4] not in (None, "") and str(row[4]).strip() != "" else None except Exception: start = None end = None if (start is None or end is None) and value and text: match = re.search(re.escape(value), text) if match: start, end = match.start(), match.end() else: warnings.append(f"Could not find '{value}' in text; span left as None.") m: Dict[str, Any] = {"value": value, "start": start, "end": end} if rid not in entity_map: entity_map[rid] = {"id": rid, "type": rtype, "mentions": []} entity_map[rid]["mentions"].append(m) return list(entity_map.values()), warnings def auto_annotate(text: str): if not text or not text.strip(): return [] doc = nlp(text) span_to_id: Dict[str, int] = {} rows = [] next_id = 0 for ent in doc.ents: surface = ent.text if surface not in span_to_id: span_to_id[surface] = next_id next_id += 1 eid = span_to_id[surface] rows.append([eid, ent.label_, surface, ent.start_char, ent.end_char]) return rows def relations_to_table(relations: List[Dict[str, Any]]) -> List[List[Any]]: rows = [] for r in relations: e1 = r.get("entity_1", {}) e2 = r.get("entity_2", {}) rows.append([ e1[0].get("id"), e1[0].get("text", ""), r.get("relation_type", ""), e2[0].get("id", ""), e2[0].get("text", ""), r.get("score", ""), ]) return rows def draw_relation_graph(relations: List[Dict[str, Any]]): G = nx.MultiDiGraph() node_labels = {} for r in relations: e1 = r.get("entity_1", {}) e2 = r.get("entity_2", {}) id1 = str(e1[0].get("id")) id2 = str(e2[0].get("id")) val1 = e1[0].get("text", id1) val2 = e2[0].get("text", id2) node_labels[id1] = f"{val1}\n({id1})" node_labels[id2] = f"{val2}\n({id2})" relation_label = r.get("relation_type", "") G.add_node(id1) G.add_node(id2) G.add_edge(id1, id2, label=relation_label) plt.figure(figsize=(8, 6)) pos = nx.spring_layout(G, k=0.8) nx.draw_networkx_nodes(G, pos, node_size=1200) nx.draw_networkx_labels(G, pos, labels=node_labels) nx.draw_networkx_edges(G, pos, arrows=True) edge_labels = dict(((u, v), data["label"]) for u, v, key, data in G.edges(data=True, keys=True)) nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) plt.axis("off") buf = io.BytesIO() plt.tight_layout() plt.savefig(buf, format="png", bbox_inches="tight") plt.close() buf.seek(0) return Image.open(buf) def predict(text: str, labels_text: str, entities_df, threshold: float, multi_label: bool): labels = parse_labels(labels_text) mentions, warnings = df_to_mentions(entities_df, text=text) if not mentions: msg = "⚠️ No valid mentions found. Please fill in the Mentions table with id, value for each entity." if warnings: msg += "\n" + "\n".join(warnings) return ([], None, f"
{msg}
", "") relations = model.predict_entities(text=text, labels=labels, mentions=mentions, threshold=threshold, multi_label=multi_label) table = relations_to_table(relations) graph_img = draw_relation_graph(relations) warning_html = "" if warnings: warning_html = "

⚠️ " + "
".join(warnings) + "

" raw_json = json.dumps(relations, indent=2) return (table, graph_img, raw_json) with gr.Blocks(title="GLiDRE — Gradio demo") as demo: gr.HTML(""" """) gr.Markdown("# GLiDRE DEMO") with gr.Row(): with gr.Column(scale=6): text_input = gr.Textbox(label="Document text", value="The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna.", lines=6) labels_input = gr.Textbox(label="Relation labels (comma-separated)", value="COUNTRY_OF_CITIZENSHIP, PUBLICATION_DATE, PART_OF") annotate_btn = gr.Button("🔍 Auto-annotate using spaCy", variant="secondary") entities_df = gr.Dataframe(headers=["id", "type", "value", "start", "end"], datatype=["number", "text", "text", "number", "number"], interactive=True, label="Mentions", column_count=5) with gr.Row(): threshold = gr.Slider(0.0, 1.0, value=0.3, label="Threshold", step = 0.05) multi_label = gr.Checkbox(label="Allow multi-label (one mention pair can have multiple relations)", value=True) run = gr.Button("▶ Run prediction", variant="primary", elem_id="run-btn") gr.Examples( label="Examples (click to load)", examples=[ [ "Rihanna released her album in 2016. The artist won several awards that year.", "AWARD_RECEIVED, RELEASE_DATE", [[0, "PERSON", "Rihanna", 0, 7], [0, "PERSON", "artist", 40, 46], [1, "MISC", "album", 21, 26], [2, "DATE", "2016", 30, 34]] ], [ "Steve Jobs co-founded Apple. Jobs also served as its CEO until 2011.", "FOUNDER, EMPLOYEE_OF, END_DATE", [[0, "PERSON", "Steve Jobs", 0, 10], [0, "PERSON", "Jobs", 29, 33], [1, "ORG", "Apple", 22, 27], [2, "DATE", "2011", 63, 67]] ], [ "Marie Curie, born in Warsaw, won the Nobel Prize in Physics in 1903.", "PLACE_OF_BIRTH, AWARD_RECEIVED", [[0, "PERSON", "Marie Curie", 0, 11], [1, "LOC", "Warsaw", 21, 27], [2, "MISC", "Nobel Prize in Physics", 37, 59], [3, "DATE", "1903", 63, 67]] ] ], inputs=[text_input, labels_input, entities_df] ) with gr.Column(scale=4): relations_table = gr.Dataframe(headers=["e1_id", "e1_text", "relation", "e2_id", "e2_text", "score"], interactive=False) graph_out = gr.Image(label="Relation graph", type="pil") raw_json_out = gr.Textbox(label="Raw JSON output", lines=12) annotate_btn.click(fn=auto_annotate, inputs=[text_input], outputs=[entities_df]) run.click(fn=predict, inputs=[text_input, labels_input, entities_df, threshold, multi_label], outputs=[relations_table, graph_out, raw_json_out]) if __name__ == "__main__": demo.launch(favicon_path="images/LIST.svg", debug=True, allowed_paths=["images"])