import torch import json import streamlit as st from typing import List from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM from gliner import GLiNER # N'oubliez pas d'ajouter 'gliner' dans requirements.txt # --- SCHÉMAS DE DONNÉES --- class Entity(BaseModel): id: str = Field(description="ID unique (ex: E1).") name: str = Field(description="Nom exact trouvé.") type: str = Field(description="Catégorie détectée.") description: str = Field(description="Rôle ou contexte.") class Relationship(BaseModel): source: str = Field(alias="from", description="ID source.") target: str = Field(alias="to", description="ID cible.") type: str = Field(description="Verbe d'action court.") description: str = Field(description="Détails du lien.") class KnowledgeGraph(BaseModel): entities: List[Entity] relationships: List[Relationship] class ExtractorEngine: def __init__(self): self.model_name = "Qwen/Qwen2.5-1.5B-Instruct" if 'llm_model' not in st.session_state: with st.spinner("🚀 Chargement des cerveaux IA (CPU)..."): # Chargement Qwen (Compréhension & Relations) st.session_state.llm_tokenizer = AutoTokenizer.from_pretrained(self.model_name) st.session_state.llm_model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=False ) # Chargement GLiNER (Extraction de précision) st.session_state.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1") self.tokenizer = st.session_state.llm_tokenizer self.model = st.session_state.llm_model self.gliner = st.session_state.gliner_model self.json_schema = json.dumps(KnowledgeGraph.model_json_schema(), indent=2) def extract_long_text(self, text: str, temperature: float, chunk_size: int = 3500): chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] final_graph = {"entities": [], "relationships": []} entity_map = {} for chunk in chunks: # 1. Le LLM identifie dynamiquement les catégories importantes dynamic_labels = self._get_labels_from_llm(chunk) # 2. GLiNER extrait les entités avec ces labels gliner_entities = self.gliner.predict_entities(chunk, dynamic_labels, threshold=0.4) # 3. Le LLM tisse les relations basées sur les entités GLiNER raw_res = self._run_inference_with_entities(chunk, gliner_entities, temperature) if raw_res: current_chunk_map = {} for ent in raw_res.get("entities", []): name_key = ent["name"].lower().strip() if name_key not in entity_map: new_id = f"E{len(entity_map) + 1}" entity_map[name_key] = new_id ent["id"] = new_id final_graph["entities"].append(ent) current_chunk_map[ent["id"]] = entity_map[name_key] for rel in raw_res.get("relationships", []): rel["from"] = current_chunk_map.get(rel["from"], rel["from"]) rel["to"] = current_chunk_map.get(rel["to"], rel["to"]) final_graph["relationships"].append(rel) return final_graph def _get_dynamic_labels(self, text: str): """ Analyse le texte intégral pour générer des catégories d'extraction exhaustives et uniques. """ # Prompt pour une analyse totale et sans perte prompt = f"""Tu es un analyste expert en extraction de connaissances. Analyse l'intégralité du texte ci-dessous et liste tous les types d'entités (catégories) nécessaires pour reconstruire ce document sous forme de graphe sans perte de précision. Cherche : Acteurs, Méthodologies, Chiffres clés, Unités de mesure, Dates, Lieux, Variables, Fichiers sources, et Conditions contractuelles. TEXTE COMPLET : {text} Réponds uniquement par une liste de mots simples séparés par des virgules :""" inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu") with torch.no_grad(): # On laisse un peu plus de tokens pour une liste riche outputs = self.model.generate(**inputs, max_new_tokens=150) res = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) # --- LOGIQUE DE NETTOYAGE ET DÉDUPLICATION --- raw_labels = res.split(",") clean_labels = [] seen = set() for l in raw_labels: # Nettoyage : retrait des espaces, mise en minuscule pour comparer label = l.strip().replace(".", "").replace("\n", "") if len(label) > 2: # On normalise (singulier et minuscule) pour éviter les doublons norm_label = label.lower().rstrip('s') if norm_label not in seen: seen.add(norm_label) clean_labels.append(label.capitalize()) # On garde un joli format (ex: "Montant") return clean_labels def _run_inference_with_entities(self, text: str, gliner_ents: list, temperature: float): """Phase de liaison : le LLM crée le graphe JSON final.""" # On injecte les entités détectées par GLiNER dans le prompt ents_str = "\n".join([f"- {e['text']} ({e['label']})" for e in gliner_ents]) system_prompt = """Tu es un expert en graphes de connaissance. Utilise les ENTITÉS extraites pour créer des RELATIONS précises basées sur le TEXTE. Les relations doivent être des verbes courts en MAJUSCULES. Utilise uniquement les verbes présents dans le texte source. Utilise EXCLUSIVEMENT les identifiants fournis dans la liste des entités pour remplir les champs 'from' et 'to'. Ne réutilise jamais le nom complet de l'entité dans une relation. Réponds strictement en JSON sans explications.""" user_prompt = f"SCHÉMA:\n{self.json_schema}\n\nENTITÉS DÉTECTÉES:\n{ents_str}\n\nTEXTE:\n{text}\n\nJSON:" # Inférence classique inputs = self.tokenizer.apply_chat_template( [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to("cpu") with torch.no_grad(): outputs = self.model.generate(inputs, max_new_tokens=1500, temperature=temperature, do_sample=True if temperature > 0.1 else False) res_text = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) try: return json.loads(self._clean(res_text)) except: return None def _clean(self, t): t = t.strip() start, end = t.find('{'), t.rfind('}') + 1 return t[start:end] if start != -1 and end != 0 else t