DataNeuron / model_handler.py
Bhanushray's picture
Update model_handler.py
eef6783 verified
raw
history blame
1.39 kB
from sentence_transformers.cross_encoder import CrossEncoder
import torch
class SimilarityModelHandler:
# HOLDING THE MODEL INSTANCE TO PREVENT RELOADING
SIMILARITY_MODEL_INSTANCE = None
def __init__(self):
# CONSTRUCTOR: LOADING THE MODEL IF IT DOESN'T EXIST
if not SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE:
print("INITIALIZING AND LOADING THE MODEL...")
# CHECKING FOR GPU, FALLBACK TO CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"SERVICE IS RUNNING ON DEVICE: {device}")
# LOADING THE PRE-TRAINED CROSS-ENCODER MODEL
model_Name = 'cross-encoder/stsb-roberta-base'
#cross-encoder/stsb-roberta-large'
SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE = CrossEncoder(model_Name, device=device)
print("MODEL LOADED SUCCESSFULLY.")
def calculate_Similarity(self, text_One: str, text_Two: str) -> float:
"""
CALCULATES THE SIMILARITY SCORE BETWEEN TWO TEXTS.
"""
# GETTING THE SCORE FROM THE MODEL( 0-1 )
finalScore = self.SIMILARITY_MODEL_INSTANCE.predict([(text_One, text_Two)])
# CONVERTING FROM NUMPY ARRAY TO A SIMPLE FLOAT
return finalScore.item()
# CREATING A SINGLE INSTANCE TO BE USED BY THE API
MODEL_HANDLER = SimilarityModelHandler()