File size: 2,167 Bytes
892a768
ad471a0
 
 
892a768
 
ad471a0
 
892a768
 
 
ad471a0
 
 
 
 
 
 
 
 
 
 
 
 
 
892a768
ad471a0
892a768
 
ad471a0
5ba2fa9
ad471a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import APIRouter, HTTPException
import os
import pickle

from back_end.models.embedding_model import generate_embedding
from back_end.schemas.request import TextRequest
from sklearn.linear_model import LogisticRegression
from scipy.spatial.distance import cosine

router = APIRouter()

BASE_DIR = os.path.dirname(os.path.abspath(__file__))  # Get the directory of the current file
MODEL_PATH = os.path.join(BASE_DIR, "..", "models", "logistic.pkl")

try:
    with open(MODEL_PATH, "rb") as f:
        logistic_model = pickle.load(f)
except FileNotFoundError:
    raise RuntimeError(f"Model file not found at {MODEL_PATH}")
except pickle.UnpicklingError:
    raise RuntimeError(f"Error unpickling model file at {MODEL_PATH}")


@router.post("/generate_embedding/")
def get_embedding(request: TextRequest):
    """Returns a 768-dimensional embedding for the given text."""
    if not request.text:
        raise HTTPException(status_code=400, detail="Text cannot be empty")

    embedding = generate_embedding(request.text)
    return {"dimensions": len(embedding), "embedding": embedding}


@router.post("/cosine_similarity/")
def get_cosine_similarity(request: TextRequest):
    """Returns the cosine similarity between two input texts."""
    if not hasattr(request, 'text') or not hasattr(request, 'text2'):
        raise HTTPException(status_code=400, detail="Both text inputs must be provided")

    embedding1 = generate_embedding(request.text)
    embedding2 = generate_embedding(request.text2)

    similarity = 1 - cosine(embedding1, embedding2)
    return {"cosine_similarity": similarity}


@router.post("/logistic_prediction/")
def get_logistic_prediction(request: TextRequest):
    """Returns the prediction from the logistic regression model for the input text."""
    if not request.text:
        raise HTTPException(status_code=400, detail="Text cannot be empty")

    embedding = generate_embedding(request.text)
    try:
        prediction = logistic_model.predict([embedding])[0]
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")

    return {"prediction": prediction}