File size: 2,013 Bytes
79eca0e
9288345
 
 
 
 
 
 
 
709b775
9288345
 
709b775
 
 
79eca0e
709b775
 
 
 
 
 
 
 
 
 
 
 
 
c11261f
 
 
 
 
 
 
 
 
 
 
 
 
9288345
 
 
 
81c6189
 
 
 
 
9288345
81c6189
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from fastapi import APIRouter, HTTPException
from schemas.schema import (
    SMSRequest,
    EmbeddingResponse,
    SimilarityRequest,
    SimilarityResponse,
    PredictionRequest,
    PredictionResponse
)
from service.embedded_service import generate_embeddings
from service.prediction_service import predict_label
import numpy as np

# Initialize Router
router = APIRouter()

@router.post("/get_embeddings/", response_model=EmbeddingResponse)
async def get_embeddings(sms_request: SMSRequest):
    if not sms_request.messages:
        raise HTTPException(status_code=400, detail="No messages provided.")
    
    embeddings = generate_embeddings(sms_request.messages)
    
    if not embeddings or not all(isinstance(emb, list) for emb in embeddings):
        raise HTTPException(status_code=500, detail="Failed to generate embeddings.")
    
    dimensions = len(embeddings[0]) if embeddings else 0
    
    return EmbeddingResponse(dimensions=dimensions, embeddings=embeddings)

@router.post("/calculate_similarity/", response_model=SimilarityResponse)
async def calculate_similarity(similarity_request: SimilarityRequest):
    embeddings = generate_embeddings([similarity_request.message1, similarity_request.message2])

    if len(embeddings) != 2:
        raise HTTPException(status_code=500, detail="Failed to generate embeddings for both messages.")

    vec1 = np.array(embeddings[0])
    vec2 = np.array(embeddings[1])
    cosine_similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    
    return SimilarityResponse(similarity_score=float(cosine_similarity))

@router.post("/predict_label/", response_model=PredictionResponse)
async def predict_sms_label(prediction_request: PredictionRequest):
    label, probability = predict_label(prediction_request.message)
    
    # Handle prediction errors
    if label == "Error":
        raise HTTPException(status_code=500, detail="Prediction failed.")
    
    return PredictionResponse(label=label, probability=probability)