File size: 1,523 Bytes
b76ff16
 
 
 
 
 
 
5e6ea4a
b76ff16
 
 
b00beeb
 
b76ff16
 
1d022c3
b76ff16
 
 
 
 
b00beeb
 
938df35
4c9ac12
b76ff16
938df35
4c9ac12
b00beeb
 
d1fc9fc
 
6247a4a
b00beeb
 
 
 
b76ff16
 
 
 
 
 
 
 
 
 
6247a4a
b76ff16
 
6247a4a
6161807
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from fastapi import FastAPI
import os
import json
import uvicorn
import numpy as np
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import pickle

app = FastAPI()

tokenizer, model = None, None

@app.get('/')
def homepage():
    return json.dumps({'Message': 'Welcome to SPLADE'})

@app.get('/health_check')
def health_check():
    return json.dumps({"success": True}), 200

def load_model_and_tokenizer():
    global tokenizer, model
    with open('./app/tokenizer.pkl', 'rb') as f:
        tokenizer = pickle.load(f)

    with open('./app/model.pkl', 'rb') as f:
        model = pickle.load(f)

@app.get('/get_vector')
def get_vector(search_queries_list):
    search_queries = json.loads(search_queries_list)
    global tokenizer, model
    if tokenizer is None or model is None:
        raise ValueError("Model and tokenizer have not been loaded yet. Call load_model_and_tokenizer() first.")
        load_model_and_tokenizer()

    search_tokens = tokenizer(
        search_queries, return_tensors='pt',
        padding=True, truncation=True
    )
    search_output = model(**search_tokens)
    # Aggregate the token-level vecs and transform to sparse
    search_vecs = torch.max(
        torch.log(1 + torch.relu(search_output.logits)) * search_tokens.attention_mask.unsqueeze(-1), dim=1
    )[0].squeeze().detach().cpu().numpy()

    return {"vector":json.dumps(search_vecs.tolist())}

if __name__ == '__main__':
    load_model_and_tokenizer()
    uvicorn.run(app, host='0.0.0.0', port=7860)