Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| from torch.nn.functional import F | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger=logging.getLogger(__name__) | |
| logger.info("Server Starting") | |
| try: | |
| logger.info("Loading model") | |
| MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final" | |
| device="cpu" | |
| tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model=AutoModel.from_pretrained(MODEL_NAME) | |
| model.to(device) | |
| model.eval() | |
| logger.info("Model Loaded") | |
| except: | |
| logger.error("Failed to load Model") | |
| model=None | |
| tokenizer=None | |
| app=FastAPI() | |
| #Req and Response Pydantic models | |
| class EmbedRequest(BaseModel): | |
| text : str | |
| class EmbedResponse(BaseModel): | |
| embedding: list[float] | |
| def mean_pooling(model_output, attention_mask): | |
| """ | |
| Performs mean pooling on the last hidden state of the model. | |
| This turns token-level embeddings into a single sentence-level embedding. | |
| """ | |
| token_embeddings = model_output.last_hidden_state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def root_status(): | |
| return {"status":"ok","model":model is not None} | |
| def get_embedding(request: EmbedRequest): | |
| if not model or not tokenizer: | |
| HTTPException(status_code=503,detail="/Tokenizer could not be loaded") | |
| try: | |
| encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device) | |
| model_output = model(**encoded_input) | |
| # embedding=model.encode(request.text).tolist() | |
| sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask']) | |
| normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1) | |
| embedding_list = normalized_embedding[0].tolist() | |
| return EmbedResponse(embedding=embedding_list) | |
| except Exception as e: | |
| logger.error("Error during embedding generation %s",e) | |
| return HTTPException(status_code=500,detail="Error generating embeddings") |