Commit
·
df5822b
1
Parent(s):
b5db444
loaded model using AutoModel and manual pooling
Browse files
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel
|
| 3 |
-
from
|
|
|
|
|
|
|
| 4 |
import logging
|
| 5 |
|
| 6 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -9,11 +11,19 @@ logger=logging.getLogger(__name__)
|
|
| 9 |
logger.info("Server Starting")
|
| 10 |
try:
|
| 11 |
logger.info("Loading model")
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
logger.info("Model Loaded")
|
| 14 |
except:
|
| 15 |
logger.error("Failed to load Model")
|
| 16 |
model=None
|
|
|
|
| 17 |
|
| 18 |
app=FastAPI()
|
| 19 |
|
|
@@ -24,16 +34,33 @@ class EmbedRequest(BaseModel):
|
|
| 24 |
class EmbedResponse(BaseModel):
|
| 25 |
embedding: list[float]
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
@app.get("/")
|
| 28 |
def root_status():
|
| 29 |
return {"status":"ok","model":model is not None}
|
| 30 |
|
| 31 |
@app.post("/embed",response_model=EmbedResponse)
|
| 32 |
def get_embedding(request: EmbedRequest):
|
| 33 |
-
if model
|
| 34 |
-
HTTPException(status_code=503,detail="
|
| 35 |
try:
|
|
|
|
|
|
|
| 36 |
embedding=model.encode(request.text).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
return EmbedResponse(embedding=embedding)
|
| 38 |
except Exception as e:
|
| 39 |
logger.error("Error during embedding generation %s",e)
|
|
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel
|
| 3 |
+
from transformers import AutoTokenizer, AutoModel
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn.functional import F
|
| 6 |
import logging
|
| 7 |
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 11 |
logger.info("Server Starting")
|
| 12 |
try:
|
| 13 |
logger.info("Loading model")
|
| 14 |
+
MODEL_NAME = "Sid-the-sloth/leetcode_unixcoder_final"
|
| 15 |
+
device="cpu"
|
| 16 |
+
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 17 |
+
model=AutoModel.from_pretrained(MODEL_NAME)
|
| 18 |
+
|
| 19 |
+
model.to(device)
|
| 20 |
+
model.eval()
|
| 21 |
+
|
| 22 |
logger.info("Model Loaded")
|
| 23 |
except:
|
| 24 |
logger.error("Failed to load Model")
|
| 25 |
model=None
|
| 26 |
+
tokenizer=None
|
| 27 |
|
| 28 |
app=FastAPI()
|
| 29 |
|
|
|
|
| 34 |
class EmbedResponse(BaseModel):
|
| 35 |
embedding: list[float]
|
| 36 |
|
| 37 |
+
def mean_pooling(model_output, attention_mask):
|
| 38 |
+
"""
|
| 39 |
+
Performs mean pooling on the last hidden state of the model.
|
| 40 |
+
This turns token-level embeddings into a single sentence-level embedding.
|
| 41 |
+
"""
|
| 42 |
+
token_embeddings = model_output.last_hidden_state
|
| 43 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 44 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
@app.get("/")
|
| 48 |
def root_status():
|
| 49 |
return {"status":"ok","model":model is not None}
|
| 50 |
|
| 51 |
@app.post("/embed",response_model=EmbedResponse)
|
| 52 |
def get_embedding(request: EmbedRequest):
|
| 53 |
+
if not model or not tokenizer:
|
| 54 |
+
HTTPException(status_code=503,detail="/Tokenizer could not be loaded")
|
| 55 |
try:
|
| 56 |
+
encoded_input = tokenizer(request.text, padding=True, truncation=True, return_tensors='pt').to(device)
|
| 57 |
+
model_output = model(**encoded_input)
|
| 58 |
embedding=model.encode(request.text).tolist()
|
| 59 |
+
sentence_embedding = mean_pooling(model_output, encoded_input['attention_mask'])
|
| 60 |
+
|
| 61 |
+
normalized_embedding = F.normalize(sentence_embedding, p=2, dim=1)
|
| 62 |
+
|
| 63 |
+
embedding_list = normalized_embedding[0].tolist()
|
| 64 |
return EmbedResponse(embedding=embedding)
|
| 65 |
except Exception as e:
|
| 66 |
logger.error("Error during embedding generation %s",e)
|