fahmiaziz98 commited on
Commit
2e2b2b3
·
1 Parent(s): 376886a

[UPDATE]: Input str or List[str]

Browse files
src/api/routers/embedding.py CHANGED
@@ -27,11 +27,8 @@ from src.core.exceptions import (
27
  from src.api.dependencies import get_model_manager
28
  from src.utils.validators import (
29
  extract_embedding_kwargs,
30
- validate_texts,
31
  count_tokens_batch,
32
  )
33
- from src.config.settings import get_settings
34
-
35
 
36
  router = APIRouter(tags=["embeddings"])
37
 
@@ -64,10 +61,8 @@ def _ensure_model_type(config, expected_type: str, model_id: str) -> None:
64
  summary="Generate single/batch embeddings",
65
  description="Generate embeddings for multiple texts in a single request",
66
  )
67
- async def create_embeddings_document(
68
- request: EmbedRequest,
69
- manager: ModelManager = Depends(get_model_manager),
70
- settings=Depends(get_settings),
71
  ):
72
  """
73
  Generate embeddings for multiple texts.
@@ -79,13 +74,11 @@ async def create_embeddings_document(
79
  Raises:
80
  HTTPException: On validation or generation errors
81
  """
 
 
 
 
82
  try:
83
- # Validate input
84
- validate_texts(
85
- request.input,
86
- max_length=settings.MAX_TEXT_LENGTH,
87
- max_batch_size=settings.MAX_BATCH_SIZE,
88
- )
89
  kwargs = extract_embedding_kwargs(request)
90
 
91
  model = manager.get_model(request.model)
@@ -95,7 +88,7 @@ async def create_embeddings_document(
95
 
96
  start_time = time.time()
97
 
98
- embeddings = model.embed(input=request.input, **kwargs)
99
  processing_time = time.time() - start_time
100
 
101
  data = [
@@ -108,8 +101,8 @@ async def create_embeddings_document(
108
  ]
109
 
110
  token_usage = TokenUsage(
111
- prompt_tokens=count_tokens_batch(request.input),
112
- total_tokens=count_tokens_batch(request.input),
113
  )
114
 
115
  response = DenseEmbedResponse(
@@ -120,8 +113,8 @@ async def create_embeddings_document(
120
  )
121
 
122
  logger.info(
123
- f"Generated {len(request.input)} embeddings "
124
- f"in {processing_time:.3f}s ({len(request.input) / processing_time:.1f} texts/s)"
125
  )
126
 
127
  return response
@@ -133,10 +126,10 @@ async def create_embeddings_document(
133
  except EmbeddingGenerationError as e:
134
  raise HTTPException(status_code=e.status_code, detail=e.message)
135
  except Exception as e:
136
- logger.exception("Unexpected error in create_embeddings_document")
137
  raise HTTPException(
138
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
139
- detail=f"Failed to create batch embeddings: {str(e)}",
140
  )
141
 
142
 
@@ -160,8 +153,10 @@ async def create_sparse_embedding(
160
  Raises:
161
  HTTPException: On validation or generation errors
162
  """
 
 
 
163
  try:
164
- validate_texts(request.input)
165
  kwargs = extract_embedding_kwargs(request)
166
 
167
  model = manager.get_model(request.model)
@@ -171,12 +166,12 @@ async def create_sparse_embedding(
171
 
172
  start_time = time.time()
173
 
174
- sparse_results = model.embed(input=request.input, **kwargs)
175
  processing_time = time.time() - start_time
176
 
177
  sparse_embeddings = [
178
  SparseEmbedding(
179
- text=request.input[idx],
180
  indices=sparse_result["indices"],
181
  values=sparse_result["values"],
182
  )
@@ -190,8 +185,8 @@ async def create_sparse_embedding(
190
  )
191
 
192
  logger.info(
193
- f"Generated {len(request.input)} embeddings "
194
- f"in {processing_time:.3f}s ({len(request.input) / processing_time:.1f} texts/s)"
195
  )
196
 
197
  return response
@@ -206,5 +201,5 @@ async def create_sparse_embedding(
206
  logger.exception("Unexpected error in create_sparse_embedding")
207
  raise HTTPException(
208
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
209
- detail=f"Failed to create query embedding: {str(e)}",
210
  )
 
27
  from src.api.dependencies import get_model_manager
28
  from src.utils.validators import (
29
  extract_embedding_kwargs,
 
30
  count_tokens_batch,
31
  )
 
 
32
 
33
  router = APIRouter(tags=["embeddings"])
34
 
 
61
  summary="Generate single/batch embeddings",
62
  description="Generate embeddings for multiple texts in a single request",
63
  )
64
+ async def create_embeddings(
65
+ request: EmbedRequest, manager: ModelManager = Depends(get_model_manager)
 
 
66
  ):
67
  """
68
  Generate embeddings for multiple texts.
 
74
  Raises:
75
  HTTPException: On validation or generation errors
76
  """
77
+
78
+ if isinstance(request.input, str):
79
+ texts = [request.input]
80
+
81
  try:
 
 
 
 
 
 
82
  kwargs = extract_embedding_kwargs(request)
83
 
84
  model = manager.get_model(request.model)
 
88
 
89
  start_time = time.time()
90
 
91
+ embeddings = model.embed(input=texts, **kwargs)
92
  processing_time = time.time() - start_time
93
 
94
  data = [
 
101
  ]
102
 
103
  token_usage = TokenUsage(
104
+ prompt_tokens=count_tokens_batch(texts),
105
+ total_tokens=count_tokens_batch(texts),
106
  )
107
 
108
  response = DenseEmbedResponse(
 
113
  )
114
 
115
  logger.info(
116
+ f"Generated {len(texts)} embeddings "
117
+ f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
118
  )
119
 
120
  return response
 
126
  except EmbeddingGenerationError as e:
127
  raise HTTPException(status_code=e.status_code, detail=e.message)
128
  except Exception as e:
129
+ logger.exception("Unexpected error in create_embeddings")
130
  raise HTTPException(
131
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
132
+ detail=f"Failed to create embeddings: {str(e)}",
133
  )
134
 
135
 
 
153
  Raises:
154
  HTTPException: On validation or generation errors
155
  """
156
+ if isinstance(request.input, str):
157
+ texts = [request.input]
158
+
159
  try:
 
160
  kwargs = extract_embedding_kwargs(request)
161
 
162
  model = manager.get_model(request.model)
 
166
 
167
  start_time = time.time()
168
 
169
+ sparse_results = model.embed(input=texts, **kwargs)
170
  processing_time = time.time() - start_time
171
 
172
  sparse_embeddings = [
173
  SparseEmbedding(
174
+ text=texts[idx],
175
  indices=sparse_result["indices"],
176
  values=sparse_result["values"],
177
  )
 
185
  )
186
 
187
  logger.info(
188
+ f"Generated {len(texts)} embeddings "
189
+ f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
190
  )
191
 
192
  return response
 
201
  logger.exception("Unexpected error in create_sparse_embedding")
202
  raise HTTPException(
203
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
204
+ detail=f"Failed to create sparse embedding: {str(e)}",
205
  )
src/models/schemas/requests.py CHANGED
@@ -5,7 +5,7 @@ This module defines all Pydantic models for incoming API requests,
5
  with validation and documentation.
6
  """
7
 
8
- from typing import List, Optional, Literal
9
  from pydantic import BaseModel, Field, field_validator, ConfigDict
10
  from .common import EmbeddingOptions
11
 
@@ -55,18 +55,18 @@ class EmbedRequest(BaseEmbedRequest):
55
  Used for /embeddings and /embed_sparse endpoint to process multiple texts at once.
56
 
57
  Attributes:
58
- input: List of input texts to embed
59
  """
60
 
61
- input: List[str] = Field(
62
  ...,
63
- description="List of input texts to generate embeddings for",
64
  min_length=1,
65
  )
66
 
67
  @field_validator("input")
68
  @classmethod
69
- def validate_texts(cls, v: List[str]) -> List[str]:
70
  """Validate that all texts are non-empty."""
71
  if not v:
72
  raise ValueError("texts list cannot be empty")
@@ -81,8 +81,6 @@ class EmbedRequest(BaseEmbedRequest):
81
  raise ValueError(f"texts[{idx}] must be a string")
82
  if not text.strip():
83
  raise ValueError(f"texts[{idx}] cannot be empty or whitespace only")
84
- if len(text) > 8192:
85
- raise ValueError(f"texts[{idx}] exceeds maximum length (8192)")
86
  validated.append(text)
87
 
88
  return validated
 
5
  with validation and documentation.
6
  """
7
 
8
+ from typing import List, Optional, Literal, Union
9
  from pydantic import BaseModel, Field, field_validator, ConfigDict
10
  from .common import EmbeddingOptions
11
 
 
55
  Used for /embeddings and /embed_sparse endpoint to process multiple texts at once.
56
 
57
  Attributes:
58
+ input: Str or List of input texts to embed
59
  """
60
 
61
+ input: Union[str, List[str]] = Field(
62
  ...,
63
+ description="Str or List of input texts to generate embeddings for",
64
  min_length=1,
65
  )
66
 
67
  @field_validator("input")
68
  @classmethod
69
+ def validate_texts(cls, v: Union[str, List[str]]) -> List[str]:
70
  """Validate that all texts are non-empty."""
71
  if not v:
72
  raise ValueError("texts list cannot be empty")
 
81
  raise ValueError(f"texts[{idx}] must be a string")
82
  if not text.strip():
83
  raise ValueError(f"texts[{idx}] cannot be empty or whitespace only")
 
 
84
  validated.append(text)
85
 
86
  return validated