|
|
from fastapi import FastAPI |
|
|
import models |
|
|
from schema import Prediction |
|
|
from sentence_transformers import util |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.get("/") |
|
|
def home_page(): |
|
|
return {"Home": "Welcome to prediction hub"} |
|
|
|
|
|
@app.get("/embeddings") |
|
|
def display_embedding(message : str = "Hello guys enter a text to get embeddings"): |
|
|
try: |
|
|
embedding = models.get_embedding(message) |
|
|
dimension = len(embedding) |
|
|
return {"Dimension" : {dimension : embedding.tolist()}} |
|
|
except Exception as e: |
|
|
return {f"Unable to fetch the embeddings. Error :{e}" } |
|
|
|
|
|
@app.post("/prediction") |
|
|
def display_prediction(prediction : Prediction): |
|
|
message = prediction.message |
|
|
embedding = models.get_embedding([message]) |
|
|
loaded_model = models.load_model('log_reg_model.pkl') |
|
|
result = loaded_model.predict(embedding).tolist() |
|
|
return {"Prediction": f"{message} is a {result}"} |
|
|
|
|
|
@app.post("/cosine_similarity") |
|
|
def display_cosine_similarity(prediction : Prediction): |
|
|
message = prediction.message |
|
|
message_1 = prediction.message_1 |
|
|
embendding = models.get_embedding([message,message_1]) |
|
|
similarity = util.cos_sim(embendding[0], embendding[1]).item() |
|
|
return {f"Cosine Similarity between {message} and {message_1} is" : round(similarity, 4)} |
|
|
|