Spaces:
Running
Running
fahmiaziz98
commited on
Commit
·
2e2b2b3
1
Parent(s):
376886a
[UPDATE]: Input str or List[str]
Browse files- src/api/routers/embedding.py +21 -26
- src/models/schemas/requests.py +5 -7
src/api/routers/embedding.py
CHANGED
|
@@ -27,11 +27,8 @@ from src.core.exceptions import (
|
|
| 27 |
from src.api.dependencies import get_model_manager
|
| 28 |
from src.utils.validators import (
|
| 29 |
extract_embedding_kwargs,
|
| 30 |
-
validate_texts,
|
| 31 |
count_tokens_batch,
|
| 32 |
)
|
| 33 |
-
from src.config.settings import get_settings
|
| 34 |
-
|
| 35 |
|
| 36 |
router = APIRouter(tags=["embeddings"])
|
| 37 |
|
|
@@ -64,10 +61,8 @@ def _ensure_model_type(config, expected_type: str, model_id: str) -> None:
|
|
| 64 |
summary="Generate single/batch embeddings",
|
| 65 |
description="Generate embeddings for multiple texts in a single request",
|
| 66 |
)
|
| 67 |
-
async def
|
| 68 |
-
request: EmbedRequest,
|
| 69 |
-
manager: ModelManager = Depends(get_model_manager),
|
| 70 |
-
settings=Depends(get_settings),
|
| 71 |
):
|
| 72 |
"""
|
| 73 |
Generate embeddings for multiple texts.
|
|
@@ -79,13 +74,11 @@ async def create_embeddings_document(
|
|
| 79 |
Raises:
|
| 80 |
HTTPException: On validation or generation errors
|
| 81 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
try:
|
| 83 |
-
# Validate input
|
| 84 |
-
validate_texts(
|
| 85 |
-
request.input,
|
| 86 |
-
max_length=settings.MAX_TEXT_LENGTH,
|
| 87 |
-
max_batch_size=settings.MAX_BATCH_SIZE,
|
| 88 |
-
)
|
| 89 |
kwargs = extract_embedding_kwargs(request)
|
| 90 |
|
| 91 |
model = manager.get_model(request.model)
|
|
@@ -95,7 +88,7 @@ async def create_embeddings_document(
|
|
| 95 |
|
| 96 |
start_time = time.time()
|
| 97 |
|
| 98 |
-
embeddings = model.embed(input=
|
| 99 |
processing_time = time.time() - start_time
|
| 100 |
|
| 101 |
data = [
|
|
@@ -108,8 +101,8 @@ async def create_embeddings_document(
|
|
| 108 |
]
|
| 109 |
|
| 110 |
token_usage = TokenUsage(
|
| 111 |
-
prompt_tokens=count_tokens_batch(
|
| 112 |
-
total_tokens=count_tokens_batch(
|
| 113 |
)
|
| 114 |
|
| 115 |
response = DenseEmbedResponse(
|
|
@@ -120,8 +113,8 @@ async def create_embeddings_document(
|
|
| 120 |
)
|
| 121 |
|
| 122 |
logger.info(
|
| 123 |
-
f"Generated {len(
|
| 124 |
-
f"in {processing_time:.3f}s ({len(
|
| 125 |
)
|
| 126 |
|
| 127 |
return response
|
|
@@ -133,10 +126,10 @@ async def create_embeddings_document(
|
|
| 133 |
except EmbeddingGenerationError as e:
|
| 134 |
raise HTTPException(status_code=e.status_code, detail=e.message)
|
| 135 |
except Exception as e:
|
| 136 |
-
logger.exception("Unexpected error in
|
| 137 |
raise HTTPException(
|
| 138 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 139 |
-
detail=f"Failed to create
|
| 140 |
)
|
| 141 |
|
| 142 |
|
|
@@ -160,8 +153,10 @@ async def create_sparse_embedding(
|
|
| 160 |
Raises:
|
| 161 |
HTTPException: On validation or generation errors
|
| 162 |
"""
|
|
|
|
|
|
|
|
|
|
| 163 |
try:
|
| 164 |
-
validate_texts(request.input)
|
| 165 |
kwargs = extract_embedding_kwargs(request)
|
| 166 |
|
| 167 |
model = manager.get_model(request.model)
|
|
@@ -171,12 +166,12 @@ async def create_sparse_embedding(
|
|
| 171 |
|
| 172 |
start_time = time.time()
|
| 173 |
|
| 174 |
-
sparse_results = model.embed(input=
|
| 175 |
processing_time = time.time() - start_time
|
| 176 |
|
| 177 |
sparse_embeddings = [
|
| 178 |
SparseEmbedding(
|
| 179 |
-
text=
|
| 180 |
indices=sparse_result["indices"],
|
| 181 |
values=sparse_result["values"],
|
| 182 |
)
|
|
@@ -190,8 +185,8 @@ async def create_sparse_embedding(
|
|
| 190 |
)
|
| 191 |
|
| 192 |
logger.info(
|
| 193 |
-
f"Generated {len(
|
| 194 |
-
f"in {processing_time:.3f}s ({len(
|
| 195 |
)
|
| 196 |
|
| 197 |
return response
|
|
@@ -206,5 +201,5 @@ async def create_sparse_embedding(
|
|
| 206 |
logger.exception("Unexpected error in create_sparse_embedding")
|
| 207 |
raise HTTPException(
|
| 208 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 209 |
-
detail=f"Failed to create
|
| 210 |
)
|
|
|
|
| 27 |
from src.api.dependencies import get_model_manager
|
| 28 |
from src.utils.validators import (
|
| 29 |
extract_embedding_kwargs,
|
|
|
|
| 30 |
count_tokens_batch,
|
| 31 |
)
|
|
|
|
|
|
|
| 32 |
|
| 33 |
router = APIRouter(tags=["embeddings"])
|
| 34 |
|
|
|
|
| 61 |
summary="Generate single/batch embeddings",
|
| 62 |
description="Generate embeddings for multiple texts in a single request",
|
| 63 |
)
|
| 64 |
+
async def create_embeddings(
|
| 65 |
+
request: EmbedRequest, manager: ModelManager = Depends(get_model_manager)
|
|
|
|
|
|
|
| 66 |
):
|
| 67 |
"""
|
| 68 |
Generate embeddings for multiple texts.
|
|
|
|
| 74 |
Raises:
|
| 75 |
HTTPException: On validation or generation errors
|
| 76 |
"""
|
| 77 |
+
|
| 78 |
+
if isinstance(request.input, str):
|
| 79 |
+
texts = [request.input]
|
| 80 |
+
|
| 81 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
kwargs = extract_embedding_kwargs(request)
|
| 83 |
|
| 84 |
model = manager.get_model(request.model)
|
|
|
|
| 88 |
|
| 89 |
start_time = time.time()
|
| 90 |
|
| 91 |
+
embeddings = model.embed(input=texts, **kwargs)
|
| 92 |
processing_time = time.time() - start_time
|
| 93 |
|
| 94 |
data = [
|
|
|
|
| 101 |
]
|
| 102 |
|
| 103 |
token_usage = TokenUsage(
|
| 104 |
+
prompt_tokens=count_tokens_batch(texts),
|
| 105 |
+
total_tokens=count_tokens_batch(texts),
|
| 106 |
)
|
| 107 |
|
| 108 |
response = DenseEmbedResponse(
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
logger.info(
|
| 116 |
+
f"Generated {len(texts)} embeddings "
|
| 117 |
+
f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
|
| 118 |
)
|
| 119 |
|
| 120 |
return response
|
|
|
|
| 126 |
except EmbeddingGenerationError as e:
|
| 127 |
raise HTTPException(status_code=e.status_code, detail=e.message)
|
| 128 |
except Exception as e:
|
| 129 |
+
logger.exception("Unexpected error in create_embeddings")
|
| 130 |
raise HTTPException(
|
| 131 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 132 |
+
detail=f"Failed to create embeddings: {str(e)}",
|
| 133 |
)
|
| 134 |
|
| 135 |
|
|
|
|
| 153 |
Raises:
|
| 154 |
HTTPException: On validation or generation errors
|
| 155 |
"""
|
| 156 |
+
if isinstance(request.input, str):
|
| 157 |
+
texts = [request.input]
|
| 158 |
+
|
| 159 |
try:
|
|
|
|
| 160 |
kwargs = extract_embedding_kwargs(request)
|
| 161 |
|
| 162 |
model = manager.get_model(request.model)
|
|
|
|
| 166 |
|
| 167 |
start_time = time.time()
|
| 168 |
|
| 169 |
+
sparse_results = model.embed(input=texts, **kwargs)
|
| 170 |
processing_time = time.time() - start_time
|
| 171 |
|
| 172 |
sparse_embeddings = [
|
| 173 |
SparseEmbedding(
|
| 174 |
+
text=texts[idx],
|
| 175 |
indices=sparse_result["indices"],
|
| 176 |
values=sparse_result["values"],
|
| 177 |
)
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
logger.info(
|
| 188 |
+
f"Generated {len(texts)} embeddings "
|
| 189 |
+
f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
|
| 190 |
)
|
| 191 |
|
| 192 |
return response
|
|
|
|
| 201 |
logger.exception("Unexpected error in create_sparse_embedding")
|
| 202 |
raise HTTPException(
|
| 203 |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 204 |
+
detail=f"Failed to create sparse embedding: {str(e)}",
|
| 205 |
)
|
src/models/schemas/requests.py
CHANGED
|
@@ -5,7 +5,7 @@ This module defines all Pydantic models for incoming API requests,
|
|
| 5 |
with validation and documentation.
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
from typing import List, Optional, Literal
|
| 9 |
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
| 10 |
from .common import EmbeddingOptions
|
| 11 |
|
|
@@ -55,18 +55,18 @@ class EmbedRequest(BaseEmbedRequest):
|
|
| 55 |
Used for /embeddings and /embed_sparse endpoint to process multiple texts at once.
|
| 56 |
|
| 57 |
Attributes:
|
| 58 |
-
input: List of input texts to embed
|
| 59 |
"""
|
| 60 |
|
| 61 |
-
input: List[str] = Field(
|
| 62 |
...,
|
| 63 |
-
description="List of input texts to generate embeddings for",
|
| 64 |
min_length=1,
|
| 65 |
)
|
| 66 |
|
| 67 |
@field_validator("input")
|
| 68 |
@classmethod
|
| 69 |
-
def validate_texts(cls, v: List[str]) -> List[str]:
|
| 70 |
"""Validate that all texts are non-empty."""
|
| 71 |
if not v:
|
| 72 |
raise ValueError("texts list cannot be empty")
|
|
@@ -81,8 +81,6 @@ class EmbedRequest(BaseEmbedRequest):
|
|
| 81 |
raise ValueError(f"texts[{idx}] must be a string")
|
| 82 |
if not text.strip():
|
| 83 |
raise ValueError(f"texts[{idx}] cannot be empty or whitespace only")
|
| 84 |
-
if len(text) > 8192:
|
| 85 |
-
raise ValueError(f"texts[{idx}] exceeds maximum length (8192)")
|
| 86 |
validated.append(text)
|
| 87 |
|
| 88 |
return validated
|
|
|
|
| 5 |
with validation and documentation.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
from typing import List, Optional, Literal, Union
|
| 9 |
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
| 10 |
from .common import EmbeddingOptions
|
| 11 |
|
|
|
|
| 55 |
Used for /embeddings and /embed_sparse endpoint to process multiple texts at once.
|
| 56 |
|
| 57 |
Attributes:
|
| 58 |
+
input: Str or List of input texts to embed
|
| 59 |
"""
|
| 60 |
|
| 61 |
+
input: Union[str, List[str]] = Field(
|
| 62 |
...,
|
| 63 |
+
description="Str or List of input texts to generate embeddings for",
|
| 64 |
min_length=1,
|
| 65 |
)
|
| 66 |
|
| 67 |
@field_validator("input")
|
| 68 |
@classmethod
|
| 69 |
+
def validate_texts(cls, v: Union[str, List[str]]) -> List[str]:
|
| 70 |
"""Validate that all texts are non-empty."""
|
| 71 |
if not v:
|
| 72 |
raise ValueError("texts list cannot be empty")
|
|
|
|
| 81 |
raise ValueError(f"texts[{idx}] must be a string")
|
| 82 |
if not text.strip():
|
| 83 |
raise ValueError(f"texts[{idx}] cannot be empty or whitespace only")
|
|
|
|
|
|
|
| 84 |
validated.append(text)
|
| 85 |
|
| 86 |
return validated
|