Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| from fastapi import APIRouter | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from datetime import datetime | |
| from logger import log | |
| from config import TEST_MODE | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| router = APIRouter() | |
| class SentenceEmbeddingsInput(BaseModel): | |
| inputs: list[str] | |
| model: str | |
| parameters: dict | |
| class SentenceEmbeddingsOutput(BaseModel): | |
| embeddings: Optional[list[list[float]]] = None | |
| error: Optional[str] = None | |
| def sentence_embeddings(inputs: SentenceEmbeddingsInput): | |
| start_time = datetime.now() | |
| fn = sentence_embeddings_mapping.get(inputs.model) | |
| if not fn: | |
| return SentenceEmbeddingsOutput( | |
| error=f'No sentence embeddings model found for {inputs.model}' | |
| ) | |
| try: | |
| embeddings = fn(inputs.inputs, inputs.parameters) | |
| log({ | |
| "task": "sentence_embeddings", | |
| "model": inputs.model, | |
| "start_time": start_time.isoformat(), | |
| "time_taken": (datetime.now() - start_time).total_seconds(), | |
| "inputs": inputs.inputs, | |
| "outputs": embeddings, | |
| "parameters": inputs.parameters, | |
| }) | |
| loaded_models_last_updated[inputs.model] = datetime.now() | |
| return SentenceEmbeddingsOutput( | |
| embeddings=embeddings | |
| ) | |
| except Exception as e: | |
| return SentenceEmbeddingsOutput( | |
| error=str(e) | |
| ) | |
| def generic_sentence_embeddings(model_name: str): | |
| global loaded_models | |
| def process_texts(texts: list[str], parameters: dict): | |
| if TEST_MODE: | |
| return [[0.1,0.2]] * len(texts) | |
| if model_name in loaded_models: | |
| tokenizer, model = loaded_models[model_name] | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| loaded_models[model] = (tokenizer, model) | |
| # Tokenize sentences | |
| encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device) | |
| with torch.no_grad(): | |
| model_output = model(**encoded_input) | |
| sentence_embeddings = model_output[0][:, 0] | |
| # normalize embeddings | |
| sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) | |
| return sentence_embeddings.tolist() | |
| return process_texts | |
| # Polling every X minutes to | |
| loaded_models = {} | |
| loaded_models_last_updated = {} | |
| sentence_embeddings_mapping = { | |
| 'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'), | |
| 'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'), | |
| } |