File size: 7,247 Bytes
f015124
 
e9ce464
c5ef5c1
f015124
 
2133bc9
f52f333
f015124
2133bc9
f015124
c5ef5c1
2133bc9
 
 
f015124
 
2133bc9
 
 
 
f015124
 
 
 
e9ce464
 
 
f015124
c5ef5c1
2133bc9
 
c5ef5c1
 
eb8b6cf
c5ef5c1
2133bc9
 
 
c5ef5c1
 
2133bc9
f015124
e9ce464
f3c07b5
c360f43
 
2133bc9
 
c360f43
2133bc9
 
 
 
 
 
 
 
 
c360f43
f3c07b5
c360f43
 
 
 
 
 
 
f3c07b5
2133bc9
c360f43
f3c07b5
 
c360f43
 
 
f015124
f433cad
75dd810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2133bc9
75dd810
 
 
245eae5
75dd810
245eae5
75dd810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2133bc9
 
 
 
 
c360f43
2133bc9
 
9fb8957
4781044
f433cad
 
2133bc9
 
 
 
 
 
 
 
 
 
 
 
faeffd5
2133bc9
 
 
 
 
 
faeffd5
2133bc9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import json
import streamlit as st
from typing import List
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM
from gliner import GLiNER # N'oubliez pas d'ajouter 'gliner' dans requirements.txt


# --- SCHÉMAS DE DONNÉES ---
class Entity(BaseModel):
    id: str = Field(description="ID unique (ex: E1).")
    name: str = Field(description="Nom exact trouvé.")
    type: str = Field(description="Catégorie détectée.")
    description: str = Field(description="Rôle ou contexte.")

class Relationship(BaseModel):
    source: str = Field(alias="from", description="ID source.")
    target: str = Field(alias="to", description="ID cible.")
    type: str = Field(description="Verbe d'action court.")
    description: str = Field(description="Détails du lien.")

class KnowledgeGraph(BaseModel):
    entities: List[Entity]
    relationships: List[Relationship]

class ExtractorEngine:
    def __init__(self):
        self.model_name = "Qwen/Qwen2.5-1.5B-Instruct"
        if 'llm_model' not in st.session_state:
            with st.spinner("🚀 Chargement des cerveaux IA (CPU)..."):
                # Chargement Qwen (Compréhension & Relations)
                st.session_state.llm_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                st.session_state.llm_model = AutoModelForCausalLM.from_pretrained(
                    self.model_name, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=False
                )
                # Chargement GLiNER (Extraction de précision)
                st.session_state.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
        
        self.tokenizer = st.session_state.llm_tokenizer
        self.model = st.session_state.llm_model
        self.gliner = st.session_state.gliner_model
        self.json_schema = json.dumps(KnowledgeGraph.model_json_schema(), indent=2)

    def extract_long_text(self, text: str, temperature: float, chunk_size: int = 3500):
        chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
        final_graph = {"entities": [], "relationships": []}
        entity_map = {}

        for chunk in chunks:
            # 1. Le LLM identifie dynamiquement les catégories importantes
            dynamic_labels = self._get_labels_from_llm(chunk)
            
            # 2. GLiNER extrait les entités avec ces labels
            gliner_entities = self.gliner.predict_entities(chunk, dynamic_labels, threshold=0.4)
            
            # 3. Le LLM tisse les relations basées sur les entités GLiNER
            raw_res = self._run_inference_with_entities(chunk, gliner_entities, temperature)
            
            if raw_res:
                current_chunk_map = {}
                for ent in raw_res.get("entities", []):
                    name_key = ent["name"].lower().strip()
                    if name_key not in entity_map:
                        new_id = f"E{len(entity_map) + 1}"
                        entity_map[name_key] = new_id
                        ent["id"] = new_id
                        final_graph["entities"].append(ent)
                    current_chunk_map[ent["id"]] = entity_map[name_key]

                for rel in raw_res.get("relationships", []):
                    rel["from"] = current_chunk_map.get(rel["from"], rel["from"])
                    rel["to"] = current_chunk_map.get(rel["to"], rel["to"])
                    final_graph["relationships"].append(rel)
        
        return final_graph

    def _get_dynamic_labels(self, text: str):
        """
        Analyse le texte intégral pour générer des catégories d'extraction 
        exhaustives et uniques.
        """
        # Prompt pour une analyse totale et sans perte
        prompt = f"""Tu es un analyste expert en extraction de connaissances. 
    Analyse l'intégralité du texte ci-dessous et liste tous les types d'entités (catégories) 
    nécessaires pour reconstruire ce document sous forme de graphe sans perte de précision.
    
    Cherche : Acteurs, Méthodologies, Chiffres clés, Unités de mesure, Dates, Lieux, 
    Variables, Fichiers sources, et Conditions contractuelles.
    
    TEXTE COMPLET :
    {text}
    
    Réponds uniquement par une liste de mots simples séparés par des virgules :"""
    
        inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
        with torch.no_grad():
            # On laisse un peu plus de tokens pour une liste riche
            outputs = self.model.generate(**inputs, max_new_tokens=150)
        
        res = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
        
        # --- LOGIQUE DE NETTOYAGE ET DÉDUPLICATION ---
        raw_labels = res.split(",")
        clean_labels = []
        seen = set()
    
        for l in raw_labels:
            # Nettoyage : retrait des espaces, mise en minuscule pour comparer
            label = l.strip().replace(".", "").replace("\n", "")
            if len(label) > 2:
                # On normalise (singulier et minuscule) pour éviter les doublons
                norm_label = label.lower().rstrip('s') 
                if norm_label not in seen:
                    seen.add(norm_label)
                    clean_labels.append(label.capitalize()) # On garde un joli format (ex: "Montant")
    
        return clean_labels

    def _run_inference_with_entities(self, text: str, gliner_ents: list, temperature: float):
        """Phase de liaison : le LLM crée le graphe JSON final."""
        # On injecte les entités détectées par GLiNER dans le prompt
        ents_str = "\n".join([f"- {e['text']} ({e['label']})" for e in gliner_ents])
        
        system_prompt = """Tu es un expert en graphes de connaissance. 
        Utilise les ENTITÉS extraites pour créer des RELATIONS précises basées sur le TEXTE.
        Les relations doivent être des verbes courts en MAJUSCULES.
        Utilise uniquement les verbes présents dans le texte source.
        Utilise EXCLUSIVEMENT les identifiants fournis dans la liste des entités pour remplir les champs 'from' et 'to'. 
        Ne réutilise jamais le nom complet de l'entité dans une relation.
        Réponds strictement en JSON sans explications."""

        user_prompt = f"SCHÉMA:\n{self.json_schema}\n\nENTITÉS DÉTECTÉES:\n{ents_str}\n\nTEXTE:\n{text}\n\nJSON:"

        # Inférence classique
        inputs = self.tokenizer.apply_chat_template(
            [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
            tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to("cpu")

        with torch.no_grad():
            outputs = self.model.generate(inputs, max_new_tokens=1500, temperature=temperature, do_sample=True if temperature > 0.1 else False)
        
        res_text = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
        try:
            return json.loads(self._clean(res_text))
        except:
            return None

    def _clean(self, t):
        t = t.strip()
        start, end = t.find('{'), t.rfind('}') + 1
        return t[start:end] if start != -1 and end != 0 else t