philipp-zettl commited on
Commit
7afed3f
·
verified ·
1 Parent(s): cc18cf3

Add vrom_hub/embedder.py

Browse files
Files changed (1) hide show
  1. vrom_hub/embedder.py +89 -0
vrom_hub/embedder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chunk embedder using sentence-transformers.
3
+
4
+ Produces 384-dimensional normalized embeddings compatible with
5
+ the vROM ecosystem (Xenova/all-MiniLM-L6-v2, cosine metric).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from typing import TYPE_CHECKING
12
+
13
+ import numpy as np
14
+
15
+ if TYPE_CHECKING:
16
+ from vrom_hub.chunker import Chunk
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ChunkEmbedder:
22
+ """
23
+ Embeds chunk text using sentence-transformers/all-MiniLM-L6-v2.
24
+
25
+ The model produces 384-dimensional embeddings. Vectors are L2-normalized
26
+ for cosine similarity (consistent with the WASM runtime).
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
32
+ device: str | None = None,
33
+ batch_size: int = 64,
34
+ ):
35
+ self.model_name = model_name
36
+ self.batch_size = batch_size
37
+ self._model = None
38
+ self._device = device
39
+
40
+ @property
41
+ def model(self):
42
+ if self._model is None:
43
+ from sentence_transformers import SentenceTransformer
44
+ self._model = SentenceTransformer(self.model_name, device=self._device)
45
+ logger.info(
46
+ f"Loaded embedding model: {self.model_name} "
47
+ f"(dim={self._model.get_embedding_dimension()})"
48
+ )
49
+ return self._model
50
+
51
+ @property
52
+ def dimensions(self) -> int:
53
+ return self.model.get_embedding_dimension()
54
+
55
+ def embed_texts(self, texts: list[str]) -> np.ndarray:
56
+ """
57
+ Embed a list of texts.
58
+
59
+ Returns:
60
+ np.ndarray of shape (len(texts), dim) with L2-normalized vectors.
61
+ """
62
+ logger.info(f"Embedding {len(texts)} texts in batches of {self.batch_size}...")
63
+ embeddings = self.model.encode(
64
+ texts,
65
+ batch_size=self.batch_size,
66
+ show_progress_bar=True,
67
+ normalize_embeddings=True, # L2-normalize for cosine
68
+ convert_to_numpy=True,
69
+ )
70
+ logger.info(f"Embeddings shape: {embeddings.shape}")
71
+ return embeddings
72
+
73
+ def embed_chunks(self, chunks: list[Chunk]) -> np.ndarray:
74
+ """
75
+ Embed a list of Chunk objects by their text content.
76
+
77
+ Returns:
78
+ np.ndarray of shape (len(chunks), dim) with L2-normalized vectors.
79
+ """
80
+ texts = [c.text for c in chunks]
81
+ return self.embed_texts(texts)
82
+
83
+ def embed_query(self, query: str) -> np.ndarray:
84
+ """Embed a single query string. Returns shape (dim,)."""
85
+ return self.model.encode(
86
+ [query],
87
+ normalize_embeddings=True,
88
+ convert_to_numpy=True,
89
+ )[0]