checkin_or_not / checkin_or_not_API.py
mjpsm's picture
added the newly retrained model (version 3)
5fdd5e9 verified
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
import xgboost as xgb
import numpy as np
app = FastAPI(
title="Check-In Classifier API",
description="Detects whether a message is a CHECKIN or NOT_CHECKIN using XGBoost + MiniLM.",
version="1.0.0"
)
# -------------------------------------
# Load embedding model (MiniLM)
# -------------------------------------
embedder = SentenceTransformer("all-MiniLM-L6-v2")
# -------------------------------------
# Download XGBoost model from HF Hub
# -------------------------------------
MODEL_REPO = "mjpsm/checkin_or_not_model"
MODEL_FILE = "checkin_or_not_classifierV3.json"
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE
)
booster = xgb.Booster()
booster.load_model(model_path)
# -------------------------------------
# Request/Response Models
# -------------------------------------
class PredictionRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
label: str
score: float
@app.get("/")
def home():
return {"status": "online", "message": "Check-In Classification API is running."}
@app.post("/predict", response_model=PredictionResponse)
def predict(req: PredictionRequest):
# Convert text → embedding
embedding = embedder.encode([req.text])
# Prepare for XGBoost
dmatrix = xgb.DMatrix(embedding)
# Predict score
score = float(booster.predict(dmatrix)[0])
label = "CHECKIN" if score >= 0.5 else "NOT_CHECKIN"
return PredictionResponse(label=label, score=score)