AutomatedSemanticDiscovery / src /graph /entity_resolver.py
GaetanoParente's picture
Update src/graph/entity_resolver.py
a28147d verified
raw
history blame
3.23 kB
import numpy as np
from sklearn.cluster import DBSCAN
from langchain_huggingface import HuggingFaceEmbeddings
from collections import Counter
class EntityResolver:
def __init__(self, model_name="all-MiniLM-L6-v2", similarity_threshold=0.85):
"""
Inizializza il modello per il calcolo delle similarità.
similarity_threshold: quanto devono essere vicini i vettori (0-1).
Convertito in 'eps' per DBSCAN.
"""
print("🧩 Inizializzazione Entity Resolver (DBSCAN)...")
self.embedding_model = HuggingFaceEmbeddings(model_name=model_name)
# DBSCAN usa la distanza, non la similarità. Distanza = 1 - Similarità.
# Se threshold è 0.85 (alta similarità), eps deve essere 0.15 (bassa distanza).
self.eps = 1 - similarity_threshold
def resolve_entities(self, triples):
"""
Prende una lista di triple (GraphTriple) e normalizza i nomi delle entità.
"""
if not triples:
return []
# 1. Estrazione di tutte le entità uniche (Soggetti e Oggetti)
all_entities = set()
for t in triples:
all_entities.add(t.subject)
all_entities.add(t.object)
unique_entities = list(all_entities)
print(f" Analisi di {len(unique_entities)} entità uniche per deduplica...")
if len(unique_entities) < 2:
return triples
# 2. Calcolo Embeddings
embeddings = self.embedding_model.embed_documents(unique_entities)
X = np.array(embeddings)
# 3. Clustering DBSCAN
# metrica='cosine' è fondamentale per vettori semantici
clustering = DBSCAN(eps=self.eps, min_samples=1, metric='cosine').fit(X)
labels = clustering.labels_
# 4. Creazione Mappa {Variante -> Canonico}
# Raggruppiamo le entità per Cluster ID
cluster_map = {}
for entity, label in zip(unique_entities, labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(entity)
# Per ogni cluster, eleggiamo il "Canonico" (es. la stringa più lunga)
entity_replacement_map = {}
for label, variants in cluster_map.items():
if len(variants) > 1:
# Euristiche di canonicalizzazione:
# 1. Preferisci quella che inizia con maiuscola
# 2. Preferisci la più lunga (spesso più descrittiva: "San Marco" vs "Basilica di San Marco")
canonical = sorted(variants, key=len, reverse=True)[0]
print(f" ✨ Deduplica: {variants} -> '{canonical}'")
for v in variants:
entity_replacement_map[v] = canonical
else:
entity_replacement_map[variants[0]] = variants[0]
# 5. Riscrittura Triple
resolved_triples = []
for t in triples:
# Sostituiamo soggetto e oggetto con le versioni canoniche
t.subject = entity_replacement_map.get(t.subject, t.subject)
t.object = entity_replacement_map.get(t.object, t.object)
resolved_triples.append(t)
return resolved_triples