fahmiaziz98
[TYPO] fix typo
a23b910
raw
history blame
6.56 kB
"""
Single/Batch embedding generation endpoints.
This module provides routes for generating embeddings for
multiple texts in a single request.
"""
import time
from fastapi import APIRouter, Depends, HTTPException, status
from loguru import logger
from src.models.schemas import (
EmbedRequest,
DenseEmbedResponse,
EmbeddingObject,
TokenUsage,
SparseEmbedResponse,
SparseEmbedding,
)
from src.core.manager import ModelManager
from src.core.exceptions import (
ModelNotFoundError,
ModelNotLoadedError,
EmbeddingGenerationError,
ValidationError,
)
from src.api.dependencies import get_model_manager
from src.utils.validators import extract_embedding_kwargs, validate_texts, count_tokens_batch
from src.config.settings import get_settings
router = APIRouter(tags=["embeddings"])
@router.post(
"/embeddings",
response_model=DenseEmbedResponse,
summary="Generate single/batch embeddings",
description="Generate embeddings for multiple texts in a single request",
)
async def create_embeddings_document(
request: EmbedRequest,
manager: ModelManager = Depends(get_model_manager),
settings=Depends(get_settings),
):
"""
Generate embeddings for multiple texts.
Args:
request: BatchEmbedRequest with input, model, and optional parameters
manager: Model manager dependency
settings: Application settings
Returns:
DenseEmbedResponse
Raises:
HTTPException: On validation or generation errors
"""
try:
# Validate input
validate_texts(
request.input,
max_length=settings.MAX_TEXT_LENGTH,
max_batch_size=settings.MAX_BATCH_SIZE,
)
kwargs = extract_embedding_kwargs(request)
model = manager.get_model(request.model)
config = manager.model_configs[request.model]
start_time = time.time()
if config.type == "embeddings":
embeddings = model.embed(
input=request.input, **kwargs
)
processing_time = time.time() - start_time
data = []
for idx, embedding in enumerate(embeddings):
data.append(
EmbeddingObject(
object="embedding",
embedding=embedding,
index=idx,
)
)
# Calculate token usage
token_usage = TokenUsage(
prompt_tokens=count_tokens_batch(request.input),
total_tokens=count_tokens_batch(request.input),
)
response = DenseEmbedResponse(
object="list",
data=data,
model=request.model,
usage=token_usage,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Model '{request.model}' is not a dense model. Type: {config.type}",
)
logger.info(
f"Generated {len(request.input)} embeddings "
f"in {processing_time:.3f}s ({len(request.input) / processing_time:.1f} texts/s)"
)
return response
except (ValidationError, ModelNotFoundError) as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except ModelNotLoadedError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except EmbeddingGenerationError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except Exception as e:
logger.exception("Unexpected error in create_embeddings_document")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create batch embeddings: {str(e)}",
)
@router.post(
"/embed_sparse",
response_model=SparseEmbedResponse,
summary="Generate single/batch sparse embeddings",
description="Generate embedding for a multiple query text",
)
async def create_sparse_embedding(
request: EmbedRequest,
manager: ModelManager = Depends(get_model_manager),
):
"""
Generate a single/batch sparse embedding.
Args:
request: EmbedRequest with input, model, and optional parameters
manager: Model manager dependency
Returns:
SparseEmbedResponse
Raises:
HTTPException: On validation or generation errors
"""
try:
validate_texts(request.input)
kwargs = extract_embedding_kwargs(request)
model = manager.get_model(request.model)
config = manager.model_configs[request.model]
start_time = time.time()
if config.type == "sparse-embeddings":
sparse_results = model.embed(
input=request.input, **kwargs
)
processing_time = time.time() - start_time
sparse_embeddings = []
for idx, sparse_result in enumerate(sparse_results):
sparse_embeddings.append(
SparseEmbedding(
text=request.input[idx],
indices=sparse_result["indices"],
values=sparse_result["values"],
)
)
response = SparseEmbedResponse(
embeddings=sparse_embeddings,
count=len(sparse_embeddings),
model=request.model
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Model '{request.model}' is not a sparse model. Type: {config.type}",
)
logger.info(
f"Generated {len(request.texts)} embeddings "
f"in {processing_time:.3f}s ({len(request.texts) / processing_time:.1f} texts/s)"
)
return response
except (ValidationError, ModelNotFoundError) as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except ModelNotLoadedError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except EmbeddingGenerationError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except Exception as e:
logger.exception("Unexpected error in create_query_embedding")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create query embedding: {str(e)}",
)