Spaces:
Runtime error
Runtime error
| import asyncio | |
| import base64 | |
| from typing import Union | |
| import numpy as np | |
| import tiktoken | |
| from fastapi import APIRouter, Depends | |
| from openai import AsyncOpenAI | |
| from openai.types.create_embedding_response import Usage | |
| from sentence_transformers import SentenceTransformer | |
| from api.config import SETTINGS | |
| from api.models import EMBEDDED_MODEL | |
| from api.utils.protocol import EmbeddingCreateParams, Embedding, CreateEmbeddingResponse | |
| from api.utils.request import check_api_key | |
| embedding_router = APIRouter() | |
| def get_embedding_engine(): | |
| yield EMBEDDED_MODEL | |
| async def create_embeddings( | |
| request: EmbeddingCreateParams, | |
| model_name: str = None, | |
| client: Union[SentenceTransformer, AsyncOpenAI] = Depends(get_embedding_engine), | |
| ): | |
| """Creates embeddings for the text""" | |
| if request.model is None: | |
| request.model = model_name | |
| request.input = request.input | |
| if isinstance(request.input, str): | |
| request.input = [request.input] | |
| elif isinstance(request.input, list): | |
| if isinstance(request.input[0], int): | |
| decoding = tiktoken.model.encoding_for_model(request.model) | |
| request.input = [decoding.decode(request.input)] | |
| elif isinstance(request.input[0], list): | |
| decoding = tiktoken.model.encoding_for_model(request.model) | |
| request.input = [decoding.decode(text) for text in request.input] | |
| data, total_tokens = [], 0 | |
| # support for tei: https://github.com/huggingface/text-embeddings-inference | |
| if isinstance(client, AsyncOpenAI): | |
| global_batch_size = SETTINGS.max_concurrent_requests * SETTINGS.max_client_batch_size | |
| for i in range(0, len(request.input), global_batch_size): | |
| tasks = [] | |
| texts = request.input[i: i + global_batch_size] | |
| for j in range(0, len(texts), SETTINGS.max_client_batch_size): | |
| tasks.append( | |
| client.embeddings.create( | |
| input=[text[:510] for text in texts[j: j + SETTINGS.max_client_batch_size]], | |
| model=request.model, | |
| ) | |
| ) | |
| res = await asyncio.gather(*tasks) | |
| vecs = np.asarray([e.embedding for r in res for e in r.data]) | |
| bs, dim = vecs.shape | |
| if SETTINGS.embedding_size > dim: | |
| zeros = np.zeros((bs, SETTINGS.embedding_size - dim)) | |
| vecs = np.c_[vecs, zeros] | |
| if request.encoding_format == "base64": | |
| vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs] | |
| else: | |
| vecs = vecs.tolist() | |
| data.extend( | |
| Embedding( | |
| index=i * global_batch_size + j, | |
| object="embedding", | |
| embedding=embed | |
| ) | |
| for j, embed in enumerate(vecs) | |
| ) | |
| total_tokens += sum(r.usage.total_tokens for r in res) | |
| else: | |
| batches = [request.input[i: i + 1024] for i in range(0, len(request.input), 1024)] | |
| for num_batch, batch in enumerate(batches): | |
| token_num = sum(len(i) for i in batch) | |
| vecs = client.encode(batch, normalize_embeddings=True) | |
| bs, dim = vecs.shape | |
| if SETTINGS.embedding_size > dim: | |
| zeros = np.zeros((bs, SETTINGS.embedding_size - dim)) | |
| vecs = np.c_[vecs, zeros] | |
| if request.encoding_format == "base64": | |
| vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs] | |
| else: | |
| vecs = vecs.tolist() | |
| data.extend( | |
| Embedding( | |
| index=num_batch * 1024 + i, | |
| object="embedding", | |
| embedding=embedding, | |
| ) | |
| for i, embedding in enumerate(vecs) | |
| ) | |
| total_tokens += token_num | |
| return CreateEmbeddingResponse( | |
| data=data, | |
| model=request.model, | |
| object="list", | |
| usage=Usage(prompt_tokens=total_tokens, total_tokens=total_tokens), | |
| ) | |