fahmiaziz98 commited on
Commit
fa16bad
·
1 Parent(s): bc2efca

Add query endpoint for embedding and refactor embedding models

Browse files
Files changed (5) hide show
  1. app.py +66 -2
  2. core/config.py +9 -0
  3. core/embedding.py +17 -98
  4. core/model_manager.py +4 -11
  5. core/sparse.py +125 -0
app.py CHANGED
@@ -85,6 +85,70 @@ def create_app() -> FastAPI:
85
 
86
  app = create_app()
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  @app.post("/embed", response_model=Union[EmbedResponse, SparseEmbedResponse])
90
  async def create_embedding(request: EmbedRequest):
@@ -117,7 +181,7 @@ async def create_embedding(request: EmbedRequest):
117
 
118
  if config.type == "sparse-embeddings":
119
  # Sparse embedding
120
- sparse_result = model.embed(request.text, prompt=request.prompt)
121
  processing_time = time.time() - start_time
122
 
123
  if isinstance(sparse_result, dict) and "indices" in sparse_result:
@@ -136,7 +200,7 @@ async def create_embedding(request: EmbedRequest):
136
  )
137
 
138
  # Dense embedding
139
- embedding = model.embed([request.text], request.prompt)[0]
140
  processing_time = time.time() - start_time
141
 
142
  return EmbedResponse(
 
85
 
86
  app = create_app()
87
 
88
+ @app.post("/query", response_model=Union[EmbedResponse, SparseEmbedResponse])
89
+ async def create_query(request: EmbedRequest):
90
+ """Create a single dense or sparse query embedding for the given text.
91
+
92
+ The request must include `model_id`. For sparse models (config type
93
+ "sparse-embeddings") the endpoint returns a `SparseEmbedResponse`,
94
+ otherwise a dense `EmbedResponse` is returned.
95
+
96
+ Args:
97
+ request: `EmbedRequest` pydantic model with text, prompt and model_id.
98
+ Returns:
99
+ Union[EmbedResponse, SparseEmbedResponse]: The embedding response.
100
+ Raises:
101
+ HTTPException: on validation or internal errors with appropriate
102
+ HTTP status codes.
103
+ """
104
+
105
+ if not request.model_id:
106
+ raise HTTPException(status_code=400, detail="model_id is required")
107
+
108
+ try:
109
+ assert model_manager is not None
110
+ model = model_manager.get_model(request.model_id)
111
+ start_time = time.time()
112
+
113
+ config = model_manager.model_configs[request.model_id]
114
+
115
+ if config.type == "sparse-embeddings":
116
+ # Sparse embedding
117
+ sparse_result = model.query_embed(text=[request.text], prompt=request.prompt)
118
+ processing_time = time.time() - start_time
119
+
120
+ if isinstance(sparse_result, dict) and "indices" in sparse_result:
121
+ sparse_embedding = SparseEmbedding(
122
+ text=request.text,
123
+ indices=sparse_result["indices"],
124
+ values=sparse_result["values"],
125
+ )
126
+ else:
127
+ raise ValueError(f"Unexpected sparse result format: {sparse_result}")
128
+
129
+ return SparseEmbedResponse(
130
+ sparse_embedding=sparse_embedding,
131
+ model_id=request.model_id,
132
+ processing_time=processing_time,
133
+ )
134
+
135
+ # Dense embedding
136
+ embedding = model.query_embed(text=[request.text], prompt=request.prompt)[0]
137
+ processing_time = time.time() - start_time
138
+
139
+ return EmbedResponse(
140
+ embedding=embedding,
141
+ dimension=len(embedding),
142
+ model_id=request.model_id,
143
+ processing_time=processing_time,
144
+ )
145
+
146
+ except AssertionError:
147
+ logger.exception("Model manager is not initialized")
148
+ raise HTTPException(status_code=500, detail="Server not ready")
149
+ except Exception:
150
+ logger.exception("Error creating query embedding")
151
+ raise HTTPException(status_code=500, detail="Failed to create query embedding")
152
 
153
  @app.post("/embed", response_model=Union[EmbedResponse, SparseEmbedResponse])
154
  async def create_embedding(request: EmbedRequest):
 
181
 
182
  if config.type == "sparse-embeddings":
183
  # Sparse embedding
184
+ sparse_result = model.embed_documents(text=[request.text], prompt=request.prompt)
185
  processing_time = time.time() - start_time
186
 
187
  if isinstance(sparse_result, dict) and "indices" in sparse_result:
 
200
  )
201
 
202
  # Dense embedding
203
+ embedding = model.embed_documents(text=[request.text], prompt=request.prompt)[0]
204
  processing_time = time.time() - start_time
205
 
206
  return EmbedResponse(
core/config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+
4
+ class ModelConfig:
5
+ def __init__(self, model_id: str, config: Dict[str, Any]):
6
+ self.id = model_id
7
+ self.name = config["name"]
8
+ self.type = config["type"]
9
+ self.repository = config["repository"]
core/embedding.py CHANGED
@@ -1,19 +1,9 @@
1
- from loguru import logger
2
- from typing import Dict, List, Optional, Any
3
  from sentence_transformers import SentenceTransformer
4
- from sentence_transformers import SparseEncoder
5
 
 
6
 
7
- class ModelConfig:
8
- def __init__(self, model_id: str, config: Dict[str, Any]):
9
- self.id = model_id
10
- self.name = config["name"]
11
- self.type = config["type"] # "embedding" or "sparse"
12
- self.dimension = int(config["dimension"])
13
- self.max_tokens = int(config["max_tokens"])
14
- self.description = config["description"]
15
- self.language = config["language"]
16
- self.repository = config["repository"]
17
 
18
  class EmbeddingModel:
19
  """
@@ -43,115 +33,44 @@ class EmbeddingModel:
43
  logger.error(f"Failed to load embedding model {self.config.id}: {e}")
44
  raise
45
 
46
- def embed(self, texts: List[str], prompt: Optional[str] = None) -> List[List[float]]:
47
  """
48
- method to generate embeddings for a list of texts.
49
 
50
  Args:
51
- texts: List of input texts
52
  prompt: Optional prompt for instruction-based models
53
 
54
- Returns:
55
- List of embedding vectors
56
  """
57
  if not self._loaded:
58
  self.load()
59
 
60
  try:
61
- embeddings = self.model.encode(texts, prompt=prompt)
62
- return [embedding.tolist() for embedding in embeddings]
63
  except Exception as e:
64
  logger.error(f"Embedding generation failed: {e}")
65
  raise
66
 
67
- class SparseEmbeddingModel:
68
- """
69
- Sparse embedding model wrapper.
70
-
71
- Attributes:
72
- config: ModelConfig instance
73
- model: SparseEncoder instance
74
- _loaded: Flag indicating if the model is loaded
75
- """
76
-
77
- def __init__(self, config: ModelConfig):
78
- self.config = config
79
- self.model: Optional[SparseEncoder] = None
80
- self._loaded = False
81
-
82
- def load(self) -> None:
83
- """Load the sparse embedding model."""
84
- if self._loaded:
85
- return
86
-
87
- logger.info(f"Loading sparse model: {self.config.name}")
88
- try:
89
- self.model = SparseEncoder(self.config.name)
90
- self._loaded = True
91
- logger.success(f"Loaded sparse model: {self.config.id}")
92
- except Exception as e:
93
- logger.error(f"Failed to load sparse model {self.config.id}: {e}")
94
- raise
95
-
96
- def embed(self, text: str, prompt: Optional[str] = None) -> Dict[Any, Any]:
97
- """
98
- Generate a sparse embedding for a single text.
99
-
100
- Args:
101
- text: Input text
102
- prompt: Optional prompt for instruction-based models
103
-
104
- Returns:
105
- Sparse embedding as a dictionary with 'indices' and 'values' keys.
106
- """
107
-
108
- try:
109
- tensor = self.model.encode([text])
110
-
111
- values = tensor[0].coalesce().values().tolist()
112
- indices = tensor[0].coalesce().indices()[0].tolist()
113
-
114
- return {
115
- "indices": indices,
116
- "values": values
117
- }
118
-
119
-
120
- except Exception as e:
121
- logger.error(f"Embedding error: {e}")
122
- raise
123
-
124
- def embed_batch(self, texts: List[str], prompt: Optional[str] = None) -> List[Dict[str, Any]]:
125
  """
126
- Generate sparse embeddings for a batch of texts.
127
 
128
  Args:
129
  texts: List of input texts
130
  prompt: Optional prompt for instruction-based models
131
 
132
- Returns:
133
- List of sparse embeddings as dictionaries with 'text' and 'sparse_embedding' keys.
134
  """
135
  if not self._loaded:
136
  self.load()
137
 
138
  try:
139
- tensors = self.model.encode(texts)
140
- results = []
141
-
142
- for i, tensor in enumerate(tensors):
143
- values = tensor.coalesce().values().tolist()
144
- indices = tensor.coalesce().indices()[0].tolist()
145
-
146
- results.append({
147
- "text": texts[i],
148
- "sparse_embedding": {
149
- "indices": indices,
150
- "values": values
151
- }
152
- })
153
-
154
- return results
155
  except Exception as e:
156
- logger.error(f"Sparse embedding generation failed: {e}")
157
  raise
 
1
+ from typing import List, Optional
 
2
  from sentence_transformers import SentenceTransformer
3
+ from loguru import logger
4
 
5
+ from .config import ModelConfig
6
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class EmbeddingModel:
9
  """
 
33
  logger.error(f"Failed to load embedding model {self.config.id}: {e}")
34
  raise
35
 
36
+ def query_embed(self, text: List[str], prompt: Optional[str] = None) -> List[float]:
37
  """
38
+ method to generate embedding for a single text.
39
 
40
  Args:
41
+ text: Input text
42
  prompt: Optional prompt for instruction-based models
43
 
44
+ Returns:
45
+ Embedding vector
46
  """
47
  if not self._loaded:
48
  self.load()
49
 
50
  try:
51
+ embedding = self.model.encode_query(text, prompt=prompt)
52
+ return embedding[0].tolist()
53
  except Exception as e:
54
  logger.error(f"Embedding generation failed: {e}")
55
  raise
56
 
57
+ def embed_documents(self, texts: List[str], prompt: Optional[str] = None) -> List[List[float]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  """
59
+ method to generate embeddings for a list of texts.
60
 
61
  Args:
62
  texts: List of input texts
63
  prompt: Optional prompt for instruction-based models
64
 
65
+ Returns:
66
+ List of embedding vectors
67
  """
68
  if not self._loaded:
69
  self.load()
70
 
71
  try:
72
+ embeddings = self.model.encode_document(texts, prompt=prompt)
73
+ return [embedding.tolist() for embedding in embeddings]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
+ logger.error(f"Embedding generation failed: {e}")
76
  raise
core/model_manager.py CHANGED
@@ -1,9 +1,11 @@
1
  import yaml
2
  from pathlib import Path
3
  from loguru import logger
4
- from typing import Dict, List, Any, Optional, Union
5
  from threading import Lock
6
- from .embedding import ModelConfig, EmbeddingModel, SparseEmbeddingModel
 
 
7
 
8
  class ModelManager:
9
  """
@@ -20,7 +22,6 @@ class ModelManager:
20
  def __init__(self, config_path: str = "config.yaml"):
21
  self.models: Dict[str, Union[EmbeddingModel, SparseEmbeddingModel]] = {}
22
  self.model_configs: Dict[str, ModelConfig] = {}
23
- self.default_model_id: Optional[str] = None
24
  self._lock = Lock() # For thread safety
25
  self._preload_complete = False
26
 
@@ -40,9 +41,6 @@ class ModelManager:
40
  for model_id, model_cfg in config["models"].items():
41
  self.model_configs[model_id] = ModelConfig(model_id, model_cfg)
42
 
43
- if "default" in config and "model" in config["default"]:
44
- self.default_model_id = config["default"]["model"]
45
-
46
  logger.info(f"Loaded {len(self.model_configs)} model configurations")
47
 
48
  except Exception as e:
@@ -140,10 +138,6 @@ class ModelManager:
140
  "id": config.id,
141
  "name": config.name,
142
  "type": config.type,
143
- "dimension": config.dimension,
144
- "max_tokens": config.max_tokens,
145
- "description": config.description,
146
- "language": config.language,
147
  "loaded": is_loaded,
148
  "repository": config.repository,
149
  }
@@ -210,7 +204,6 @@ High-performance API for generating text embeddings using multiple model archite
210
  loaded_models.append({
211
  "id": model_id,
212
  "type": self.model_configs[model_id].type,
213
- "dimension": model.config.dimension,
214
  "name": model.config.name
215
  })
216
 
 
1
  import yaml
2
  from pathlib import Path
3
  from loguru import logger
4
+ from typing import Dict, List, Any, Union
5
  from threading import Lock
6
+ from .embedding import EmbeddingModel
7
+ from .sparse import SparseEmbeddingModel
8
+ from .config import ModelConfig
9
 
10
  class ModelManager:
11
  """
 
22
  def __init__(self, config_path: str = "config.yaml"):
23
  self.models: Dict[str, Union[EmbeddingModel, SparseEmbeddingModel]] = {}
24
  self.model_configs: Dict[str, ModelConfig] = {}
 
25
  self._lock = Lock() # For thread safety
26
  self._preload_complete = False
27
 
 
41
  for model_id, model_cfg in config["models"].items():
42
  self.model_configs[model_id] = ModelConfig(model_id, model_cfg)
43
 
 
 
 
44
  logger.info(f"Loaded {len(self.model_configs)} model configurations")
45
 
46
  except Exception as e:
 
138
  "id": config.id,
139
  "name": config.name,
140
  "type": config.type,
 
 
 
 
141
  "loaded": is_loaded,
142
  "repository": config.repository,
143
  }
 
204
  loaded_models.append({
205
  "id": model_id,
206
  "type": self.model_configs[model_id].type,
 
207
  "name": model.config.name
208
  })
209
 
core/sparse.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+ from sentence_transformers import SparseEncoder
3
+ from loguru import logger
4
+
5
+ from .config import ModelConfig
6
+
7
+
8
+ class SparseEmbeddingModel:
9
+ """
10
+ Sparse embedding model wrapper.
11
+
12
+ Attributes:
13
+ config: ModelConfig instance
14
+ model: SparseEncoder instance
15
+ _loaded: Flag indicating if the model is loaded
16
+ """
17
+
18
+ def __init__(self, config: ModelConfig):
19
+ self.config = config
20
+ self.model: Optional[SparseEncoder] = None
21
+ self._loaded = False
22
+
23
+ def load(self) -> None:
24
+ """Load the sparse embedding model."""
25
+ if self._loaded:
26
+ return
27
+
28
+ logger.info(f"Loading sparse model: {self.config.name}")
29
+ try:
30
+ self.model = SparseEncoder(self.config.name)
31
+ self._loaded = True
32
+ logger.success(f"Loaded sparse model: {self.config.id}")
33
+ except Exception as e:
34
+ logger.error(f"Failed to load sparse model {self.config.id}: {e}")
35
+ raise
36
+
37
+ def query_embed(self, text: List[str], prompt: Optional[str] = None) -> Dict[Any, Any]:
38
+ """
39
+ Generate a sparse embedding for a single text.
40
+
41
+ Args:
42
+ text: Input text
43
+ prompt: Optional prompt for instruction-based models
44
+ Returns:
45
+ Sparse embedding as a dictionary with 'indices' and 'values' keys.
46
+ """
47
+ if not self._loaded:
48
+ self.load()
49
+
50
+ try:
51
+ tensor = self.model.encode_query(text)
52
+
53
+ values = tensor[0].coalesce().values().tolist()
54
+ indices = tensor[0].coalesce().indices()[0].tolist()
55
+
56
+ return {
57
+ "indices": indices,
58
+ "values": values
59
+ }
60
+ except Exception as e:
61
+ logger.error(f"Embedding error: {e}")
62
+ raise
63
+
64
+ def embed_documents(self, text: List[str], prompt: Optional[str] = None) -> Dict[Any, Any]:
65
+ """
66
+ Generate a sparse embedding for a single text.
67
+
68
+ Args:
69
+ text: Input text
70
+ prompt: Optional prompt for instruction-based models
71
+
72
+ Returns:
73
+ Sparse embedding as a dictionary with 'indices' and 'values' keys.
74
+ """
75
+
76
+ try:
77
+ tensor = self.model.encode(text)
78
+
79
+ values = tensor[0].coalesce().values().tolist()
80
+ indices = tensor[0].coalesce().indices()[0].tolist()
81
+
82
+ return {
83
+ "indices": indices,
84
+ "values": values
85
+ }
86
+
87
+
88
+ except Exception as e:
89
+ logger.error(f"Embedding error: {e}")
90
+ raise
91
+
92
+ def embed_batch(self, texts: List[str], prompt: Optional[str] = None) -> List[Dict[str, Any]]:
93
+ """
94
+ Generate sparse embeddings for a batch of texts.
95
+
96
+ Args:
97
+ texts: List of input texts
98
+ prompt: Optional prompt for instruction-based models
99
+
100
+ Returns:
101
+ List of sparse embeddings as dictionaries with 'text' and 'sparse_embedding' keys.
102
+ """
103
+ if not self._loaded:
104
+ self.load()
105
+
106
+ try:
107
+ tensors = self.model.encode(texts)
108
+ results = []
109
+
110
+ for i, tensor in enumerate(tensors):
111
+ values = tensor.coalesce().values().tolist()
112
+ indices = tensor.coalesce().indices()[0].tolist()
113
+
114
+ results.append({
115
+ "text": texts[i],
116
+ "sparse_embedding": {
117
+ "indices": indices,
118
+ "values": values
119
+ }
120
+ })
121
+
122
+ return results
123
+ except Exception as e:
124
+ logger.error(f"Sparse embedding generation failed: {e}")
125
+ raise