Spaces:
Running
Running
| """ | |
| Single/Batch embedding generation endpoints. | |
| This module provides routes for generating embeddings for | |
| multiple texts in a single request. | |
| """ | |
| import time | |
| from fastapi import APIRouter, Depends, HTTPException, status | |
| from loguru import logger | |
| from src.models.schemas import ( | |
| EmbedRequest, | |
| DenseEmbedResponse, | |
| EmbeddingObject, | |
| TokenUsage, | |
| SparseEmbedResponse, | |
| SparseEmbedding, | |
| ) | |
| from src.core.manager import ModelManager | |
| from src.core.exceptions import ( | |
| ModelNotFoundError, | |
| ModelNotLoadedError, | |
| EmbeddingGenerationError, | |
| ValidationError, | |
| ) | |
| from src.api.dependencies import get_model_manager | |
| from src.utils.validators import extract_embedding_kwargs, validate_texts, count_tokens_batch | |
| from src.config.settings import get_settings | |
| router = APIRouter(tags=["embeddings"]) | |
| async def create_embeddings_document( | |
| request: EmbedRequest, | |
| manager: ModelManager = Depends(get_model_manager), | |
| settings=Depends(get_settings), | |
| ): | |
| """ | |
| Generate embeddings for multiple texts. | |
| Args: | |
| request: BatchEmbedRequest with input, model, and optional parameters | |
| manager: Model manager dependency | |
| settings: Application settings | |
| Returns: | |
| DenseEmbedResponse | |
| Raises: | |
| HTTPException: On validation or generation errors | |
| """ | |
| try: | |
| # Validate input | |
| validate_texts( | |
| request.input, | |
| max_length=settings.MAX_TEXT_LENGTH, | |
| max_batch_size=settings.MAX_BATCH_SIZE, | |
| ) | |
| kwargs = extract_embedding_kwargs(request) | |
| model = manager.get_model(request.model) | |
| config = manager.model_configs[request.model] | |
| start_time = time.time() | |
| if config.type == "embeddings": | |
| embeddings = model.embed( | |
| input=request.input, **kwargs | |
| ) | |
| processing_time = time.time() - start_time | |
| data = [] | |
| for idx, embedding in enumerate(embeddings): | |
| data.append( | |
| EmbeddingObject( | |
| object="embedding", | |
| embedding=embedding, | |
| index=idx, | |
| ) | |
| ) | |
| # Calculate token usage | |
| token_usage = TokenUsage( | |
| prompt_tokens=count_tokens_batch(request.input), | |
| total_tokens=count_tokens_batch(request.input), | |
| ) | |
| response = DenseEmbedResponse( | |
| object="list", | |
| data=data, | |
| model=request.model, | |
| usage=token_usage, | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Model '{request.model}' is not a dense model. Type: {config.type}", | |
| ) | |
| logger.info( | |
| f"Generated {len(request.input)} embeddings " | |
| f"in {processing_time:.3f}s ({len(request.input) / processing_time:.1f} texts/s)" | |
| ) | |
| return response | |
| except (ValidationError, ModelNotFoundError) as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.message) | |
| except ModelNotLoadedError as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.message) | |
| except EmbeddingGenerationError as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.message) | |
| except Exception as e: | |
| logger.exception("Unexpected error in create_embeddings_document") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to create batch embeddings: {str(e)}", | |
| ) | |
| async def create_sparse_embedding( | |
| request: EmbedRequest, | |
| manager: ModelManager = Depends(get_model_manager), | |
| ): | |
| """ | |
| Generate a single/batch sparse embedding. | |
| Args: | |
| request: EmbedRequest with input, model, and optional parameters | |
| manager: Model manager dependency | |
| Returns: | |
| SparseEmbedResponse | |
| Raises: | |
| HTTPException: On validation or generation errors | |
| """ | |
| try: | |
| validate_texts(request.input) | |
| kwargs = extract_embedding_kwargs(request) | |
| model = manager.get_model(request.model) | |
| config = manager.model_configs[request.model] | |
| start_time = time.time() | |
| if config.type == "sparse-embeddings": | |
| sparse_results = model.embed( | |
| input=request.input, **kwargs | |
| ) | |
| processing_time = time.time() - start_time | |
| sparse_embeddings = [] | |
| for idx, sparse_result in enumerate(sparse_results): | |
| sparse_embeddings.append( | |
| SparseEmbedding( | |
| text=request.input[idx], | |
| indices=sparse_result["indices"], | |
| values=sparse_result["values"], | |
| ) | |
| ) | |
| response = SparseEmbedResponse( | |
| embeddings=sparse_embeddings, | |
| count=len(sparse_embeddings), | |
| model=request.model | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Model '{request.model}' is not a sparse model. Type: {config.type}", | |
| ) | |
| logger.info( | |
| f"Generated {len(request.texts)} embeddings " | |
| f"in {processing_time:.3f}s ({len(request.texts) / processing_time:.1f} texts/s)" | |
| ) | |
| return response | |
| except (ValidationError, ModelNotFoundError) as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.message) | |
| except ModelNotLoadedError as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.message) | |
| except EmbeddingGenerationError as e: | |
| raise HTTPException(status_code=e.status_code, detail=e.message) | |
| except Exception as e: | |
| logger.exception("Unexpected error in create_query_embedding") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to create query embedding: {str(e)}", | |
| ) | |