"""NodeEncoder for converting agent descriptions into embeddings.""" import hashlib import importlib import importlib.util import re from collections.abc import Sequence from typing import TYPE_CHECKING, Any import torch from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator if TYPE_CHECKING: from sentence_transformers import SentenceTransformer else: SentenceTransformer = Any __all__ = ["NodeEncoder"] _TOKEN_RE = re.compile(r"[\w']+") _HASH_PROVIDER = "hash" _HASH_PREFIX = f"{_HASH_PROVIDER}:" _SENTENCE_TRANSFORMERS_PREFIXES = ("sentence-transformers/", "sentence-transformers:") def _tokenize(text: str) -> list[str]: """Split text into tokens (words and numbers) in lower case.""" return _TOKEN_RE.findall(text.lower()) class NodeEncoder(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) model_name: str = "sentence-transformers/all-MiniLM-L6-v2" normalize_embeddings: bool = True fallback_dim: int = 384 _model: SentenceTransformer | None = PrivateAttr(default=None) _provider: str = PrivateAttr(default="sentence-transformers") @field_validator("model_name") @classmethod def validate_model_name(cls, v: str) -> str: """Validate the model name.""" if v == _HASH_PROVIDER or v.startswith(_HASH_PREFIX): if v.startswith(_HASH_PREFIX): _, _, raw_dim = v.partition(":") if not raw_dim.isdigit(): msg = f"Hash embedding dimension must be numeric, got {raw_dim!r}" raise ValueError(msg) if int(raw_dim) < 1: msg = f"Hash embedding dimension must be positive, got {raw_dim}" raise ValueError(msg) return v if any(v.startswith(prefix) for prefix in _SENTENCE_TRANSFORMERS_PREFIXES): parts = v.split("/", 1) if "/" in v else v.split(":", 1) parts_expected = 2 if len(parts) == parts_expected and parts[1].strip(): return v msg = f"SentenceTransformer specification '{v}' is missing the model identifier" raise ValueError(msg) msg = "Unsupported embedding model. Expected 'sentence-transformers/' or 'hash[:]'" raise ValueError(msg) def model_post_init(self, __context, /) -> None: """Determine the provider (hash or sentence-transformers) and fallback dim.""" if self.model_name == _HASH_PROVIDER or self.model_name.startswith(_HASH_PREFIX): self._provider = _HASH_PROVIDER if self.model_name.startswith(_HASH_PREFIX): _, _, raw_dim = self.model_name.partition(":") self.fallback_dim = max(int(raw_dim), 32) else: self._provider = "sentence-transformers" def encode(self, texts: Sequence[str]) -> torch.Tensor: """Encode a list of texts into embeddings.""" cleaned = [text.strip() if isinstance(text, str) else "" for text in texts] if not cleaned: return torch.zeros((0, 0), dtype=torch.float32) if self._provider == _HASH_PROVIDER: return self._hash_fallback(cleaned) model = self._load_model() if model is None: return self._hash_fallback(cleaned) embeddings = model.encode( cleaned, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings, ) return embeddings.to(dtype=torch.float32) def _load_model(self) -> Any: """Lazily load SentenceTransformer if it is available.""" if self._provider != "sentence-transformers": return None if self._model is not None: return self._model if importlib.util.find_spec("sentence_transformers") is None: self._provider = _HASH_PROVIDER return None module = importlib.import_module("sentence_transformers") self._model = module.SentenceTransformer(self.model_name) return self._model def _hash_fallback(self, texts: Sequence[str]) -> torch.Tensor: """Build normalized bag-of-words embeddings using the hash trick.""" dimension = max(self.fallback_dim, 32) matrix = torch.zeros((len(texts), dimension), dtype=torch.float32) for row, text in enumerate(texts): tokens = _tokenize(text) if not tokens: continue for token in tokens: digest = hashlib.blake2b(token.encode("utf-8"), digest_size=32).digest() index = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension matrix[row, index] += 1.0 norm = torch.norm(matrix[row]) if norm > 0: matrix[row] /= norm return matrix @property def embedding_dim(self) -> int: """Dimension of embeddings generated by the selected provider.""" if self._provider == _HASH_PROVIDER: return self.fallback_dim model = self._load_model() if model is not None: return model.get_sentence_embedding_dimension() return self.fallback_dim