klydekushy commited on
Commit
75dd810
·
verified ·
1 Parent(s): 245eae5

Update core/extractor.py

Browse files
Files changed (1) hide show
  1. core/extractor.py +38 -27
core/extractor.py CHANGED
@@ -5,6 +5,7 @@ from typing import List
5
  from pydantic import BaseModel, Field
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from gliner import GLiNER # N'oubliez pas d'ajouter 'gliner' dans requirements.txt
 
8
 
9
  # --- SCHÉMAS DE DONNÉES ---
10
  class Entity(BaseModel):
@@ -75,36 +76,46 @@ class ExtractorEngine:
75
  return final_graph
76
 
77
  def _get_dynamic_labels(self, text: str):
78
- # Prompt universel qui ne change JAMAIS, quel que soit le document
79
- prompt = (
80
- "En tant qu'expert en analyse de données, identifie TOUS les types d'entités "
81
- "nécessaires pour reconstruire ce document sans perte d'information. "
82
- "Inclus les acteurs, les objets, les actions, les chiffres clés, les dates et les lieux. "
83
- "Réponds uniquement par une liste de mots séparés par des virgules."
84
- )
 
 
 
 
 
 
 
 
 
 
85
  inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
86
- with torch.no_grad():
87
- # On laisse un peu plus de tokens pour une liste riche
88
- outputs = self.model.generate(**inputs, max_new_tokens=150)
89
-
90
- res = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
91
-
92
- # --- LOGIQUE DE NETTOYAGE ET DÉDUPLICATION ---
93
- raw_labels = res.split(",")
94
- clean_labels = []
95
- seen = set()
96
 
97
- for l in raw_labels:
98
- # Nettoyage : retrait des espaces, mise en minuscule pour comparer
99
- label = l.strip().replace(".", "").replace("\n", "")
100
- if len(label) > 2:
101
- # On normalise (singulier et minuscule) pour éviter les doublons
102
- norm_label = label.lower().rstrip('s')
103
- if norm_label not in seen:
104
- seen.add(norm_label)
105
- clean_labels.append(label.capitalize()) # On garde un joli format (ex: "Montant")
106
 
107
- return clean_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def _run_inference_with_entities(self, text: str, gliner_ents: list, temperature: float):
110
  """Phase de liaison : le LLM crée le graphe JSON final."""
 
5
  from pydantic import BaseModel, Field
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from gliner import GLiNER # N'oubliez pas d'ajouter 'gliner' dans requirements.txt
8
+ from core.extractor import ExtractorEngine
9
 
10
  # --- SCHÉMAS DE DONNÉES ---
11
  class Entity(BaseModel):
 
76
  return final_graph
77
 
78
  def _get_dynamic_labels(self, text: str):
79
+ """
80
+ Analyse le texte intégral pour générer des catégories d'extraction
81
+ exhaustives et uniques.
82
+ """
83
+ # Prompt pour une analyse totale et sans perte
84
+ prompt = f"""Tu es un analyste expert en extraction de connaissances.
85
+ Analyse l'intégralité du texte ci-dessous et liste tous les types d'entités (catégories)
86
+ nécessaires pour reconstruire ce document sous forme de graphe sans perte de précision.
87
+
88
+ Cherche : Acteurs, Méthodologies, Chiffres clés, Unités de mesure, Dates, Lieux,
89
+ Variables, Fichiers sources, et Conditions contractuelles.
90
+
91
+ TEXTE COMPLET :
92
+ {text}
93
+
94
+ Réponds uniquement par une liste de mots simples séparés par des virgules :"""
95
+
96
  inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
97
+ with torch.no_grad():
98
+ # On laisse un peu plus de tokens pour une liste riche
99
+ outputs = self.model.generate(**inputs, max_new_tokens=150)
 
 
 
 
 
 
 
100
 
101
+ res = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
102
 
103
+ # --- LOGIQUE DE NETTOYAGE ET DÉDUPLICATION ---
104
+ raw_labels = res.split(",")
105
+ clean_labels = []
106
+ seen = set()
107
+
108
+ for l in raw_labels:
109
+ # Nettoyage : retrait des espaces, mise en minuscule pour comparer
110
+ label = l.strip().replace(".", "").replace("\n", "")
111
+ if len(label) > 2:
112
+ # On normalise (singulier et minuscule) pour éviter les doublons
113
+ norm_label = label.lower().rstrip('s')
114
+ if norm_label not in seen:
115
+ seen.add(norm_label)
116
+ clean_labels.append(label.capitalize()) # On garde un joli format (ex: "Montant")
117
+
118
+ return clean_labels
119
 
120
  def _run_inference_with_entities(self, text: str, gliner_ents: list, temperature: float):
121
  """Phase de liaison : le LLM crée le graphe JSON final."""