GaetanoParente's picture
Update src/extraction/extractor.py
cfc197c verified
raw
history blame
9.68 kB
import json
import os
import numpy as np
from typing import List, Optional
from pydantic import BaseModel, Field, ValidationError
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
# Gestione Multi-Backend (Locale vs Cloud)
from langchain_ollama import ChatOllama
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFaceEndpoint
from sklearn.metrics.pairwise import cosine_similarity
# --- 1. DEFINIZIONE DELLO SCHEMA ---
class GraphTriple(BaseModel):
subject: str = Field(..., description="Entità sorgente (Canonical).")
predicate: str = Field(..., description="Relazione (snake_case).")
object: str = Field(..., description="Entità target.")
confidence: float = Field(..., description="Confidenza (0.0 - 1.0).")
source: Optional[str] = Field(None, description="ID del documento o chunk.")
class KnowledgeGraphExtraction(BaseModel):
reasoning: Optional[str] = Field(None, description="Breve ragionamento logico.")
triples: List[GraphTriple]
# --- 2. ESTRATTORE DINAMICO (Dynamic Few-Shot) ---
class NeuroSymbolicExtractor:
def __init__(self, model_name="llama3", temperature=0, gold_standard_path=None):
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("☁️ Rilevato ambiente Cloud (HF Spaces). Utilizzo HuggingFace Inference API.")
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
try:
endpoint = HuggingFaceEndpoint(
repo_id=repo_id,
task="text-generation",
max_new_tokens=1024,
temperature=0.1,
huggingfacehub_api_token=hf_token
)
self.llm = ChatHuggingFace(llm=endpoint)
print(f"✅ Connesso a {repo_id} via API.")
except Exception as e:
print(f"❌ Errore connessione HF API: {e}. Fallback su CPU locale (sconsigliato).")
raise e
else:
print(f"🏠 Ambiente Locale rilevato. Inizializzazione Ollama: {model_name}...")
try:
self.llm = ChatOllama(
model=model_name,
temperature=temperature,
format="json",
base_url="http://localhost:11434"
)
except Exception as e:
print(f"⚠️ Errore Ollama: {e}")
# 2. Modello Embedding per la selezione dinamica
print("🧠 Caricamento modello embedding per Dynamic Selection...")
self.embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
# 3. Caricamento e Indicizzazione Gold Standard
self.examples = []
self.example_embeddings = None
if gold_standard_path and os.path.exists(gold_standard_path):
print(f"🌟 Indicizzazione vettoriale Gold Standard da: {gold_standard_path}")
self._index_examples(gold_standard_path)
else:
# Crea una lista vuota per evitare crash se il path non esiste
print("⚠️ Nessun Gold Standard trovato. Modalità Zero-Shot.")
# Template Specializzato (Prompt Engineering)
self.system_template_base = """Sei l'Agente Cognitivo (AC) del sistema Canusium xCH.
Il tuo compito è trasformare il testo non strutturato in un Digital Twin Graph (RDF).
SCHEMA JSON RICHIESTO:
{{
"reasoning": "Spiega brevemente perché hai scelto queste classi/relazioni...",
"triples": [
{{"subject": "Entità", "predicate": "prefix:Relazione", "object": "Entità", "confidence": 0.95}}
]
}}
ONTOLOGIA DI RIFERIMENTO (Usa questi prefissi):
- xchh: (Heritage) -> Per oggetti fisici, siti, reperti (es. xchh:HeritageObject, xchh:Site).
- crm: (CIDOC-CRM) -> Per relazioni standard (es. crm:P55_has_current_location, crm:P4_has_time-span).
- xche: (Experience) -> Per sessioni AR/VR, visitatori, interazioni (es. xche:ExperienceSession).
- xcha: (Agents) -> Per agenti umani o artificiali.
- skos: -> Per concetti generici o gerarchie.
ESEMPI CONTESTUALI (Dynamic Few-Shot):
{selected_examples}
REGOLE DI CONFIDENZA (Trust Layer):
- 1.0 (Fatto Curato): Informazione esplicita e certa nel testo.
- 0.8 - 0.9 (Inferenza): Deduzione logica forte ma non esplicita.
- < 0.7 (Ipotesi): Associazione probabile ma incerta (da marcare per revisione umana).
Canonicalizza i nomi (es. "Il Parco" -> "Parco Archeologico di Canne").
Rispondi ESCLUSIVAMENTE con un JSON valido.
"""
def _index_examples(self, path: str):
"""Carica il JSON e calcola i vettori per ogni esempio."""
try:
with open(path, 'r', encoding='utf-8') as f:
self.examples = json.load(f)
# Estraiamo solo il testo di input per calcolare l'embedding
texts = [ex['text'] for ex in self.examples]
self.example_embeddings = self.embedding_model.embed_documents(texts)
print(f"✅ Indicizzati {len(self.examples)} esempi di Gold Standard.")
except Exception as e:
print(f"❌ Errore indicizzazione Gold Standard: {e}")
self.examples = []
def _get_relevant_examples(self, query_text: str, k=2) -> str:
"""
Trova i k esempi più simili semanticamente al chunk attuale.
"""
if not self.examples or self.example_embeddings is None:
return "Nessun esempio disponibile."
# 1. Embed del chunk attuale
query_embedding = self.embedding_model.embed_query(query_text)
# 2. Calcolo similarità coseno
similarities = cosine_similarity([query_embedding], self.example_embeddings)[0]
# 3. Selezione dei top-k
top_k_indices = np.argsort(similarities)[-k:][::-1]
formatted_text = ""
for i, idx in enumerate(top_k_indices):
ex = self.examples[idx]
sim_score = similarities[idx]
formatted_text += f"\n--- ESEMPIO RILEVANTE #{i+1} (Sim: {sim_score:.2f}) ---\n"
formatted_text += f"INPUT: {ex['text']}\n"
# Gestione sicura nel caso triples manchi
triples_out = ex.get('triples', [])
formatted_text += f"OUTPUT: {json.dumps({'triples': triples_out}, ensure_ascii=False)}\n"
return formatted_text
def extract(self, text_chunk: str, source_id: str = "unknown", max_retries=3) -> KnowledgeGraphExtraction:
print(f"🧠 Processing {source_id} (Dynamic Mode)...")
# Selezione Esempi
relevant_examples_str = self._get_relevant_examples(text_chunk, k=2)
# Costruzione Prompt Finale
final_sys_text = self.system_template_base.format(selected_examples=relevant_examples_str)
sys_msg = SystemMessage(content=final_sys_text)
prompt = ChatPromptTemplate.from_messages([
sys_msg,
("human", "{text}")
])
chain = prompt | self.llm
for attempt in range(max_retries):
try:
response = chain.invoke({"text": text_chunk})
# Parsing della risposta (diversa tra Ollama e HF)
content = response.content
# Pulizia base se il modello chiacchiera prima del JSON
if "```json" in content:
content = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
content = content.split("```")[1].split("```")[0].strip()
data = json.loads(content)
# Normalizzazione output
if isinstance(data, list):
validated_data = KnowledgeGraphExtraction(triples=data, reasoning="Direct list output")
else:
# Filtra campi extra che il modello potrebbe inventare
triples = [GraphTriple(**t) for t in data.get("triples", [])]
validated_data = KnowledgeGraphExtraction(
reasoning=data.get("reasoning", "N/A"),
triples=triples
)
for t in validated_data.triples:
t.source = source_id
return validated_data
except (json.JSONDecodeError, ValidationError) as e:
print(f"⚠️ Errore Validazione (Tentativo {attempt+1}/{max_retries}): {e}")
# SELF-CORRECTION LOOP
prev_content = locals().get('content', 'No content')
correction_prompt = ChatPromptTemplate.from_messages([
sys_msg,
HumanMessage(content=text_chunk),
AIMessage(content=prev_content),
HumanMessage(content=f"Errore nel JSON precedente: {e}. Correggi e restituisci SOLO JSON valido senza markdown.")
])
chain = correction_prompt | self.llm
except Exception as e:
print(f"❌ Errore critico: {e}")
break
return KnowledgeGraphExtraction(triples=[])