File size: 2,480 Bytes
513d6d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import logging
import torch
import numpy as np
import os
from typing import List, Dict, Any, Tuple
logger = logging.getLogger(__name__)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
class SemanticNormalizer:
"""
Grounds natural language entities into a controlled ontology using embeddings.
Solves the 'sharp object' -> 'knife' problem.
"""
def __init__(self, model_name="all-MiniLM-L6-v2"):
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._model = None
self._ontology_embeddings = {} # label -> embedding
self._ontology_labels = []
def _load(self):
if not self._model:
from sentence_transformers import SentenceTransformer
# Sanctuary for embeddings
cache_dir = os.path.join(BASE_DIR, "mission_models", "LinguisticBackbone")
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"[SEMANTIC] Loading embedding model {self.model_name}...")
self._model = SentenceTransformer(self.model_name, cache_folder=cache_dir, device=self.device)
logger.info("[SEMANTIC] Model loaded.")
def fit_ontology(self, labels: List[str]):
"""Pre-computes embeddings for the ontology labels."""
self._load()
self._ontology_labels = labels
embeddings = self._model.encode(labels, convert_to_tensor=True)
for label, emb in zip(labels, embeddings):
self._ontology_embeddings[label] = emb
logger.info(f"[SEMANTIC] Indexed {len(labels)} ontology labels.")
def normalize(self, query: str, threshold: float = 0.45) -> List[Tuple[str, float]]:
"""Maps a query string to the closest ontology labels."""
if not query or not self._ontology_labels:
return []
self._load()
from sentence_transformers import util
query_emb = self._model.encode(query, convert_to_tensor=True)
results = []
# Calculate similarity with all ontology labels
for label, label_emb in self._ontology_embeddings.items():
score = util.cos_sim(query_emb, label_emb).item()
if score >= threshold:
results.append((label, score))
# Sort by best match
results.sort(key=lambda x: x[1], reverse=True)
return results
# Singleton instance
semantic_normalizer = SemanticNormalizer()
|