Spaces:
Sleeping
Sleeping
| import torch | |
| import json | |
| import streamlit as st | |
| from typing import List | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from gliner import GLiNER # N'oubliez pas d'ajouter 'gliner' dans requirements.txt | |
| # --- SCHÉMAS DE DONNÉES --- | |
| class Entity(BaseModel): | |
| id: str = Field(description="ID unique (ex: E1).") | |
| name: str = Field(description="Nom exact trouvé.") | |
| type: str = Field(description="Catégorie détectée.") | |
| description: str = Field(description="Rôle ou contexte.") | |
| class Relationship(BaseModel): | |
| source: str = Field(alias="from", description="ID source.") | |
| target: str = Field(alias="to", description="ID cible.") | |
| type: str = Field(description="Verbe d'action court.") | |
| description: str = Field(description="Détails du lien.") | |
| class KnowledgeGraph(BaseModel): | |
| entities: List[Entity] | |
| relationships: List[Relationship] | |
| class ExtractorEngine: | |
| def __init__(self): | |
| self.model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
| if 'llm_model' not in st.session_state: | |
| with st.spinner("🚀 Chargement des cerveaux IA (CPU)..."): | |
| # Chargement Qwen (Compréhension & Relations) | |
| st.session_state.llm_tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| st.session_state.llm_model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=False | |
| ) | |
| # Chargement GLiNER (Extraction de précision) | |
| st.session_state.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1") | |
| self.tokenizer = st.session_state.llm_tokenizer | |
| self.model = st.session_state.llm_model | |
| self.gliner = st.session_state.gliner_model | |
| self.json_schema = json.dumps(KnowledgeGraph.model_json_schema(), indent=2) | |
| def extract_long_text(self, text: str, temperature: float, chunk_size: int = 3500): | |
| chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
| final_graph = {"entities": [], "relationships": []} | |
| entity_map = {} | |
| for chunk in chunks: | |
| # 1. Le LLM identifie dynamiquement les catégories importantes | |
| dynamic_labels = self._get_labels_from_llm(chunk) | |
| # 2. GLiNER extrait les entités avec ces labels | |
| gliner_entities = self.gliner.predict_entities(chunk, dynamic_labels, threshold=0.4) | |
| # 3. Le LLM tisse les relations basées sur les entités GLiNER | |
| raw_res = self._run_inference_with_entities(chunk, gliner_entities, temperature) | |
| if raw_res: | |
| current_chunk_map = {} | |
| for ent in raw_res.get("entities", []): | |
| name_key = ent["name"].lower().strip() | |
| if name_key not in entity_map: | |
| new_id = f"E{len(entity_map) + 1}" | |
| entity_map[name_key] = new_id | |
| ent["id"] = new_id | |
| final_graph["entities"].append(ent) | |
| current_chunk_map[ent["id"]] = entity_map[name_key] | |
| for rel in raw_res.get("relationships", []): | |
| rel["from"] = current_chunk_map.get(rel["from"], rel["from"]) | |
| rel["to"] = current_chunk_map.get(rel["to"], rel["to"]) | |
| final_graph["relationships"].append(rel) | |
| return final_graph | |
| def _get_dynamic_labels(self, text: str): | |
| """ | |
| Analyse le texte intégral pour générer des catégories d'extraction | |
| exhaustives et uniques. | |
| """ | |
| # Prompt pour une analyse totale et sans perte | |
| prompt = f"""Tu es un analyste expert en extraction de connaissances. | |
| Analyse l'intégralité du texte ci-dessous et liste tous les types d'entités (catégories) | |
| nécessaires pour reconstruire ce document sous forme de graphe sans perte de précision. | |
| Cherche : Acteurs, Méthodologies, Chiffres clés, Unités de mesure, Dates, Lieux, | |
| Variables, Fichiers sources, et Conditions contractuelles. | |
| TEXTE COMPLET : | |
| {text} | |
| Réponds uniquement par une liste de mots simples séparés par des virgules :""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu") | |
| with torch.no_grad(): | |
| # On laisse un peu plus de tokens pour une liste riche | |
| outputs = self.model.generate(**inputs, max_new_tokens=150) | |
| res = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| # --- LOGIQUE DE NETTOYAGE ET DÉDUPLICATION --- | |
| raw_labels = res.split(",") | |
| clean_labels = [] | |
| seen = set() | |
| for l in raw_labels: | |
| # Nettoyage : retrait des espaces, mise en minuscule pour comparer | |
| label = l.strip().replace(".", "").replace("\n", "") | |
| if len(label) > 2: | |
| # On normalise (singulier et minuscule) pour éviter les doublons | |
| norm_label = label.lower().rstrip('s') | |
| if norm_label not in seen: | |
| seen.add(norm_label) | |
| clean_labels.append(label.capitalize()) # On garde un joli format (ex: "Montant") | |
| return clean_labels | |
| def _run_inference_with_entities(self, text: str, gliner_ents: list, temperature: float): | |
| """Phase de liaison : le LLM crée le graphe JSON final.""" | |
| # On injecte les entités détectées par GLiNER dans le prompt | |
| ents_str = "\n".join([f"- {e['text']} ({e['label']})" for e in gliner_ents]) | |
| system_prompt = """Tu es un expert en graphes de connaissance. | |
| Utilise les ENTITÉS extraites pour créer des RELATIONS précises basées sur le TEXTE. | |
| Les relations doivent être des verbes courts en MAJUSCULES. | |
| Utilise uniquement les verbes présents dans le texte source. | |
| Utilise EXCLUSIVEMENT les identifiants fournis dans la liste des entités pour remplir les champs 'from' et 'to'. | |
| Ne réutilise jamais le nom complet de l'entité dans une relation. | |
| Réponds strictement en JSON sans explications.""" | |
| user_prompt = f"SCHÉMA:\n{self.json_schema}\n\nENTITÉS DÉTECTÉES:\n{ents_str}\n\nTEXTE:\n{text}\n\nJSON:" | |
| # Inférence classique | |
| inputs = self.tokenizer.apply_chat_template( | |
| [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], | |
| tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ).to("cpu") | |
| with torch.no_grad(): | |
| outputs = self.model.generate(inputs, max_new_tokens=1500, temperature=temperature, do_sample=True if temperature > 0.1 else False) | |
| res_text = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| try: | |
| return json.loads(self._clean(res_text)) | |
| except: | |
| return None | |
| def _clean(self, t): | |
| t = t.strip() | |
| start, end = t.find('{'), t.rfind('}') + 1 | |
| return t[start:end] if start != -1 and end != 0 else t |