STOP / app.py
Nightfury16's picture
updated checkpoints
23fab8c
import os
import joblib
import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List
CHECKPOINT_DIR = "checkpoints"
TFIDF_PATH = os.path.join(CHECKPOINT_DIR, "tfidf_vectorizer.pkl")
SVM_PATH = os.path.join(CHECKPOINT_DIR, "svm_stop_classifier.pkl")
LABEL_0 = "NOT_STOP"
LABEL_1 = "STOP"
tfidf_vectorizer = None
svm_model = None
try:
print(f"Loading TFIDF Vectorizer from {TFIDF_PATH}...")
tfidf_vectorizer = joblib.load(TFIDF_PATH)
print(f"Loading SVM Model from {SVM_PATH}...")
svm_model = joblib.load(SVM_PATH)
print("Models loaded successfully.")
except FileNotFoundError as e:
print(f"ERROR: Model file not found: {e}")
raise RuntimeError(f"Failed to load required model files. Ensure 'checkpoint/' is correctly populated. Error: {e}")
app = FastAPI(
title="STOP Classifier API",
description="STOP/NOT_STOP text classification using Linear SVM. The main UI is at the root '/', while the API endpoints are at '/api-docs' and '/predict'.",
version="1.0.0"
)
class PredictionRequest(BaseModel):
texts: List[str] = Field(
...,
description="A list of text strings to classify.",
example=[
"please discontinue all communication",
"I will stop by the station after lunch"
]
)
class PredictionResponse(BaseModel):
text: str = Field(..., description="The input text.")
prediction: str = Field(..., description="The predicted label (STOP or NOT_STOP).")
probability_NOT_STOP: float = Field(..., description="Probability of NOT_STOP label.")
probability_STOP: float = Field(..., description="Probability of STOP label.")
inference_model: str = Field("SVM", description="The model used for inference.")
def predict_svm(texts: List[str]) -> List[PredictionResponse]:
if not texts:
return []
vec = tfidf_vectorizer.transform(texts)
probs = svm_model.predict_proba(vec)
preds = svm_model.predict(vec)
results = []
for i, txt in enumerate(texts):
pred_label = LABEL_1 if preds[i] == 1 else LABEL_0
results.append(PredictionResponse(
text=txt,
prediction=pred_label,
probability_NOT_STOP=float(probs[i][0]),
probability_STOP=float(probs[i][1]),
inference_model="SVM"
))
return results
@app.get("/health", status_code=200, tags=["API"])
def health_check():
return {"status": "ok", "model_loaded": bool(svm_model)}
@app.post("/predict", response_model=List[PredictionResponse], tags=["API"])
async def post_predict(request: PredictionRequest):
try:
results = predict_svm(request.texts)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error during POST prediction: {e}")
@app.get("/predict", response_model=PredictionResponse, tags=["API"])
async def get_predict(text: str):
if not text.strip():
raise HTTPException(status_code=400, detail="Text query parameter cannot be empty.")
try:
results = predict_svm([text])
if not results:
raise HTTPException(status_code=500, detail="Prediction returned empty result.")
return results[0]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error during GET prediction: {e}")
def gradio_interface_fn(text_input):
"""Interface function to be called by Gradio UI."""
if not text_input or not text_input.strip():
return "Please enter text for classification.", None
try:
result = predict_svm([text_input])[0]
prediction_label = result.prediction
prob_display = {
LABEL_0: result.probability_NOT_STOP,
LABEL_1: result.probability_STOP
}
return prediction_label, prob_display
except Exception as e:
return f"An error occurred: {str(e)}", None
ui = gr.Interface(
fn=gradio_interface_fn,
inputs=gr.Textbox(lines=2, placeholder="Enter a message to classify...", label="Input Text"),
outputs=[
gr.Label(label="Classification Result"),
gr.Label(label="Probabilities")
],
title="STOP Classifier SVM",
description="This is the user interface for the SVM model. The model classifies text as intended to end communication (STOP) or not (NOT_STOP). The API is available at the '/predict' endpoints."
)
app = gr.mount_gradio_app(app, ui, path="/")