import io import json import re from typing import Any, Dict, List import gradio as gr import subprocess subprocess.run(["git", "clone", "https://github.com/robinarmingaud/glidre"]) subprocess.run(["pip", "install", "./glidre"]) subprocess.run(["pip", "install", "networkx"]) subprocess.run(["pip", "install", "matplotlib"]) subprocess.run(["pip", "install", "spacy"]) subprocess.run([ "pip", "install", "https://github.com/explosion/spacy-models/releases/download/xx_ent_wiki_sm-3.8.0/xx_ent_wiki_sm-3.8.0-py3-none-any.whl" ]) import networkx as nx import matplotlib.pyplot as plt from PIL import Image import pandas as pd import spacy nlp = spacy.load("xx_ent_wiki_sm") from glidre import GLiDRE MODEL_NAME = "cea-list-ia/glidre_multi" model = GLiDRE.from_pretrained(MODEL_NAME) def parse_labels(labels_text: str) -> List[str]: if not labels_text: return [] parts = [p.strip() for p in labels_text.replace("\n", ",").split(",")] parts = [p for p in parts if p] return [p.upper() for p in parts] def df_to_mentions(df_rows, text: str = "") -> tuple: if isinstance(df_rows, pd.DataFrame): df_rows = df_rows.values.tolist() entity_map: Dict[int, Dict[str, Any]] = {} warnings: List[str] = [] auto_id = 0 for row in df_rows: if not row or all(v is None or str(v).strip() == "" for v in row): continue try: rid = int(row[0]) if row[0] is not None and str(row[0]).strip() != "" else auto_id except Exception: rid = auto_id auto_id = max(auto_id, rid) + 1 rtype = (str(row[1]).strip() if len(row) > 1 and row[1] not in (None, "") else "MISC") or "MISC" value = str(row[2]).strip() if len(row) > 2 and row[2] is not None else "" try: start = int(row[3]) if len(row) > 3 and row[3] not in (None, "") and str(row[3]).strip() != "" else None end = int(row[4]) if len(row) > 4 and row[4] not in (None, "") and str(row[4]).strip() != "" else None except Exception: start = None end = None if (start is None or end is None) and value and text: match = re.search(re.escape(value), text) if match: start, end = match.start(), match.end() else: warnings.append(f"Could not find '{value}' in text; span left as None.") m: Dict[str, Any] = {"value": value, "start": start, "end": end} if rid not in entity_map: entity_map[rid] = {"id": rid, "type": rtype, "mentions": []} entity_map[rid]["mentions"].append(m) return list(entity_map.values()), warnings def auto_annotate(text: str): if not text or not text.strip(): return [] doc = nlp(text) span_to_id: Dict[str, int] = {} rows = [] next_id = 0 for ent in doc.ents: surface = ent.text if surface not in span_to_id: span_to_id[surface] = next_id next_id += 1 eid = span_to_id[surface] rows.append([eid, ent.label_, surface, ent.start_char, ent.end_char]) return rows def relations_to_table(relations: List[Dict[str, Any]]) -> List[List[Any]]: rows = [] for r in relations: e1 = r.get("entity_1", {}) e2 = r.get("entity_2", {}) rows.append([ e1[0].get("id"), e1[0].get("text", ""), r.get("relation_type", ""), e2[0].get("id", ""), e2[0].get("text", ""), r.get("score", ""), ]) return rows def draw_relation_graph(relations: List[Dict[str, Any]]): G = nx.MultiDiGraph() node_labels = {} for r in relations: e1 = r.get("entity_1", {}) e2 = r.get("entity_2", {}) id1 = str(e1[0].get("id")) id2 = str(e2[0].get("id")) val1 = e1[0].get("text", id1) val2 = e2[0].get("text", id2) node_labels[id1] = f"{val1}\n({id1})" node_labels[id2] = f"{val2}\n({id2})" relation_label = r.get("relation_type", "") G.add_node(id1) G.add_node(id2) G.add_edge(id1, id2, label=relation_label) plt.figure(figsize=(8, 6)) pos = nx.spring_layout(G, k=0.8) nx.draw_networkx_nodes(G, pos, node_size=1200) nx.draw_networkx_labels(G, pos, labels=node_labels) nx.draw_networkx_edges(G, pos, arrows=True) edge_labels = dict(((u, v), data["label"]) for u, v, key, data in G.edges(data=True, keys=True)) nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) plt.axis("off") buf = io.BytesIO() plt.tight_layout() plt.savefig(buf, format="png", bbox_inches="tight") plt.close() buf.seek(0) return Image.open(buf) def predict(text: str, labels_text: str, entities_df, threshold: float, multi_label: bool): labels = parse_labels(labels_text) mentions, warnings = df_to_mentions(entities_df, text=text) if not mentions: msg = "⚠️ No valid mentions found. Please fill in the Mentions table with id, value for each entity." if warnings: msg += "\n" + "\n".join(warnings) return ([], None, f"
{msg}", "")
relations = model.predict_entities(text=text, labels=labels, mentions=mentions, threshold=threshold, multi_label=multi_label)
table = relations_to_table(relations)
graph_img = draw_relation_graph(relations)
warning_html = ""
if warnings:
warning_html = "⚠️ " + "
".join(warnings) + "