test / pretrained.py
dhruv107's picture
llogs
745c9ac
from sentence_transformers import SentenceTransformer, util
import pickle
import pandas as pd
import numpy as np
import os
import json
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
# import logging
# # Set up root logger, and add a file handler to root logger
# logging.basicConfig(filename = 'log_file.log',
# filemode='w',
# level = logging.DEBUG,
# format = '%(asctime)s:%(levelname)s:%(filename)s:%(funcName)s:%(lineno)d:%(message)s')
# logger = logging.getLogger()
app = Flask(__name__)
@app.route('/match_text', methods=['POST'])
def similarity():
try:
# logger.debug(f'receiving the json data')
data = request.get_json()
# logger.debug(f'received the json data')
if 'text1' not in data or 'text2' not in data:
# logger.debug(f'Error : Both text1 and text2 must be provided!')
return jsonify({'error': 'Both text1 and text2 must be provided.'}), 400
# logger.debug(f'extracting the sentences from the request')
sentences1 = data['text1']
sentences2 = data['text2']
# logger.debug(f'extracted the sentences from the request')
# logger.debug(f'calculating the embeddings')
embeddings1 = model.encode(sentences1, convert_to_tensor=True)
embeddings2 = model.encode(sentences2, convert_to_tensor=True)
# logger.debug(f'embeddings calculated')
# logger.debug(f'calculating the cosine score')
cosine_scores = util.cos_sim(embeddings1, embeddings2)
# logger.debug(f'calculated the cosine score')
print(f'{cosine_scores[0][0].item()}')
return jsonify({'similarity_score': cosine_scores[0][0].item()}), 200
except Exception as e:
# logger.debug(f'Unknown error! : {e}')
return jsonify({'error' : str(e)}), 500
if __name__ == '__main__':
# logger.debug(f'loading model...')
print(f'loading model...')
# model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder='./')
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder='./')
#model = SentenceTransformer("models--sentence-transformers--all-MiniLM-L6-v2/snapshots/1a310852cf8e58d22c5ebff537711d504ad4ad66")
model.max_seq_length = 512
print(f'model max lenght is :{model.max_seq_length}')
app.run(debug=False, port = 7860, host = '0.0.0.0', threaded = False)