Spaces:
Sleeping
Sleeping
File size: 4,601 Bytes
4f2a027 2e13826 4f2a027 aabbc15 4f2a027 aabbc15 4f2a027 aabbc15 4f2a027 aabbc15 4f2a027 23fab8c 4f2a027 f8c2f56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import joblib
import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List
CHECKPOINT_DIR = "checkpoints"
TFIDF_PATH = os.path.join(CHECKPOINT_DIR, "tfidf_vectorizer.pkl")
SVM_PATH = os.path.join(CHECKPOINT_DIR, "svm_stop_classifier.pkl")
LABEL_0 = "NOT_STOP"
LABEL_1 = "STOP"
tfidf_vectorizer = None
svm_model = None
try:
print(f"Loading TFIDF Vectorizer from {TFIDF_PATH}...")
tfidf_vectorizer = joblib.load(TFIDF_PATH)
print(f"Loading SVM Model from {SVM_PATH}...")
svm_model = joblib.load(SVM_PATH)
print("Models loaded successfully.")
except FileNotFoundError as e:
print(f"ERROR: Model file not found: {e}")
raise RuntimeError(f"Failed to load required model files. Ensure 'checkpoint/' is correctly populated. Error: {e}")
app = FastAPI(
title="STOP Classifier API",
description="STOP/NOT_STOP text classification using Linear SVM. The main UI is at the root '/', while the API endpoints are at '/api-docs' and '/predict'.",
version="1.0.0"
)
class PredictionRequest(BaseModel):
texts: List[str] = Field(
...,
description="A list of text strings to classify.",
example=[
"please discontinue all communication",
"I will stop by the station after lunch"
]
)
class PredictionResponse(BaseModel):
text: str = Field(..., description="The input text.")
prediction: str = Field(..., description="The predicted label (STOP or NOT_STOP).")
probability_NOT_STOP: float = Field(..., description="Probability of NOT_STOP label.")
probability_STOP: float = Field(..., description="Probability of STOP label.")
inference_model: str = Field("SVM", description="The model used for inference.")
def predict_svm(texts: List[str]) -> List[PredictionResponse]:
if not texts:
return []
vec = tfidf_vectorizer.transform(texts)
probs = svm_model.predict_proba(vec)
preds = svm_model.predict(vec)
results = []
for i, txt in enumerate(texts):
pred_label = LABEL_1 if preds[i] == 1 else LABEL_0
results.append(PredictionResponse(
text=txt,
prediction=pred_label,
probability_NOT_STOP=float(probs[i][0]),
probability_STOP=float(probs[i][1]),
inference_model="SVM"
))
return results
@app.get("/health", status_code=200, tags=["API"])
def health_check():
return {"status": "ok", "model_loaded": bool(svm_model)}
@app.post("/predict", response_model=List[PredictionResponse], tags=["API"])
async def post_predict(request: PredictionRequest):
try:
results = predict_svm(request.texts)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error during POST prediction: {e}")
@app.get("/predict", response_model=PredictionResponse, tags=["API"])
async def get_predict(text: str):
if not text.strip():
raise HTTPException(status_code=400, detail="Text query parameter cannot be empty.")
try:
results = predict_svm([text])
if not results:
raise HTTPException(status_code=500, detail="Prediction returned empty result.")
return results[0]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error during GET prediction: {e}")
def gradio_interface_fn(text_input):
"""Interface function to be called by Gradio UI."""
if not text_input or not text_input.strip():
return "Please enter text for classification.", None
try:
result = predict_svm([text_input])[0]
prediction_label = result.prediction
prob_display = {
LABEL_0: result.probability_NOT_STOP,
LABEL_1: result.probability_STOP
}
return prediction_label, prob_display
except Exception as e:
return f"An error occurred: {str(e)}", None
ui = gr.Interface(
fn=gradio_interface_fn,
inputs=gr.Textbox(lines=2, placeholder="Enter a message to classify...", label="Input Text"),
outputs=[
gr.Label(label="Classification Result"),
gr.Label(label="Probabilities")
],
title="STOP Classifier SVM",
description="This is the user interface for the SVM model. The model classifies text as intended to end communication (STOP) or not (NOT_STOP). The API is available at the '/predict' endpoints."
)
app = gr.mount_gradio_app(app, ui, path="/") |