import os import joblib import gradio as gr from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import List CHECKPOINT_DIR = "checkpoints" TFIDF_PATH = os.path.join(CHECKPOINT_DIR, "tfidf_vectorizer.pkl") SVM_PATH = os.path.join(CHECKPOINT_DIR, "svm_stop_classifier.pkl") LABEL_0 = "NOT_STOP" LABEL_1 = "STOP" tfidf_vectorizer = None svm_model = None try: print(f"Loading TFIDF Vectorizer from {TFIDF_PATH}...") tfidf_vectorizer = joblib.load(TFIDF_PATH) print(f"Loading SVM Model from {SVM_PATH}...") svm_model = joblib.load(SVM_PATH) print("Models loaded successfully.") except FileNotFoundError as e: print(f"ERROR: Model file not found: {e}") raise RuntimeError(f"Failed to load required model files. Ensure 'checkpoint/' is correctly populated. Error: {e}") app = FastAPI( title="STOP Classifier API", description="STOP/NOT_STOP text classification using Linear SVM. The main UI is at the root '/', while the API endpoints are at '/api-docs' and '/predict'.", version="1.0.0" ) class PredictionRequest(BaseModel): texts: List[str] = Field( ..., description="A list of text strings to classify.", example=[ "please discontinue all communication", "I will stop by the station after lunch" ] ) class PredictionResponse(BaseModel): text: str = Field(..., description="The input text.") prediction: str = Field(..., description="The predicted label (STOP or NOT_STOP).") probability_NOT_STOP: float = Field(..., description="Probability of NOT_STOP label.") probability_STOP: float = Field(..., description="Probability of STOP label.") inference_model: str = Field("SVM", description="The model used for inference.") def predict_svm(texts: List[str]) -> List[PredictionResponse]: if not texts: return [] vec = tfidf_vectorizer.transform(texts) probs = svm_model.predict_proba(vec) preds = svm_model.predict(vec) results = [] for i, txt in enumerate(texts): pred_label = LABEL_1 if preds[i] == 1 else LABEL_0 results.append(PredictionResponse( text=txt, prediction=pred_label, probability_NOT_STOP=float(probs[i][0]), probability_STOP=float(probs[i][1]), inference_model="SVM" )) return results @app.get("/health", status_code=200, tags=["API"]) def health_check(): return {"status": "ok", "model_loaded": bool(svm_model)} @app.post("/predict", response_model=List[PredictionResponse], tags=["API"]) async def post_predict(request: PredictionRequest): try: results = predict_svm(request.texts) return results except Exception as e: raise HTTPException(status_code=500, detail=f"Internal Server Error during POST prediction: {e}") @app.get("/predict", response_model=PredictionResponse, tags=["API"]) async def get_predict(text: str): if not text.strip(): raise HTTPException(status_code=400, detail="Text query parameter cannot be empty.") try: results = predict_svm([text]) if not results: raise HTTPException(status_code=500, detail="Prediction returned empty result.") return results[0] except Exception as e: raise HTTPException(status_code=500, detail=f"Internal Server Error during GET prediction: {e}") def gradio_interface_fn(text_input): """Interface function to be called by Gradio UI.""" if not text_input or not text_input.strip(): return "Please enter text for classification.", None try: result = predict_svm([text_input])[0] prediction_label = result.prediction prob_display = { LABEL_0: result.probability_NOT_STOP, LABEL_1: result.probability_STOP } return prediction_label, prob_display except Exception as e: return f"An error occurred: {str(e)}", None ui = gr.Interface( fn=gradio_interface_fn, inputs=gr.Textbox(lines=2, placeholder="Enter a message to classify...", label="Input Text"), outputs=[ gr.Label(label="Classification Result"), gr.Label(label="Probabilities") ], title="STOP Classifier SVM", description="This is the user interface for the SVM model. The model classifies text as intended to end communication (STOP) or not (NOT_STOP). The API is available at the '/predict' endpoints." ) app = gr.mount_gradio_app(app, ui, path="/")