File size: 5,236 Bytes
3193174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""NodeEncoder for converting agent descriptions into embeddings."""

import hashlib
import importlib
import importlib.util
import re
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import torch
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator

if TYPE_CHECKING:
    from sentence_transformers import SentenceTransformer
else:
    SentenceTransformer = Any

__all__ = ["NodeEncoder"]

_TOKEN_RE = re.compile(r"[\w']+")
_HASH_PROVIDER = "hash"
_HASH_PREFIX = f"{_HASH_PROVIDER}:"
_SENTENCE_TRANSFORMERS_PREFIXES = ("sentence-transformers/", "sentence-transformers:")


def _tokenize(text: str) -> list[str]:
    """Split text into tokens (words and numbers) in lower case."""
    return _TOKEN_RE.findall(text.lower())


class NodeEncoder(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    normalize_embeddings: bool = True
    fallback_dim: int = 384
    _model: SentenceTransformer | None = PrivateAttr(default=None)
    _provider: str = PrivateAttr(default="sentence-transformers")

    @field_validator("model_name")
    @classmethod
    def validate_model_name(cls, v: str) -> str:
        """Validate the model name."""
        if v == _HASH_PROVIDER or v.startswith(_HASH_PREFIX):
            if v.startswith(_HASH_PREFIX):
                _, _, raw_dim = v.partition(":")
                if not raw_dim.isdigit():
                    msg = f"Hash embedding dimension must be numeric, got {raw_dim!r}"
                    raise ValueError(msg)
                if int(raw_dim) < 1:
                    msg = f"Hash embedding dimension must be positive, got {raw_dim}"
                    raise ValueError(msg)
            return v

        if any(v.startswith(prefix) for prefix in _SENTENCE_TRANSFORMERS_PREFIXES):
            parts = v.split("/", 1) if "/" in v else v.split(":", 1)
            parts_expected = 2
            if len(parts) == parts_expected and parts[1].strip():
                return v
            msg = f"SentenceTransformer specification '{v}' is missing the model identifier"
            raise ValueError(msg)

        msg = "Unsupported embedding model. Expected 'sentence-transformers/<model>' or 'hash[:<dim>]'"
        raise ValueError(msg)

    def model_post_init(self, __context, /) -> None:
        """Determine the provider (hash or sentence-transformers) and fallback dim."""
        if self.model_name == _HASH_PROVIDER or self.model_name.startswith(_HASH_PREFIX):
            self._provider = _HASH_PROVIDER
            if self.model_name.startswith(_HASH_PREFIX):
                _, _, raw_dim = self.model_name.partition(":")
                self.fallback_dim = max(int(raw_dim), 32)
        else:
            self._provider = "sentence-transformers"

    def encode(self, texts: Sequence[str]) -> torch.Tensor:
        """Encode a list of texts into embeddings."""
        cleaned = [text.strip() if isinstance(text, str) else "" for text in texts]
        if not cleaned:
            return torch.zeros((0, 0), dtype=torch.float32)

        if self._provider == _HASH_PROVIDER:
            return self._hash_fallback(cleaned)

        model = self._load_model()
        if model is None:
            return self._hash_fallback(cleaned)

        embeddings = model.encode(
            cleaned,
            convert_to_tensor=True,
            normalize_embeddings=self.normalize_embeddings,
        )
        return embeddings.to(dtype=torch.float32)

    def _load_model(self) -> Any:
        """Lazily load SentenceTransformer if it is available."""
        if self._provider != "sentence-transformers":
            return None

        if self._model is not None:
            return self._model

        if importlib.util.find_spec("sentence_transformers") is None:
            self._provider = _HASH_PROVIDER
            return None

        module = importlib.import_module("sentence_transformers")
        self._model = module.SentenceTransformer(self.model_name)

        return self._model

    def _hash_fallback(self, texts: Sequence[str]) -> torch.Tensor:
        """Build normalized bag-of-words embeddings using the hash trick."""
        dimension = max(self.fallback_dim, 32)
        matrix = torch.zeros((len(texts), dimension), dtype=torch.float32)

        for row, text in enumerate(texts):
            tokens = _tokenize(text)
            if not tokens:
                continue
            for token in tokens:
                digest = hashlib.blake2b(token.encode("utf-8"), digest_size=32).digest()
                index = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension
                matrix[row, index] += 1.0

            norm = torch.norm(matrix[row])
            if norm > 0:
                matrix[row] /= norm

        return matrix

    @property
    def embedding_dim(self) -> int:
        """Dimension of embeddings generated by the selected provider."""
        if self._provider == _HASH_PROVIDER:
            return self.fallback_dim

        model = self._load_model()
        if model is not None:
            return model.get_sentence_embedding_dimension()

        return self.fallback_dim