klydekushy commited on
Commit
c360f43
·
verified ·
1 Parent(s): c5ef5c1

Update core/extractor.py

Browse files
Files changed (1) hide show
  1. core/extractor.py +83 -24
core/extractor.py CHANGED
@@ -15,8 +15,8 @@ class Entity(BaseModel):
15
  class Relationship(BaseModel):
16
  source: str = Field(alias="from", description="ID de l'entité source.")
17
  target: str = Field(alias="to", description="ID de l'entité cible.")
18
- type: str = Field(description="Action ou lien (ex: USES, CORRELATED_WITH, AUTHORED_BY).")
19
- description: str = Field(description="Explication du lien.")
20
 
21
  class KnowledgeGraph(BaseModel):
22
  entities: List[Entity]
@@ -35,30 +35,89 @@ class ExtractorEngine:
35
  self.model = st.session_state.llm_model
36
  self.json_schema = json.dumps(KnowledgeGraph.model_json_schema(), indent=2)
37
 
38
- def extract_graph(self, text: str):
39
- # Votre version du prompt renforcée
40
- system_prompt = """Tu es un système d'extraction de graphe de connaissance hautement fiable.
41
- Ton objectif est d'extraire toutes les entités et relations pertinentes du texte fourni.
42
- Réponds TOUJOURS uniquement en JSON. Le JSON DOIT respecter le schéma spécifié ci-dessous
43
- SANS AUCUNE EXPLICATION SUPPLÉMENTAIRE."""
44
 
45
- user_prompt = f"Schéma:\n{self.json_schema}\n\nTexte:\n{text[:4000]}\n\nRéponse JSON:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- try:
48
- inputs = self.tokenizer.apply_chat_template(
49
- [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
50
- tokenize=True, add_generation_prompt=True, return_tensors="pt"
51
- ).to("cpu")
52
 
53
- with torch.no_grad():
54
- outputs = self.model.generate(inputs, max_new_tokens=1500, temperature=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- result = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
57
- return json.loads(self._clean(result))
58
- except Exception as e:
59
- st.error(f"Erreur IA : {e}")
60
- return None
61
-
62
- def _clean(self, t):
63
- return t.strip().replace("```json", "").replace("```", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
 
15
  class Relationship(BaseModel):
16
  source: str = Field(alias="from", description="ID de l'entité source.")
17
  target: str = Field(alias="to", description="ID de l'entité cible.")
18
+ type: str = Field(description="Action ou lien sémantique (ex: USES, CORRELATED_WITH, AUTHORED_BY).")
19
+ description: str = Field(description="Explication du lien ou Détails.")
20
 
21
  class KnowledgeGraph(BaseModel):
22
  entities: List[Entity]
 
35
  self.model = st.session_state.llm_model
36
  self.json_schema = json.dumps(KnowledgeGraph.model_json_schema(), indent=2)
37
 
38
+
39
+ def extract_long_text(self, text: str, temperature: float, chunk_size: int = 3000):
40
+ # Découpage du texte
41
+ chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
42
+ final_graph = {"entities": [], "relationships": []}
43
+ entity_map = {} # Pour fusionner les noms identiques
44
 
45
+ for chunk in chunks:
46
+ raw_res = self._run_inference(chunk, temperature)
47
+ if raw_res:
48
+ # Fusion des entités par nom
49
+ for ent in raw_res.get("entities", []):
50
+ name_key = ent["name"].lower().strip()
51
+ if name_key not in entity_map:
52
+ new_id = f"E{len(entity_map) + 1}"
53
+ entity_map[name_key] = new_id
54
+ ent["id"] = new_id
55
+ final_graph["entities"].append(ent)
56
+
57
+ # Ré-attribution des IDs dans les relations
58
+ for rel in raw_res.get("relationships", []):
59
+ # (Note: simplifiée ici, nécessite que l'IA respecte les noms dans le chunk)
60
+ final_graph["relationships"].append(rel)
61
+
62
+ return final_graph
63
 
 
 
 
 
 
64
 
65
+
66
+ def _run_inference(self, text: str, temperature: float):
67
+ """
68
+ Exécute l'inférence sur un segment de texte en utilisant le prompt renforcé
69
+ et la température réglable depuis l'interface.
70
+ """
71
+ # Utilisation de votre prompt système renforcé pour une fiabilité maximale
72
+ system_prompt = """Tu es un système d'extraction de graphe de connaissance hautement fiable.
73
+ Ton objectif est d'extraire toutes les entités et relations pertinentes du texte fourni.
74
+ Réponds TOUJOURS uniquement en JSON. Le JSON DOIT respecter le schéma spécifié ci-dessous
75
+ SANS AUCUNE EXPLICATION SUPPLÉMENTAIRE."""
76
+
77
+ # Construction du prompt utilisateur avec le segment de texte
78
+ user_prompt = f"Schéma JSON STRICT à respecter:\n{self.json_schema}\n\nTexte source:\n<<<{text}>>>\n\nRéponse JSON:"
79
 
80
+ try:
81
+ # Application du template de chat spécifique à Qwen
82
+ inputs = self.tokenizer.apply_chat_template(
83
+ [
84
+ {"role": "system", "content": system_prompt},
85
+ {"role": "user", "content": user_prompt}
86
+ ],
87
+ tokenize=True,
88
+ add_generation_prompt=True,
89
+ return_tensors="pt"
90
+ ).to("cpu")
91
+
92
+ # Génération avec les paramètres optimisés
93
+ with torch.no_grad():
94
+ outputs = self.model.generate(
95
+ inputs,
96
+ max_new_tokens=1500, # Augmenté pour ne pas couper les gros JSON
97
+ temperature=temperature, # Dynamique via le slider Streamlit
98
+ do_sample=True if temperature > 0.1 else False,
99
+ pad_token_id=self.tokenizer.eos_token_id
100
+ )
101
+
102
+ # Décodage de la réponse
103
+ generated_text = self.tokenizer.decode(
104
+ outputs[0][inputs.shape[1]:],
105
+ skip_special_tokens=True
106
+ )
107
+
108
+ # Nettoyage et conversion en dictionnaire Python
109
+ clean_content = self._clean(generated_text)
110
+ return json.loads(clean_content)
111
+
112
+ except Exception as e:
113
+ # En cas d'erreur de parsing ou de génération sur ce segment
114
+ st.warning(f"Avertissement sur un segment : {e}")
115
+ return None
116
+
117
+ def _clean(self, t):
118
+ """Nettoyage rigoureux des balises Markdown et espaces superflus."""
119
+ t = t.strip()
120
+ if t.startswith("```"):
121
+ t = t.replace("```json", "").replace("```", "")
122
+ return t.strip()
123