fahmiaziz98
update: Response format
ecfc38e
"""
Single/Batch embedding generation endpoints.
This module provides routes for generating embeddings for
multiple texts in a single request.
"""
import time
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from loguru import logger
from src.models.schemas import (
EmbedRequest,
DenseEmbedResponse,
EmbeddingObject,
TokenUsage,
)
from src.core.manager import ModelManager
from src.core.exceptions import (
ModelNotFoundError,
ModelNotLoadedError,
EmbeddingGenerationError,
ValidationError,
)
from src.api.dependencies import get_model_manager
from src.utils.validators import (
extract_embedding_kwargs,
count_tokens_batch,
ensure_model_type,
)
router = APIRouter()
@router.post(
"/embeddings",
response_model=DenseEmbedResponse,
tags=["OpenAI Compatible"],
summary="Generate single/batch embeddings",
description="Generate embeddings for multiple texts in a single request",
)
async def create_openai_embeddings(
request: EmbedRequest, manager: ModelManager = Depends(get_model_manager)
):
"""
Generate embeddings for multiple texts.
The endpoint validates the request, checks that the requested
model is a dense embedding model, and returns a
:class:`DenseEmbedResponse`.
Raises:
HTTPException: On validation or generation errors
"""
texts = [request.input] if isinstance(request.input, str) else request.input
if not texts or not isinstance(texts, list):
raise ValidationError("Input must be a non-empty list or string.")
try:
kwargs = extract_embedding_kwargs(request)
model = manager.get_model(request.model)
config = manager.model_configs.get(request.model)
ensure_model_type(config, "embeddings", request.model)
start_time = time.time()
embeddings = model.embed(input=texts, **kwargs)
processing_time = time.time() - start_time
data = [
EmbeddingObject(
object="embedding",
embedding=embedding,
index=idx,
)
for idx, embedding in enumerate(embeddings)
]
token_usage = TokenUsage(
prompt_tokens=count_tokens_batch(texts),
total_tokens=count_tokens_batch(texts),
)
response = DenseEmbedResponse(
object="list",
data=data,
model=request.model,
usage=token_usage,
)
logger.info(
f"Generated {len(texts)} embeddings "
f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
)
return response
except (ValidationError, ModelNotFoundError) as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except ModelNotLoadedError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except EmbeddingGenerationError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except Exception as e:
logger.exception("Unexpected error in create_openai_embeddings")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create embeddings: {str(e)}",
)
@router.post(
"/embed",
tags=["Embeddings"],
summary="Generate single/batch dense embeddings",
description="Generate embedding for a multiple query text",
)
async def create_embeddings(
request: EmbedRequest, manager: ModelManager = Depends(get_model_manager)
):
"""
Generate embeddings for multiple texts.
The endpoint validates the request, checks that the requested
model is a dense embedding model, and returns a
:class:`DenseEmbedResponse`.
Raises:
HTTPException: On validation or generation errors
"""
texts = [request.input] if isinstance(request.input, str) else request.input
if not texts or not isinstance(texts, list):
raise ValidationError("Input must be a non-empty list or string.")
try:
kwargs = extract_embedding_kwargs(request)
model = manager.get_model(request.model)
config = manager.model_configs.get(request.model)
ensure_model_type(config, "embeddings", request.model)
start_time = time.time()
embeddings = model.embed(input=texts, **kwargs)
processing_time = time.time() - start_time
logger.info(
f"Generated {len(texts)} embeddings "
f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
)
return JSONResponse(content=embeddings)
except (ValidationError, ModelNotFoundError) as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except ModelNotLoadedError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except EmbeddingGenerationError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except Exception as e:
logger.exception("Unexpected error in create_embeddings")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create embeddings: {str(e)}",
)
@router.post(
"/embed_sparse",
tags=["Embeddings"],
summary="Generate single/batch sparse embeddings",
description="Generate embedding for a multiple query text",
)
async def create_sparse_embedding(
request: EmbedRequest,
manager: ModelManager = Depends(get_model_manager),
):
"""
Generate a single/batch sparse embedding.
The endpoint validates the request, checks that the requested
model is a sparse embedding model, and returns a
:class:`SparseEmbedResponse`.
Raises:
HTTPException: On validation or generation errors
"""
texts = [request.input] if isinstance(request.input, str) else request.input
if not texts or not isinstance(texts, list):
raise ValidationError("Input must be a non-empty list or string.")
try:
kwargs = extract_embedding_kwargs(request)
model = manager.get_model(request.model)
config = manager.model_configs.get(request.model)
ensure_model_type(config, "sparse-embeddings", request.model)
start_time = time.time()
sparse_results = model.embed(input=texts, **kwargs)
processing_time = time.time() - start_time
formatted_embeddings = [
[{"index": i, "value": v} for i, v in zip(res["indices"], res["values"])]
for res in sparse_results
]
logger.info(
f"Generated {len(texts)} embeddings "
f"in {processing_time:.3f}s ({len(texts) / processing_time:.1f} texts/s)"
)
return JSONResponse(content=formatted_embeddings)
except (ValidationError, ModelNotFoundError) as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except ModelNotLoadedError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except EmbeddingGenerationError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
except Exception as e:
logger.exception("Unexpected error in create_sparse_embedding")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create sparse embedding: {str(e)}",
)