Ezhil
modified code
e9a2c4c
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from model import get_embeddings, predict_sms_category
from service import calculate_cosine_similarity
# FastAPI app
app = FastAPI()
class MessageRequest(BaseModel):
messages: List[str]
class CosineSimilarityRequest(BaseModel):
message1: str
message2: str
class PredictionRequest(BaseModel):
message: str
class EmbeddingResponse(BaseModel):
dimensions: int
numeric_values: List[List[float]]
class CosineSimilarityResponse(BaseModel):
similarity: float
class PredictionResponse(BaseModel):
label: str
@app.get("/")
def home():
return {"Message": "Welcome to the SMS classifier API. Use /docs for documentation."}
@app.post("/embed", response_model=EmbeddingResponse)
def embed(request: MessageRequest):
embeddings = get_embeddings(request.messages)
return EmbeddingResponse(
dimensions=embeddings.shape[1], # Number of embedding dimensions
numeric_values=embeddings.tolist()
)
@app.post("/cosine_similarity", response_model=CosineSimilarityResponse)
def cosine_similarity(request: CosineSimilarityRequest):
similarity = calculate_cosine_similarity(request.message1, request.message2)
return CosineSimilarityResponse(similarity=similarity)
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
label = predict_sms_category(request.message)
return PredictionResponse(label=label)