from fastapi import FastAPI import os import json import uvicorn import numpy as np from transformers import AutoModelForMaskedLM, AutoTokenizer import torch import pickle app = FastAPI() tokenizer, model = None, None @app.get('/') def homepage(): return json.dumps({'Message': 'Welcome to SPLADE'}) @app.get('/health_check') def health_check(): return json.dumps({"success": True}), 200 def load_model_and_tokenizer(): global tokenizer, model with open('./app/tokenizer.pkl', 'rb') as f: tokenizer = pickle.load(f) with open('./app/model.pkl', 'rb') as f: model = pickle.load(f) @app.get('/get_vector') def get_vector(search_queries_list): search_queries = json.loads(search_queries_list) global tokenizer, model if tokenizer is None or model is None: raise ValueError("Model and tokenizer have not been loaded yet. Call load_model_and_tokenizer() first.") load_model_and_tokenizer() search_tokens = tokenizer( search_queries, return_tensors='pt', padding=True, truncation=True ) search_output = model(**search_tokens) # Aggregate the token-level vecs and transform to sparse search_vecs = torch.max( torch.log(1 + torch.relu(search_output.logits)) * search_tokens.attention_mask.unsqueeze(-1), dim=1 )[0].squeeze().detach().cpu().numpy() return {"vector":json.dumps(search_vecs.tolist())} if __name__ == '__main__': load_model_and_tokenizer() uvicorn.run(app, host='0.0.0.0', port=7860)