nathanjc's picture
Update app/main.py
6161807 verified
from fastapi import FastAPI
import os
import json
import uvicorn
import numpy as np
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import pickle
app = FastAPI()
tokenizer, model = None, None
@app.get('/')
def homepage():
return json.dumps({'Message': 'Welcome to SPLADE'})
@app.get('/health_check')
def health_check():
return json.dumps({"success": True}), 200
def load_model_and_tokenizer():
global tokenizer, model
with open('./app/tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
with open('./app/model.pkl', 'rb') as f:
model = pickle.load(f)
@app.get('/get_vector')
def get_vector(search_queries_list):
search_queries = json.loads(search_queries_list)
global tokenizer, model
if tokenizer is None or model is None:
raise ValueError("Model and tokenizer have not been loaded yet. Call load_model_and_tokenizer() first.")
load_model_and_tokenizer()
search_tokens = tokenizer(
search_queries, return_tensors='pt',
padding=True, truncation=True
)
search_output = model(**search_tokens)
# Aggregate the token-level vecs and transform to sparse
search_vecs = torch.max(
torch.log(1 + torch.relu(search_output.logits)) * search_tokens.attention_mask.unsqueeze(-1), dim=1
)[0].squeeze().detach().cpu().numpy()
return {"vector":json.dumps(search_vecs.tolist())}
if __name__ == '__main__':
load_model_and_tokenizer()
uvicorn.run(app, host='0.0.0.0', port=7860)