Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Any, Tuple | |
| from src.ontology.models import OntologyRecord | |
| from src.runtime.device_manager import DeviceManager | |
| class EmbeddingEngine: | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2", index_dir: str = "data/faiss_indices"): | |
| # SentenceTransformer supports "device" argument | |
| device = DeviceManager.get_optimal_device() | |
| # SentenceTransformers expects 'cuda' or 'cpu' etc. | |
| st_device = 'cuda' if device == 'CUDAExecutionProvider' else 'cpu' | |
| self.model = SentenceTransformer(model_name, device=st_device) | |
| self.index_dir = index_dir | |
| self.indices = {} | |
| self.records = {} | |
| self.batch_size = DeviceManager.get_optimal_batch_size() | |
| def build_index(self, ontology_dir: str): | |
| """Build FAISS indices, one per category (incremental).""" | |
| os.makedirs(self.index_dir, exist_ok=True) | |
| for filename in os.listdir(ontology_dir): | |
| if filename.endswith(".json"): | |
| category = filename.replace(".json", "") | |
| path = os.path.join(ontology_dir, filename) | |
| index_path = os.path.join(self.index_dir, f"{category}.index") | |
| mapping_path = os.path.join(self.index_dir, f"{category}_records.json") | |
| # Check for incremental update | |
| if os.path.exists(index_path) and os.path.exists(mapping_path): | |
| json_mtime = os.path.getmtime(path) | |
| index_mtime = os.path.getmtime(index_path) | |
| if json_mtime <= index_mtime: | |
| print(f"Skipping {category}: index is up-to-date.") | |
| continue | |
| print(f"Building index for {category}...") | |
| cat_records = [] | |
| texts = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for item in data: | |
| record = OntologyRecord(**item) | |
| cat_records.append(record) | |
| texts.append(record.canonical) | |
| for alias in record.aliases: | |
| cat_records.append(record) | |
| texts.append(alias) | |
| if not texts: | |
| continue | |
| embeddings = self.model.encode(texts, show_progress_bar=True, batch_size=self.batch_size) | |
| embeddings = np.array(embeddings).astype('float32') | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings) | |
| faiss.write_index(index, index_path) | |
| with open(mapping_path, "w", encoding="utf-8") as f: | |
| json.dump([r.model_dump() for r in cat_records], f, ensure_ascii=False) | |
| self.indices[category] = index | |
| self.records[category] = cat_records | |
| def load_index(self): | |
| """Load all category indices from disk.""" | |
| if not os.path.exists(self.index_dir): | |
| raise FileNotFoundError(f"Index directory not found at {self.index_dir}") | |
| for filename in os.listdir(self.index_dir): | |
| if filename.endswith(".index"): | |
| category = filename.replace(".index", "") | |
| index_path = os.path.join(self.index_dir, filename) | |
| mapping_path = os.path.join(self.index_dir, f"{category}_records.json") | |
| if os.path.exists(mapping_path): | |
| self.indices[category] = faiss.read_index(index_path) | |
| with open(mapping_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| self.records[category] = [OntologyRecord(**item) for item in data] | |
| def calibrate_score(self, l2_distance: float) -> float: | |
| """Convert L2 distance to a calibrated probability score [0, 1].""" | |
| # For L2 normalized embeddings, dist = 2 - 2*cos_sim. | |
| # So cos_sim = 1 - dist/2 | |
| cos_sim = 1.0 - (l2_distance / 2.0) | |
| # Use simple temperature scaling or logistic curve mapping. | |
| # Here we use a tuned sigmoid to penalize low cosine similarities sharply. | |
| # Tuned to push 0.8 cos_sim to ~0.5 confidence, and 0.95 to ~0.95. | |
| k = 15.0 # steepness | |
| x0 = 0.8 # midpoint | |
| conf = 1.0 / (1.0 + np.exp(-k * (cos_sim - x0))) | |
| return float(np.clip(conf, 0.0, 1.0)) | |
| def search(self, query: str, category: str = None, top_k: int = 5) -> List[Tuple[OntologyRecord, float]]: | |
| """Search for similar concepts, optionally within a specific category.""" | |
| if not self.indices: | |
| self.load_index() | |
| query_vector = self.model.encode([query]).astype('float32') | |
| results = [] | |
| categories_to_search = [category] if category and category in self.indices else self.indices.keys() | |
| for cat in categories_to_search: | |
| distances, indices = self.indices[cat].search(query_vector, top_k) | |
| for dist, idx in zip(distances[0], indices[0]): | |
| if idx < len(self.records[cat]): | |
| conf = self.calibrate_score(float(dist)) | |
| results.append((self.records[cat][idx], conf)) | |
| # Sort combined results by confidence descending | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| return results[:top_k] | |