Spaces:
Running
Running
| import io | |
| import json | |
| import re | |
| from typing import Any, Dict, List | |
| import gradio as gr | |
| import subprocess | |
| subprocess.run(["git", "clone", "https://github.com/robinarmingaud/glidre"]) | |
| subprocess.run(["pip", "install", "./glidre"]) | |
| subprocess.run(["pip", "install", "networkx"]) | |
| subprocess.run(["pip", "install", "matplotlib"]) | |
| subprocess.run(["pip", "install", "spacy"]) | |
| subprocess.run([ | |
| "pip", "install", | |
| "https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.8.0/xx_ent_wiki_sm-3.8.0-py3-none-any.whl" | |
| ]) | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import pandas as pd | |
| import spacy | |
| nlp = spacy.load("xx_ent_wiki_sm") | |
| from glidre import GLiDRE | |
| MODEL_NAME = "cea-list-ia/glidre_multi" | |
| model = GLiDRE.from_pretrained(MODEL_NAME) | |
| def parse_labels(labels_text: str) -> List[str]: | |
| if not labels_text: | |
| return [] | |
| parts = [p.strip() for p in labels_text.replace("\n", ",").split(",")] | |
| parts = [p for p in parts if p] | |
| return [p.upper() for p in parts] | |
| def df_to_mentions(df_rows, text: str = "") -> tuple: | |
| if isinstance(df_rows, pd.DataFrame): | |
| df_rows = df_rows.values.tolist() | |
| entity_map: Dict[int, Dict[str, Any]] = {} | |
| warnings: List[str] = [] | |
| auto_id = 0 | |
| for row in df_rows: | |
| if not row or all(v is None or str(v).strip() == "" for v in row): | |
| continue | |
| try: | |
| rid = int(row[0]) if row[0] is not None and str(row[0]).strip() != "" else auto_id | |
| except Exception: | |
| rid = auto_id | |
| auto_id = max(auto_id, rid) + 1 | |
| rtype = (str(row[1]).strip() if len(row) > 1 and row[1] not in (None, "") else "MISC") or "MISC" | |
| value = str(row[2]).strip() if len(row) > 2 and row[2] is not None else "" | |
| try: | |
| start = int(row[3]) if len(row) > 3 and row[3] not in (None, "") and str(row[3]).strip() != "" else None | |
| end = int(row[4]) if len(row) > 4 and row[4] not in (None, "") and str(row[4]).strip() != "" else None | |
| except Exception: | |
| start = None | |
| end = None | |
| if (start is None or end is None) and value and text: | |
| match = re.search(re.escape(value), text) | |
| if match: | |
| start, end = match.start(), match.end() | |
| else: | |
| warnings.append(f"Could not find '{value}' in text; span left as None.") | |
| m: Dict[str, Any] = {"value": value, "start": start, "end": end} | |
| if rid not in entity_map: | |
| entity_map[rid] = {"id": rid, "type": rtype, "mentions": []} | |
| entity_map[rid]["mentions"].append(m) | |
| return list(entity_map.values()), warnings | |
| def auto_annotate(text: str): | |
| if not text or not text.strip(): | |
| return [] | |
| doc = nlp(text) | |
| span_to_id: Dict[str, int] = {} | |
| rows = [] | |
| next_id = 0 | |
| for ent in doc.ents: | |
| surface = ent.text | |
| if surface not in span_to_id: | |
| span_to_id[surface] = next_id | |
| next_id += 1 | |
| eid = span_to_id[surface] | |
| rows.append([eid, ent.label_, surface, ent.start_char, ent.end_char]) | |
| return rows | |
| def relations_to_table(relations: List[Dict[str, Any]]) -> List[List[Any]]: | |
| rows = [] | |
| for r in relations: | |
| e1 = r.get("entity_1", {}) | |
| e2 = r.get("entity_2", {}) | |
| rows.append([ | |
| e1[0].get("id"), | |
| e1[0].get("text", ""), | |
| r.get("relation_type", ""), | |
| e2[0].get("id", ""), | |
| e2[0].get("text", ""), | |
| r.get("score", ""), | |
| ]) | |
| return rows | |
| def draw_relation_graph(relations: List[Dict[str, Any]]): | |
| G = nx.MultiDiGraph() | |
| node_labels = {} | |
| for r in relations: | |
| e1 = r.get("entity_1", {}) | |
| e2 = r.get("entity_2", {}) | |
| id1 = str(e1[0].get("id")) | |
| id2 = str(e2[0].get("id")) | |
| val1 = e1[0].get("text", id1) | |
| val2 = e2[0].get("text", id2) | |
| node_labels[id1] = f"{val1}\n({id1})" | |
| node_labels[id2] = f"{val2}\n({id2})" | |
| relation_label = r.get("relation_type", "") | |
| G.add_node(id1) | |
| G.add_node(id2) | |
| G.add_edge(id1, id2, label=relation_label) | |
| plt.figure(figsize=(8, 6)) | |
| pos = nx.spring_layout(G, k=0.8) | |
| nx.draw_networkx_nodes(G, pos, node_size=1200) | |
| nx.draw_networkx_labels(G, pos, labels=node_labels) | |
| nx.draw_networkx_edges(G, pos, arrows=True) | |
| edge_labels = dict(((u, v), data["label"]) for u, v, key, data in G.edges(data=True, keys=True)) | |
| nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) | |
| plt.axis("off") | |
| buf = io.BytesIO() | |
| plt.tight_layout() | |
| plt.savefig(buf, format="png", bbox_inches="tight") | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def predict(text: str, labels_text: str, entities_df, threshold: float, multi_label: bool): | |
| labels = parse_labels(labels_text) | |
| mentions, warnings = df_to_mentions(entities_df, text=text) | |
| if not mentions: | |
| msg = "⚠️ No valid mentions found. Please fill in the Mentions table with id, value for each entity." | |
| if warnings: | |
| msg += "\n" + "\n".join(warnings) | |
| return ([], None, f"<pre>{msg}</pre>", "") | |
| relations = model.predict_entities(text=text, labels=labels, mentions=mentions, threshold=threshold, multi_label=multi_label) | |
| table = relations_to_table(relations) | |
| graph_img = draw_relation_graph(relations) | |
| warning_html = "" | |
| if warnings: | |
| warning_html = "<p style='color:orange'>⚠️ " + "<br>".join(warnings) + "</p>" | |
| raw_json = json.dumps(relations, indent=2) | |
| return (table, graph_img, raw_json) | |
| with gr.Blocks(title="GLiDRE — Gradio demo") as demo: | |
| gr.HTML(""" | |
| <div id="logo" style="display: flex;align-items: center"> | |
| <img src="/gradio_api/file=images/LIST.png" alt="Logo" width="200"> | |
| </div> | |
| """) | |
| gr.Markdown("# GLiDRE DEMO") | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| text_input = gr.Textbox(label="Document text", value="The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna.", lines=6) | |
| labels_input = gr.Textbox(label="Relation labels (comma-separated)", value="COUNTRY_OF_CITIZENSHIP, PUBLICATION_DATE, PART_OF") | |
| annotate_btn = gr.Button("🔍 Auto-annotate using spaCy", variant="secondary") | |
| entities_df = gr.Dataframe(headers=["id", "type", "value", "start", "end"], datatype=["number", "text", "text", "number", "number"], interactive=True, label="Mentions", column_count=5) | |
| with gr.Row(): | |
| threshold = gr.Slider(0.0, 1.0, value=0.3, label="Threshold", step = 0.05) | |
| multi_label = gr.Checkbox(label="Allow multi-label (one mention pair can have multiple relations)", value=True) | |
| run = gr.Button("▶ Run prediction", variant="primary", elem_id="run-btn") | |
| gr.Examples( | |
| label="Examples (click to load)", | |
| examples=[ | |
| [ | |
| "Rihanna released her album in 2016. The artist won several awards that year.", | |
| "AWARD_RECEIVED, RELEASE_DATE", | |
| [[0, "PERSON", "Rihanna", 0, 7], [0, "PERSON", "artist", 40, 46], [1, "MISC", "album", 21, 26], [2, "DATE", "2016", 30, 34]] | |
| ], | |
| [ | |
| "Steve Jobs co-founded Apple. Jobs also served as its CEO until 2011.", | |
| "FOUNDER, EMPLOYEE_OF, END_DATE", | |
| [[0, "PERSON", "Steve Jobs", 0, 10], [0, "PERSON", "Jobs", 29, 33], [1, "ORG", "Apple", 22, 27], [2, "DATE", "2011", 63, 67]] | |
| ], | |
| [ | |
| "Marie Curie, born in Warsaw, won the Nobel Prize in Physics in 1903.", | |
| "PLACE_OF_BIRTH, AWARD_RECEIVED", | |
| [[0, "PERSON", "Marie Curie", 0, 11], [1, "LOC", "Warsaw", 21, 27], [2, "MISC", "Nobel Prize in Physics", 37, 59], [3, "DATE", "1903", 63, 67]] | |
| ] | |
| ], | |
| inputs=[text_input, labels_input, entities_df] | |
| ) | |
| with gr.Column(scale=4): | |
| relations_table = gr.Dataframe(headers=["e1_id", "e1_text", "relation", "e2_id", "e2_text", "score"], interactive=False) | |
| graph_out = gr.Image(label="Relation graph", type="pil") | |
| raw_json_out = gr.Textbox(label="Raw JSON output", lines=12) | |
| annotate_btn.click(fn=auto_annotate, inputs=[text_input], outputs=[entities_df]) | |
| run.click(fn=predict, inputs=[text_input, labels_input, entities_df, threshold, multi_label], outputs=[relations_table, graph_out, raw_json_out]) | |
| if __name__ == "__main__": | |
| demo.launch(favicon_path="images/LIST.svg", debug=True, allowed_paths=["images"]) | |