| import logging |
| import torch |
| import numpy as np |
| import os |
| from typing import List, Dict, Any, Tuple |
|
|
| logger = logging.getLogger(__name__) |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
| class SemanticNormalizer: |
| """ |
| Grounds natural language entities into a controlled ontology using embeddings. |
| Solves the 'sharp object' -> 'knife' problem. |
| """ |
| def __init__(self, model_name="all-MiniLM-L6-v2"): |
| self.model_name = model_name |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self._model = None |
| self._ontology_embeddings = {} |
| self._ontology_labels = [] |
|
|
| def _load(self): |
| if not self._model: |
| from sentence_transformers import SentenceTransformer |
| |
| cache_dir = os.path.join(BASE_DIR, "mission_models", "LinguisticBackbone") |
| os.makedirs(cache_dir, exist_ok=True) |
| |
| logger.info(f"[SEMANTIC] Loading embedding model {self.model_name}...") |
| self._model = SentenceTransformer(self.model_name, cache_folder=cache_dir, device=self.device) |
| logger.info("[SEMANTIC] Model loaded.") |
|
|
| def fit_ontology(self, labels: List[str]): |
| """Pre-computes embeddings for the ontology labels.""" |
| self._load() |
| self._ontology_labels = labels |
| embeddings = self._model.encode(labels, convert_to_tensor=True) |
| for label, emb in zip(labels, embeddings): |
| self._ontology_embeddings[label] = emb |
| logger.info(f"[SEMANTIC] Indexed {len(labels)} ontology labels.") |
|
|
| def normalize(self, query: str, threshold: float = 0.45) -> List[Tuple[str, float]]: |
| """Maps a query string to the closest ontology labels.""" |
| if not query or not self._ontology_labels: |
| return [] |
| |
| self._load() |
| from sentence_transformers import util |
| |
| query_emb = self._model.encode(query, convert_to_tensor=True) |
| results = [] |
| |
| |
| for label, label_emb in self._ontology_embeddings.items(): |
| score = util.cos_sim(query_emb, label_emb).item() |
| if score >= threshold: |
| results.append((label, score)) |
| |
| |
| results.sort(key=lambda x: x[1], reverse=True) |
| return results |
|
|
| |
| semantic_normalizer = SemanticNormalizer() |
|
|