fahmiaziz98 commited on
Commit
dd7d594
·
1 Parent(s): 17491dc

update: Response format

Browse files
src/api/routers/embedding.py CHANGED
@@ -7,6 +7,7 @@ multiple texts in a single request.
7
 
8
  import time
9
  from fastapi import APIRouter, Depends, HTTPException, status
 
10
  from loguru import logger
11
 
12
  from src.models.schemas import (
@@ -14,8 +15,6 @@ from src.models.schemas import (
14
  DenseEmbedResponse,
15
  EmbeddingObject,
16
  TokenUsage,
17
- SparseEmbedResponse,
18
- SparseEmbedding,
19
  )
20
  from src.core.manager import ModelManager
21
  from src.core.exceptions import (
@@ -31,16 +30,17 @@ from src.utils.validators import (
31
  ensure_model_type,
32
  )
33
 
34
- router = APIRouter(tags=["embeddings"])
35
 
36
 
37
  @router.post(
38
  "/embeddings",
39
  response_model=DenseEmbedResponse,
 
40
  summary="Generate single/batch embeddings",
41
  description="Generate embeddings for multiple texts in a single request",
42
  )
43
- async def create_embeddings(
44
  request: EmbedRequest, manager: ModelManager = Depends(get_model_manager)
45
  ):
46
  """
@@ -100,6 +100,66 @@ async def create_embeddings(
100
 
101
  return response
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  except (ValidationError, ModelNotFoundError) as e:
104
  raise HTTPException(status_code=e.status_code, detail=e.message)
105
  except ModelNotLoadedError as e:
@@ -116,7 +176,7 @@ async def create_embeddings(
116
 
117
  @router.post(
118
  "/embed_sparse",
119
- response_model=SparseEmbedResponse,
120
  summary="Generate single/batch sparse embeddings",
121
  description="Generate embedding for a multiple query text",
122
  )
@@ -151,28 +211,18 @@ async def create_sparse_embedding(
151
 
152
  sparse_results = model.embed(input=texts, **kwargs)
153
  processing_time = time.time() - start_time
154
-
155
- sparse_embeddings = [
156
- SparseEmbedding(
157
- text=texts[idx],
158
- indices=sparse_result["indices"],
159
- values=sparse_result["values"],
160
- )
161
- for idx, sparse_result in enumerate(sparse_results)
162
  ]
163
 
164
- response = SparseEmbedResponse(
165
- embeddings=sparse_embeddings,
166
- count=len(sparse_embeddings),
167
- model=request.model,
168
- )
169
-
170
  logger.info(
171
  f"Generated {len(texts)} embeddings "
172
  f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
173
  )
174
 
175
- return response
176
 
177
  except (ValidationError, ModelNotFoundError) as e:
178
  raise HTTPException(status_code=e.status_code, detail=e.message)
 
7
 
8
  import time
9
  from fastapi import APIRouter, Depends, HTTPException, status
10
+ from fastapi.responses import JSONResponse
11
  from loguru import logger
12
 
13
  from src.models.schemas import (
 
15
  DenseEmbedResponse,
16
  EmbeddingObject,
17
  TokenUsage,
 
 
18
  )
19
  from src.core.manager import ModelManager
20
  from src.core.exceptions import (
 
30
  ensure_model_type,
31
  )
32
 
33
+ router = APIRouter()
34
 
35
 
36
  @router.post(
37
  "/embeddings",
38
  response_model=DenseEmbedResponse,
39
+ tags=["OpenAI Compatible"],
40
  summary="Generate single/batch embeddings",
41
  description="Generate embeddings for multiple texts in a single request",
42
  )
43
+ async def create_openai_embeddings(
44
  request: EmbedRequest, manager: ModelManager = Depends(get_model_manager)
45
  ):
46
  """
 
100
 
101
  return response
102
 
103
+ except (ValidationError, ModelNotFoundError) as e:
104
+ raise HTTPException(status_code=e.status_code, detail=e.message)
105
+ except ModelNotLoadedError as e:
106
+ raise HTTPException(status_code=e.status_code, detail=e.message)
107
+ except EmbeddingGenerationError as e:
108
+ raise HTTPException(status_code=e.status_code, detail=e.message)
109
+ except Exception as e:
110
+ logger.exception("Unexpected error in create_openai_embeddings")
111
+ raise HTTPException(
112
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
113
+ detail=f"Failed to create embeddings: {str(e)}",
114
+ )
115
+
116
+
117
+ @router.post(
118
+ "/embed",
119
+ tags=["embeddings"],
120
+ summary="Generate single/batch dense embeddings",
121
+ description="Generate embedding for a multiple query text",
122
+ )
123
+ async def create_embeddings(
124
+ request: EmbedRequest, manager: ModelManager = Depends(get_model_manager)
125
+ ):
126
+ """
127
+ Generate embeddings for multiple texts.
128
+
129
+ The endpoint validates the request, checks that the requested
130
+ model is a dense embedding model, and returns a
131
+ :class:`DenseEmbedResponse`.
132
+
133
+ Raises:
134
+ HTTPException: On validation or generation errors
135
+ """
136
+
137
+ texts = [request.input] if isinstance(request.input, str) else request.input
138
+
139
+ if not texts or not isinstance(texts, list):
140
+ raise ValidationError("Input must be a non-empty list or string.")
141
+
142
+ try:
143
+ kwargs = extract_embedding_kwargs(request)
144
+
145
+ model = manager.get_model(request.model)
146
+ config = manager.model_configs.get(request.model)
147
+
148
+ ensure_model_type(config, "embeddings", request.model)
149
+
150
+ start_time = time.time()
151
+
152
+ embeddings = model.embed(input=texts, **kwargs)
153
+ processing_time = time.time() - start_time
154
+
155
+
156
+ logger.info(
157
+ f"Generated {len(texts)} embeddings "
158
+ f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
159
+ )
160
+
161
+ return JSONResponse(content=embeddings)
162
+
163
  except (ValidationError, ModelNotFoundError) as e:
164
  raise HTTPException(status_code=e.status_code, detail=e.message)
165
  except ModelNotLoadedError as e:
 
176
 
177
  @router.post(
178
  "/embed_sparse",
179
+ tags=["embeddings"],
180
  summary="Generate single/batch sparse embeddings",
181
  description="Generate embedding for a multiple query text",
182
  )
 
211
 
212
  sparse_results = model.embed(input=texts, **kwargs)
213
  processing_time = time.time() - start_time
214
+
215
+ formatted_embeddings = [
216
+ [{"index": i, "value": v} for i, v in zip(res["indices"], res["values"])]
217
+ for res in sparse_results
 
 
 
 
218
  ]
219
 
 
 
 
 
 
 
220
  logger.info(
221
  f"Generated {len(texts)} embeddings "
222
  f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
223
  )
224
 
225
+ return JSONResponse(content=formatted_embeddings)
226
 
227
  except (ValidationError, ModelNotFoundError) as e:
228
  raise HTTPException(status_code=e.status_code, detail=e.message)
src/models/schemas/__init__.py CHANGED
@@ -6,7 +6,6 @@ the application.
6
  """
7
 
8
  from .common import (
9
- SparseEmbedding,
10
  ModelInfo,
11
  HealthStatus,
12
  ErrorResponse,
@@ -18,7 +17,6 @@ from .requests import BaseEmbedRequest, EmbedRequest, RerankRequest
18
  from .responses import (
19
  BaseEmbedResponse,
20
  DenseEmbedResponse,
21
- SparseEmbedResponse,
22
  RerankResponse,
23
  EmbeddingObject,
24
  TokenUsage,
@@ -29,7 +27,6 @@ from .responses import (
29
 
30
  __all__ = [
31
  # Common
32
- "SparseEmbedding",
33
  "ModelInfo",
34
  "HealthStatus",
35
  "ErrorResponse",
@@ -43,7 +40,6 @@ __all__ = [
43
  "DenseEmbedResponse",
44
  "EmbeddingObject",
45
  "TokenUsage",
46
- "SparseEmbedResponse",
47
  "RerankResponse",
48
  "RerankResult",
49
  "ModelsListResponse",
 
6
  """
7
 
8
  from .common import (
 
9
  ModelInfo,
10
  HealthStatus,
11
  ErrorResponse,
 
17
  from .responses import (
18
  BaseEmbedResponse,
19
  DenseEmbedResponse,
 
20
  RerankResponse,
21
  EmbeddingObject,
22
  TokenUsage,
 
27
 
28
  __all__ = [
29
  # Common
 
30
  "ModelInfo",
31
  "HealthStatus",
32
  "ErrorResponse",
 
40
  "DenseEmbedResponse",
41
  "EmbeddingObject",
42
  "TokenUsage",
 
43
  "RerankResponse",
44
  "RerankResult",
45
  "ModelsListResponse",
src/models/schemas/common.py CHANGED
@@ -5,40 +5,10 @@ This module contains Pydantic models used by both requests and responses,
5
  such as SparseEmbedding and ModelInfo.
6
  """
7
 
8
- from typing import List, Optional, Literal
9
  from pydantic import BaseModel, Field, ConfigDict
10
 
11
 
12
- class SparseEmbedding(BaseModel):
13
- """
14
- Sparse embedding representation.
15
-
16
- Sparse embeddings are represented as two parallel arrays:
17
- - indices: positions of non-zero values
18
- - values: the actual values at those positions
19
-
20
- Attributes:
21
- indices: List of indices for non-zero elements
22
- values: List of values corresponding to the indices
23
- text: Optional original text that was embedded
24
- """
25
-
26
- indices: List[int] = Field(
27
- ..., description="Indices of non-zero elements in the sparse vector"
28
- )
29
- values: List[float] = Field(..., description="Values corresponding to the indices")
30
- text: Optional[str] = Field(None, description="Original text that was embedded")
31
-
32
- class Config:
33
- json_schema_extra = {
34
- "example": {
35
- "indices": [10, 25, 42, 100],
36
- "values": [0.85, 0.62, 0.91, 0.73],
37
- "text": "example query text",
38
- }
39
- }
40
-
41
-
42
  class ModelInfo(BaseModel):
43
  """
44
  Information about an available model.
 
5
  such as SparseEmbedding and ModelInfo.
6
  """
7
 
8
+ from typing import Optional, Literal
9
  from pydantic import BaseModel, Field, ConfigDict
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ModelInfo(BaseModel):
13
  """
14
  Information about an available model.
src/models/schemas/responses.py CHANGED
@@ -7,7 +7,7 @@ ensuring consistent output format across all endpoints.
7
 
8
  from typing import List, Literal
9
  from pydantic import BaseModel, Field
10
- from .common import SparseEmbedding, ModelInfo
11
 
12
 
13
  class BaseEmbedResponse(BaseModel):
@@ -68,45 +68,6 @@ class DenseEmbedResponse(BaseEmbedResponse):
68
  }
69
 
70
 
71
- class SparseEmbedResponse(BaseEmbedResponse):
72
- """
73
- Response model for single/batch sparse embeddings.
74
-
75
- Used for /embed_sparse endpoint sparse models.
76
-
77
- Attributes:
78
- embeddings: List of generated sparse embeddings
79
- count: Number of embeddings returned
80
- model: Identifier of the model used
81
- """
82
-
83
- embeddings: List[SparseEmbedding] = Field(
84
- ..., description="List of sparse embeddings"
85
- )
86
- count: int = Field(..., description="Number of embeddings", ge=1)
87
-
88
- class Config:
89
- json_schema_extra = {
90
- "example": {
91
- "embeddings": [
92
- {
93
- "indices": [10, 25, 42],
94
- "values": [0.85, 0.62, 0.91],
95
- "text": "first text",
96
- },
97
- {
98
- "indices": [15, 30, 50],
99
- "values": [0.73, 0.88, 0.65],
100
- "text": "second text",
101
- },
102
- ],
103
- "count": 2,
104
- "model_id": "splade-pp-v2",
105
- "processing_time": 0.0892,
106
- }
107
- }
108
-
109
-
110
  class RerankResult(BaseModel):
111
  """
112
  Single reranking result.
 
7
 
8
  from typing import List, Literal
9
  from pydantic import BaseModel, Field
10
+ from .common import ModelInfo
11
 
12
 
13
  class BaseEmbedResponse(BaseModel):
 
68
  }
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  class RerankResult(BaseModel):
72
  """
73
  Single reranking result.