embedding_FastAPI / routes /sms_router.py
Chittrarasu's picture
deploy
81c6189
from fastapi import APIRouter, HTTPException
from schemas.schema import (
SMSRequest,
EmbeddingResponse,
SimilarityRequest,
SimilarityResponse,
PredictionRequest,
PredictionResponse
)
from service.embedded_service import generate_embeddings
from service.prediction_service import predict_label
import numpy as np
# Initialize Router
router = APIRouter()
@router.post("/get_embeddings/", response_model=EmbeddingResponse)
async def get_embeddings(sms_request: SMSRequest):
if not sms_request.messages:
raise HTTPException(status_code=400, detail="No messages provided.")
embeddings = generate_embeddings(sms_request.messages)
if not embeddings or not all(isinstance(emb, list) for emb in embeddings):
raise HTTPException(status_code=500, detail="Failed to generate embeddings.")
dimensions = len(embeddings[0]) if embeddings else 0
return EmbeddingResponse(dimensions=dimensions, embeddings=embeddings)
@router.post("/calculate_similarity/", response_model=SimilarityResponse)
async def calculate_similarity(similarity_request: SimilarityRequest):
embeddings = generate_embeddings([similarity_request.message1, similarity_request.message2])
if len(embeddings) != 2:
raise HTTPException(status_code=500, detail="Failed to generate embeddings for both messages.")
vec1 = np.array(embeddings[0])
vec2 = np.array(embeddings[1])
cosine_similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
return SimilarityResponse(similarity_score=float(cosine_similarity))
@router.post("/predict_label/", response_model=PredictionResponse)
async def predict_sms_label(prediction_request: PredictionRequest):
label, probability = predict_label(prediction_request.message)
# Handle prediction errors
if label == "Error":
raise HTTPException(status_code=500, detail="Prediction failed.")
return PredictionResponse(label=label, probability=probability)