JairoDanielMT's picture
Upload folder using huggingface_hub
4ef6c2b verified
Raw
History Blame Contribute Delete
5.92 kB
import os
import json
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Tuple
from src.ontology.models import OntologyRecord
from src.runtime.device_manager import DeviceManager
class EmbeddingEngine:
def __init__(self, model_name: str = "all-MiniLM-L6-v2", index_dir: str = "data/faiss_indices"):
# SentenceTransformer supports "device" argument
device = DeviceManager.get_optimal_device()
# SentenceTransformers expects 'cuda' or 'cpu' etc.
st_device = 'cuda' if device == 'CUDAExecutionProvider' else 'cpu'
self.model = SentenceTransformer(model_name, device=st_device)
self.index_dir = index_dir
self.indices = {}
self.records = {}
self.batch_size = DeviceManager.get_optimal_batch_size()
def build_index(self, ontology_dir: str):
"""Build FAISS indices, one per category (incremental)."""
os.makedirs(self.index_dir, exist_ok=True)
for filename in os.listdir(ontology_dir):
if filename.endswith(".json"):
category = filename.replace(".json", "")
path = os.path.join(ontology_dir, filename)
index_path = os.path.join(self.index_dir, f"{category}.index")
mapping_path = os.path.join(self.index_dir, f"{category}_records.json")
# Check for incremental update
if os.path.exists(index_path) and os.path.exists(mapping_path):
json_mtime = os.path.getmtime(path)
index_mtime = os.path.getmtime(index_path)
if json_mtime <= index_mtime:
print(f"Skipping {category}: index is up-to-date.")
continue
print(f"Building index for {category}...")
cat_records = []
texts = []
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data:
record = OntologyRecord(**item)
cat_records.append(record)
texts.append(record.canonical)
for alias in record.aliases:
cat_records.append(record)
texts.append(alias)
if not texts:
continue
embeddings = self.model.encode(texts, show_progress_bar=True, batch_size=self.batch_size)
embeddings = np.array(embeddings).astype('float32')
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
faiss.write_index(index, index_path)
with open(mapping_path, "w", encoding="utf-8") as f:
json.dump([r.model_dump() for r in cat_records], f, ensure_ascii=False)
self.indices[category] = index
self.records[category] = cat_records
def load_index(self):
"""Load all category indices from disk."""
if not os.path.exists(self.index_dir):
raise FileNotFoundError(f"Index directory not found at {self.index_dir}")
for filename in os.listdir(self.index_dir):
if filename.endswith(".index"):
category = filename.replace(".index", "")
index_path = os.path.join(self.index_dir, filename)
mapping_path = os.path.join(self.index_dir, f"{category}_records.json")
if os.path.exists(mapping_path):
self.indices[category] = faiss.read_index(index_path)
with open(mapping_path, "r", encoding="utf-8") as f:
data = json.load(f)
self.records[category] = [OntologyRecord(**item) for item in data]
def calibrate_score(self, l2_distance: float) -> float:
"""Convert L2 distance to a calibrated probability score [0, 1]."""
# For L2 normalized embeddings, dist = 2 - 2*cos_sim.
# So cos_sim = 1 - dist/2
cos_sim = 1.0 - (l2_distance / 2.0)
# Use simple temperature scaling or logistic curve mapping.
# Here we use a tuned sigmoid to penalize low cosine similarities sharply.
# Tuned to push 0.8 cos_sim to ~0.5 confidence, and 0.95 to ~0.95.
k = 15.0 # steepness
x0 = 0.8 # midpoint
conf = 1.0 / (1.0 + np.exp(-k * (cos_sim - x0)))
return float(np.clip(conf, 0.0, 1.0))
def search(self, query: str, category: str = None, top_k: int = 5) -> List[Tuple[OntologyRecord, float]]:
"""Search for similar concepts, optionally within a specific category."""
if not self.indices:
self.load_index()
query_vector = self.model.encode([query]).astype('float32')
results = []
categories_to_search = [category] if category and category in self.indices else self.indices.keys()
for cat in categories_to_search:
distances, indices = self.indices[cat].search(query_vector, top_k)
for dist, idx in zip(distances[0], indices[0]):
if idx < len(self.records[cat]):
conf = self.calibrate_score(float(dist))
results.append((self.records[cat][idx], conf))
# Sort combined results by confidence descending
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]