import logging import torch import numpy as np import os from typing import List, Dict, Any, Tuple logger = logging.getLogger(__name__) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) class SemanticNormalizer: """ Grounds natural language entities into a controlled ontology using embeddings. Solves the 'sharp object' -> 'knife' problem. """ def __init__(self, model_name="all-MiniLM-L6-v2"): self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" self._model = None self._ontology_embeddings = {} # label -> embedding self._ontology_labels = [] def _load(self): if not self._model: from sentence_transformers import SentenceTransformer # Sanctuary for embeddings cache_dir = os.path.join(BASE_DIR, "mission_models", "LinguisticBackbone") os.makedirs(cache_dir, exist_ok=True) logger.info(f"[SEMANTIC] Loading embedding model {self.model_name}...") self._model = SentenceTransformer(self.model_name, cache_folder=cache_dir, device=self.device) logger.info("[SEMANTIC] Model loaded.") def fit_ontology(self, labels: List[str]): """Pre-computes embeddings for the ontology labels.""" self._load() self._ontology_labels = labels embeddings = self._model.encode(labels, convert_to_tensor=True) for label, emb in zip(labels, embeddings): self._ontology_embeddings[label] = emb logger.info(f"[SEMANTIC] Indexed {len(labels)} ontology labels.") def normalize(self, query: str, threshold: float = 0.45) -> List[Tuple[str, float]]: """Maps a query string to the closest ontology labels.""" if not query or not self._ontology_labels: return [] self._load() from sentence_transformers import util query_emb = self._model.encode(query, convert_to_tensor=True) results = [] # Calculate similarity with all ontology labels for label, label_emb in self._ontology_embeddings.items(): score = util.cos_sim(query_emb, label_emb).item() if score >= threshold: results.append((label, score)) # Sort by best match results.sort(key=lambda x: x[1], reverse=True) return results # Singleton instance semantic_normalizer = SemanticNormalizer()