from fastapi import FastAPI from pydantic import BaseModel from typing import List, Dict, Any from setfit import SetFitModel # Define the input model class TextInput(BaseModel): text: List[str] # Initialize the FastAPI app app = FastAPI() # Load the model once, when the app starts model = SetFitModel.from_pretrained("assets") @app.post("/predict") async def predict(input: TextInput): # Get the text input from the request text_input = input.text # Initialize a list to store the predictions response: List[Dict[str, Any]] = [] # Predict using the loaded model for each message for message in text_input: pred = model.predict([message])[0] # Predict expects a list, so wrap the message in a list and take the first result response.append({"Message": message, "label": pred}) return response @app.get("/") def read_root(): return {"message": "Welcome"}