OCR_PROSPECTUS / core /extractor.py
klydekushy's picture
Update core/extractor.py
f52f333 verified
import torch
import json
import streamlit as st
from typing import List
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM
from gliner import GLiNER # N'oubliez pas d'ajouter 'gliner' dans requirements.txt
# --- SCHÉMAS DE DONNÉES ---
class Entity(BaseModel):
id: str = Field(description="ID unique (ex: E1).")
name: str = Field(description="Nom exact trouvé.")
type: str = Field(description="Catégorie détectée.")
description: str = Field(description="Rôle ou contexte.")
class Relationship(BaseModel):
source: str = Field(alias="from", description="ID source.")
target: str = Field(alias="to", description="ID cible.")
type: str = Field(description="Verbe d'action court.")
description: str = Field(description="Détails du lien.")
class KnowledgeGraph(BaseModel):
entities: List[Entity]
relationships: List[Relationship]
class ExtractorEngine:
def __init__(self):
self.model_name = "Qwen/Qwen2.5-1.5B-Instruct"
if 'llm_model' not in st.session_state:
with st.spinner("🚀 Chargement des cerveaux IA (CPU)..."):
# Chargement Qwen (Compréhension & Relations)
st.session_state.llm_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
st.session_state.llm_model = AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.float32, device_map=None, low_cpu_mem_usage=False
)
# Chargement GLiNER (Extraction de précision)
st.session_state.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
self.tokenizer = st.session_state.llm_tokenizer
self.model = st.session_state.llm_model
self.gliner = st.session_state.gliner_model
self.json_schema = json.dumps(KnowledgeGraph.model_json_schema(), indent=2)
def extract_long_text(self, text: str, temperature: float, chunk_size: int = 3500):
chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
final_graph = {"entities": [], "relationships": []}
entity_map = {}
for chunk in chunks:
# 1. Le LLM identifie dynamiquement les catégories importantes
dynamic_labels = self._get_labels_from_llm(chunk)
# 2. GLiNER extrait les entités avec ces labels
gliner_entities = self.gliner.predict_entities(chunk, dynamic_labels, threshold=0.4)
# 3. Le LLM tisse les relations basées sur les entités GLiNER
raw_res = self._run_inference_with_entities(chunk, gliner_entities, temperature)
if raw_res:
current_chunk_map = {}
for ent in raw_res.get("entities", []):
name_key = ent["name"].lower().strip()
if name_key not in entity_map:
new_id = f"E{len(entity_map) + 1}"
entity_map[name_key] = new_id
ent["id"] = new_id
final_graph["entities"].append(ent)
current_chunk_map[ent["id"]] = entity_map[name_key]
for rel in raw_res.get("relationships", []):
rel["from"] = current_chunk_map.get(rel["from"], rel["from"])
rel["to"] = current_chunk_map.get(rel["to"], rel["to"])
final_graph["relationships"].append(rel)
return final_graph
def _get_dynamic_labels(self, text: str):
"""
Analyse le texte intégral pour générer des catégories d'extraction
exhaustives et uniques.
"""
# Prompt pour une analyse totale et sans perte
prompt = f"""Tu es un analyste expert en extraction de connaissances.
Analyse l'intégralité du texte ci-dessous et liste tous les types d'entités (catégories)
nécessaires pour reconstruire ce document sous forme de graphe sans perte de précision.
Cherche : Acteurs, Méthodologies, Chiffres clés, Unités de mesure, Dates, Lieux,
Variables, Fichiers sources, et Conditions contractuelles.
TEXTE COMPLET :
{text}
Réponds uniquement par une liste de mots simples séparés par des virgules :"""
inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
with torch.no_grad():
# On laisse un peu plus de tokens pour une liste riche
outputs = self.model.generate(**inputs, max_new_tokens=150)
res = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
# --- LOGIQUE DE NETTOYAGE ET DÉDUPLICATION ---
raw_labels = res.split(",")
clean_labels = []
seen = set()
for l in raw_labels:
# Nettoyage : retrait des espaces, mise en minuscule pour comparer
label = l.strip().replace(".", "").replace("\n", "")
if len(label) > 2:
# On normalise (singulier et minuscule) pour éviter les doublons
norm_label = label.lower().rstrip('s')
if norm_label not in seen:
seen.add(norm_label)
clean_labels.append(label.capitalize()) # On garde un joli format (ex: "Montant")
return clean_labels
def _run_inference_with_entities(self, text: str, gliner_ents: list, temperature: float):
"""Phase de liaison : le LLM crée le graphe JSON final."""
# On injecte les entités détectées par GLiNER dans le prompt
ents_str = "\n".join([f"- {e['text']} ({e['label']})" for e in gliner_ents])
system_prompt = """Tu es un expert en graphes de connaissance.
Utilise les ENTITÉS extraites pour créer des RELATIONS précises basées sur le TEXTE.
Les relations doivent être des verbes courts en MAJUSCULES.
Utilise uniquement les verbes présents dans le texte source.
Utilise EXCLUSIVEMENT les identifiants fournis dans la liste des entités pour remplir les champs 'from' et 'to'.
Ne réutilise jamais le nom complet de l'entité dans une relation.
Réponds strictement en JSON sans explications."""
user_prompt = f"SCHÉMA:\n{self.json_schema}\n\nENTITÉS DÉTECTÉES:\n{ents_str}\n\nTEXTE:\n{text}\n\nJSON:"
# Inférence classique
inputs = self.tokenizer.apply_chat_template(
[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to("cpu")
with torch.no_grad():
outputs = self.model.generate(inputs, max_new_tokens=1500, temperature=temperature, do_sample=True if temperature > 0.1 else False)
res_text = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
try:
return json.loads(self._clean(res_text))
except:
return None
def _clean(self, t):
t = t.strip()
start, end = t.find('{'), t.rfind('}') + 1
return t[start:end] if start != -1 and end != 0 else t