import os import json import faiss import numpy as np from sentence_transformers import SentenceTransformer from typing import List, Dict, Any, Tuple from src.ontology.models import OntologyRecord from src.runtime.device_manager import DeviceManager class EmbeddingEngine: def __init__(self, model_name: str = "all-MiniLM-L6-v2", index_dir: str = "data/faiss_indices"): # SentenceTransformer supports "device" argument device = DeviceManager.get_optimal_device() # SentenceTransformers expects 'cuda' or 'cpu' etc. st_device = 'cuda' if device == 'CUDAExecutionProvider' else 'cpu' self.model = SentenceTransformer(model_name, device=st_device) self.index_dir = index_dir self.indices = {} self.records = {} self.batch_size = DeviceManager.get_optimal_batch_size() def build_index(self, ontology_dir: str): """Build FAISS indices, one per category (incremental).""" os.makedirs(self.index_dir, exist_ok=True) for filename in os.listdir(ontology_dir): if filename.endswith(".json"): category = filename.replace(".json", "") path = os.path.join(ontology_dir, filename) index_path = os.path.join(self.index_dir, f"{category}.index") mapping_path = os.path.join(self.index_dir, f"{category}_records.json") # Check for incremental update if os.path.exists(index_path) and os.path.exists(mapping_path): json_mtime = os.path.getmtime(path) index_mtime = os.path.getmtime(index_path) if json_mtime <= index_mtime: print(f"Skipping {category}: index is up-to-date.") continue print(f"Building index for {category}...") cat_records = [] texts = [] with open(path, "r", encoding="utf-8") as f: data = json.load(f) for item in data: record = OntologyRecord(**item) cat_records.append(record) texts.append(record.canonical) for alias in record.aliases: cat_records.append(record) texts.append(alias) if not texts: continue embeddings = self.model.encode(texts, show_progress_bar=True, batch_size=self.batch_size) embeddings = np.array(embeddings).astype('float32') dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) faiss.write_index(index, index_path) with open(mapping_path, "w", encoding="utf-8") as f: json.dump([r.model_dump() for r in cat_records], f, ensure_ascii=False) self.indices[category] = index self.records[category] = cat_records def load_index(self): """Load all category indices from disk.""" if not os.path.exists(self.index_dir): raise FileNotFoundError(f"Index directory not found at {self.index_dir}") for filename in os.listdir(self.index_dir): if filename.endswith(".index"): category = filename.replace(".index", "") index_path = os.path.join(self.index_dir, filename) mapping_path = os.path.join(self.index_dir, f"{category}_records.json") if os.path.exists(mapping_path): self.indices[category] = faiss.read_index(index_path) with open(mapping_path, "r", encoding="utf-8") as f: data = json.load(f) self.records[category] = [OntologyRecord(**item) for item in data] def calibrate_score(self, l2_distance: float) -> float: """Convert L2 distance to a calibrated probability score [0, 1].""" # For L2 normalized embeddings, dist = 2 - 2*cos_sim. # So cos_sim = 1 - dist/2 cos_sim = 1.0 - (l2_distance / 2.0) # Use simple temperature scaling or logistic curve mapping. # Here we use a tuned sigmoid to penalize low cosine similarities sharply. # Tuned to push 0.8 cos_sim to ~0.5 confidence, and 0.95 to ~0.95. k = 15.0 # steepness x0 = 0.8 # midpoint conf = 1.0 / (1.0 + np.exp(-k * (cos_sim - x0))) return float(np.clip(conf, 0.0, 1.0)) def search(self, query: str, category: str = None, top_k: int = 5) -> List[Tuple[OntologyRecord, float]]: """Search for similar concepts, optionally within a specific category.""" if not self.indices: self.load_index() query_vector = self.model.encode([query]).astype('float32') results = [] categories_to_search = [category] if category and category in self.indices else self.indices.keys() for cat in categories_to_search: distances, indices = self.indices[cat].search(query_vector, top_k) for dist, idx in zip(distances[0], indices[0]): if idx < len(self.records[cat]): conf = self.calibrate_score(float(dist)) results.append((self.records[cat][idx], conf)) # Sort combined results by confidence descending results.sort(key=lambda x: x[1], reverse=True) return results[:top_k]