AmrYassinIsFree commited on
Commit
7052097
·
1 Parent(s): f56dbf3

add lib embedd models

Browse files
Files changed (3) hide show
  1. models.py +10 -0
  2. requirements.txt +1 -0
  3. wrapper.py +15 -1
models.py CHANGED
@@ -32,4 +32,14 @@ REGISTRY: dict[str, ModelConfig] = {
32
  model_id="sentence-transformers/all-MiniLM-L6-v2",
33
  backend="fastembed",
34
  ),
 
 
 
 
 
 
 
 
 
 
35
  }
 
32
  model_id="sentence-transformers/all-MiniLM-L6-v2",
33
  backend="fastembed",
34
  ),
35
+ "bge-small-le": ModelConfig(
36
+ name="bge-small-en-v1.5 (libembedding)",
37
+ model_id="BAAI/bge-small-en-v1.5",
38
+ backend="libembedding",
39
+ ),
40
+ "all-minilm-le": ModelConfig(
41
+ name="all-MiniLM-L6-v2 (libembedding)",
42
+ model_id="sentence-transformers/all-MiniLM-L6-v2",
43
+ backend="libembedding",
44
+ ),
45
  }
requirements.txt CHANGED
@@ -4,5 +4,6 @@ datasets
4
  psutil
5
  tabulate
6
  fastembed
 
7
  numpy
8
  scipy
 
4
  psutil
5
  tabulate
6
  fastembed
7
+ libembedding
8
  numpy
9
  scipy
wrapper.py CHANGED
@@ -51,10 +51,24 @@ class FastEmbedWrapper:
51
  return np.array(embeddings, dtype=np.float32)
52
 
53
 
54
- def load_model(cfg: ModelConfig) -> SBertWrapper | GGUFWrapper | FastEmbedWrapper:
 
 
 
 
 
 
 
 
 
 
 
 
55
  """Factory: returns the right wrapper for the model's backend."""
56
  if cfg.backend == "gguf":
57
  return GGUFWrapper(cfg)
58
  if cfg.backend == "fastembed":
59
  return FastEmbedWrapper(cfg)
 
 
60
  return SBertWrapper(cfg)
 
51
  return np.array(embeddings, dtype=np.float32)
52
 
53
 
54
+ class LibEmbedWrapper:
55
+ """Wraps libembedding.TextEmbedding."""
56
+
57
+ def __init__(self, cfg: ModelConfig):
58
+ from libembedding import TextEmbedding
59
+ self._model = TextEmbedding(cfg.model_id)
60
+
61
+ def encode(self, sentences: list[str], batch_size: int = 64, **kwargs) -> np.ndarray:
62
+ embeddings = list(self._model.embed(sentences, batch_size=batch_size))
63
+ return np.array(embeddings, dtype=np.float32)
64
+
65
+
66
+ def load_model(cfg: ModelConfig) -> SBertWrapper | GGUFWrapper | FastEmbedWrapper | LibEmbedWrapper:
67
  """Factory: returns the right wrapper for the model's backend."""
68
  if cfg.backend == "gguf":
69
  return GGUFWrapper(cfg)
70
  if cfg.backend == "fastembed":
71
  return FastEmbedWrapper(cfg)
72
+ if cfg.backend == "libembedding":
73
+ return LibEmbedWrapper(cfg)
74
  return SBertWrapper(cfg)