File size: 4,601 Bytes
4f2a027
 
 
 
 
 
 
2e13826
4f2a027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aabbc15
4f2a027
 
 
 
 
aabbc15
4f2a027
 
 
aabbc15
 
4f2a027
 
 
 
 
 
aabbc15
4f2a027
 
 
 
 
 
 
23fab8c
4f2a027
 
 
f8c2f56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import joblib
import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List

CHECKPOINT_DIR = "checkpoints"
TFIDF_PATH = os.path.join(CHECKPOINT_DIR, "tfidf_vectorizer.pkl")
SVM_PATH = os.path.join(CHECKPOINT_DIR, "svm_stop_classifier.pkl")

LABEL_0 = "NOT_STOP"
LABEL_1 = "STOP"

tfidf_vectorizer = None
svm_model = None

try:
    print(f"Loading TFIDF Vectorizer from {TFIDF_PATH}...")
    tfidf_vectorizer = joblib.load(TFIDF_PATH)
    print(f"Loading SVM Model from {SVM_PATH}...")
    svm_model = joblib.load(SVM_PATH)
    print("Models loaded successfully.")
except FileNotFoundError as e:
    print(f"ERROR: Model file not found: {e}")
    raise RuntimeError(f"Failed to load required model files. Ensure 'checkpoint/' is correctly populated. Error: {e}")

app = FastAPI(
    title="STOP Classifier API",
    description="STOP/NOT_STOP text classification using Linear SVM. The main UI is at the root '/', while the API endpoints are at '/api-docs' and '/predict'.",
    version="1.0.0"
)

class PredictionRequest(BaseModel):
    texts: List[str] = Field(
        ...,
        description="A list of text strings to classify.",
        example=[
            "please discontinue all communication",
            "I will stop by the station after lunch"
        ]
    )

class PredictionResponse(BaseModel):
    text: str = Field(..., description="The input text.")
    prediction: str = Field(..., description="The predicted label (STOP or NOT_STOP).")
    probability_NOT_STOP: float = Field(..., description="Probability of NOT_STOP label.")
    probability_STOP: float = Field(..., description="Probability of STOP label.")
    inference_model: str = Field("SVM", description="The model used for inference.")

def predict_svm(texts: List[str]) -> List[PredictionResponse]:
    if not texts:
        return []
        
    vec = tfidf_vectorizer.transform(texts)
    probs = svm_model.predict_proba(vec)
    preds = svm_model.predict(vec)

    results = []
    for i, txt in enumerate(texts):
        pred_label = LABEL_1 if preds[i] == 1 else LABEL_0
        results.append(PredictionResponse(
            text=txt,
            prediction=pred_label,
            probability_NOT_STOP=float(probs[i][0]),
            probability_STOP=float(probs[i][1]),
            inference_model="SVM"
        ))
    
    return results

@app.get("/health", status_code=200, tags=["API"])
def health_check():
    return {"status": "ok", "model_loaded": bool(svm_model)}

@app.post("/predict", response_model=List[PredictionResponse], tags=["API"])
async def post_predict(request: PredictionRequest):
    try:
        results = predict_svm(request.texts)
        return results
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal Server Error during POST prediction: {e}")

@app.get("/predict", response_model=PredictionResponse, tags=["API"])
async def get_predict(text: str):
    if not text.strip():
        raise HTTPException(status_code=400, detail="Text query parameter cannot be empty.")
        
    try:
        results = predict_svm([text])
        if not results:
            raise HTTPException(status_code=500, detail="Prediction returned empty result.")
        
        return results[0]
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal Server Error during GET prediction: {e}")

def gradio_interface_fn(text_input):
    """Interface function to be called by Gradio UI."""
    if not text_input or not text_input.strip():
        return "Please enter text for classification.", None
        
    try:
        result = predict_svm([text_input])[0]
        
        prediction_label = result.prediction
        
        prob_display = {
            LABEL_0: result.probability_NOT_STOP,  
            LABEL_1: result.probability_STOP       
        }
        
        return prediction_label, prob_display
        
    except Exception as e:
        return f"An error occurred: {str(e)}", None
    
ui = gr.Interface(
    fn=gradio_interface_fn,
    inputs=gr.Textbox(lines=2, placeholder="Enter a message to classify...", label="Input Text"),
    outputs=[
        gr.Label(label="Classification Result"), 
        gr.Label(label="Probabilities")
    ],
    title="STOP Classifier SVM",
    description="This is the user interface for the SVM model. The model classifies text as intended to end communication (STOP) or not (NOT_STOP). The API is available at the '/predict' endpoints."
)

app = gr.mount_gradio_app(app, ui, path="/")