Spaces:
Running
Running
File size: 5,236 Bytes
3193174 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | """NodeEncoder for converting agent descriptions into embeddings."""
import hashlib
import importlib
import importlib.util
import re
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
import torch
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
else:
SentenceTransformer = Any
__all__ = ["NodeEncoder"]
_TOKEN_RE = re.compile(r"[\w']+")
_HASH_PROVIDER = "hash"
_HASH_PREFIX = f"{_HASH_PROVIDER}:"
_SENTENCE_TRANSFORMERS_PREFIXES = ("sentence-transformers/", "sentence-transformers:")
def _tokenize(text: str) -> list[str]:
"""Split text into tokens (words and numbers) in lower case."""
return _TOKEN_RE.findall(text.lower())
class NodeEncoder(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
normalize_embeddings: bool = True
fallback_dim: int = 384
_model: SentenceTransformer | None = PrivateAttr(default=None)
_provider: str = PrivateAttr(default="sentence-transformers")
@field_validator("model_name")
@classmethod
def validate_model_name(cls, v: str) -> str:
"""Validate the model name."""
if v == _HASH_PROVIDER or v.startswith(_HASH_PREFIX):
if v.startswith(_HASH_PREFIX):
_, _, raw_dim = v.partition(":")
if not raw_dim.isdigit():
msg = f"Hash embedding dimension must be numeric, got {raw_dim!r}"
raise ValueError(msg)
if int(raw_dim) < 1:
msg = f"Hash embedding dimension must be positive, got {raw_dim}"
raise ValueError(msg)
return v
if any(v.startswith(prefix) for prefix in _SENTENCE_TRANSFORMERS_PREFIXES):
parts = v.split("/", 1) if "/" in v else v.split(":", 1)
parts_expected = 2
if len(parts) == parts_expected and parts[1].strip():
return v
msg = f"SentenceTransformer specification '{v}' is missing the model identifier"
raise ValueError(msg)
msg = "Unsupported embedding model. Expected 'sentence-transformers/<model>' or 'hash[:<dim>]'"
raise ValueError(msg)
def model_post_init(self, __context, /) -> None:
"""Determine the provider (hash or sentence-transformers) and fallback dim."""
if self.model_name == _HASH_PROVIDER or self.model_name.startswith(_HASH_PREFIX):
self._provider = _HASH_PROVIDER
if self.model_name.startswith(_HASH_PREFIX):
_, _, raw_dim = self.model_name.partition(":")
self.fallback_dim = max(int(raw_dim), 32)
else:
self._provider = "sentence-transformers"
def encode(self, texts: Sequence[str]) -> torch.Tensor:
"""Encode a list of texts into embeddings."""
cleaned = [text.strip() if isinstance(text, str) else "" for text in texts]
if not cleaned:
return torch.zeros((0, 0), dtype=torch.float32)
if self._provider == _HASH_PROVIDER:
return self._hash_fallback(cleaned)
model = self._load_model()
if model is None:
return self._hash_fallback(cleaned)
embeddings = model.encode(
cleaned,
convert_to_tensor=True,
normalize_embeddings=self.normalize_embeddings,
)
return embeddings.to(dtype=torch.float32)
def _load_model(self) -> Any:
"""Lazily load SentenceTransformer if it is available."""
if self._provider != "sentence-transformers":
return None
if self._model is not None:
return self._model
if importlib.util.find_spec("sentence_transformers") is None:
self._provider = _HASH_PROVIDER
return None
module = importlib.import_module("sentence_transformers")
self._model = module.SentenceTransformer(self.model_name)
return self._model
def _hash_fallback(self, texts: Sequence[str]) -> torch.Tensor:
"""Build normalized bag-of-words embeddings using the hash trick."""
dimension = max(self.fallback_dim, 32)
matrix = torch.zeros((len(texts), dimension), dtype=torch.float32)
for row, text in enumerate(texts):
tokens = _tokenize(text)
if not tokens:
continue
for token in tokens:
digest = hashlib.blake2b(token.encode("utf-8"), digest_size=32).digest()
index = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension
matrix[row, index] += 1.0
norm = torch.norm(matrix[row])
if norm > 0:
matrix[row] /= norm
return matrix
@property
def embedding_dim(self) -> int:
"""Dimension of embeddings generated by the selected provider."""
if self._provider == _HASH_PROVIDER:
return self.fallback_dim
model = self._load_model()
if model is not None:
return model.get_sentence_embedding_dimension()
return self.fallback_dim
|