Nightfury16 commited on
Commit
4f2a027
·
0 Parent(s):

Initial commit: Dockerized low-latency SVM classifier with FastAPI and Gradio UI.

Browse files
Files changed (4) hide show
  1. Dockerfile +15 -0
  2. READMe.md +61 -0
  3. app.py +128 -0
  4. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY app.py .
9
+
10
+ COPY checkpoint/ checkpoint/
11
+
12
+ ENV PORT 7860
13
+ EXPOSE 7860
14
+
15
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
READMe.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: STOP
3
+ sdk: docker
4
+ app_port: 7860
5
+ colorFrom: red
6
+ colorTo: indigo
7
+ description: Low-latency STOP/NOT_STOP text classification using Linear SVM deployed with FastAPI and Docker.
8
+ ---
9
+
10
+ # STOP Classifier API
11
+
12
+ This Hugging Face Space hosts a low-latency text classification service deployed with Docker and FastAPI.
13
+
14
+ The service uses a highly efficient Linear Support Vector Machine (SVM) model trained on text features extracted via TF-IDF to classify messages as either intending to end communication (`STOP`) or not (`NOT_STOP`). As confirmed by the training script, the SVM model provides millisecond-level inference, which is ideal for the required low-latency API.
15
+
16
+ ## Project Structure
17
+
18
+ The deployment uses the following structure:
19
+
20
+ ```
21
+ .
22
+ ├── app.py
23
+ ├── Dockerfile
24
+ ├── requirements.txt
25
+ ├── README.md
26
+ └── checkpoint/
27
+ ├── tfidf_vectorizer.pkl
28
+ └── svm_stop_classifier.pkl
29
+ ```
30
+
31
+ ## API Endpoints
32
+
33
+ The FastAPI application provides two primary endpoints for prediction:
34
+
35
+ ### 1. Health Check (GET)
36
+
37
+ * **Path:** `/`
38
+ * **Method:** `GET`
39
+ * **Description:** A simple endpoint to confirm the service is running and the models are loaded.
40
+
41
+ ### 2. Single Prediction (GET)
42
+
43
+ * **Path:** `/predict?text=<your_text>`
44
+ * **Method:** `GET`
45
+ * **Description:** Classifies a single text string passed as a query parameter. This is suitable for quick, individual queries.
46
+ * **Example Query:** `/predict?text=please%20discontinue%20all%20contact`
47
+
48
+ ### 3. Batch Prediction (POST)
49
+
50
+ * **Path:** `/predict`
51
+ * **Method:** `POST`
52
+ * **Description:** Classifies a list of text strings in a single request. This is the recommended approach for high-throughput, low-latency production use cases due to reduced overhead.
53
+ * **Request Body (JSON):**
54
+
55
+ ```json
56
+ {
57
+ "texts": [
58
+ "do not ever text me again",
59
+ "I will stop by your office tomorrow"
60
+ ]
61
+ }
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import joblib
3
+ import gradio as gr
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel, Field
6
+ from typing import List
7
+
8
+ CHECKPOINT_DIR = "checkpoint"
9
+ TFIDF_PATH = os.path.join(CHECKPOINT_DIR, "tfidf_vectorizer.pkl")
10
+ SVM_PATH = os.path.join(CHECKPOINT_DIR, "svm_stop_classifier.pkl")
11
+
12
+ LABEL_0 = "NOT_STOP"
13
+ LABEL_1 = "STOP"
14
+
15
+ tfidf_vectorizer = None
16
+ svm_model = None
17
+
18
+ try:
19
+ print(f"Loading TFIDF Vectorizer from {TFIDF_PATH}...")
20
+ tfidf_vectorizer = joblib.load(TFIDF_PATH)
21
+ print(f"Loading SVM Model from {SVM_PATH}...")
22
+ svm_model = joblib.load(SVM_PATH)
23
+ print("Models loaded successfully.")
24
+ except FileNotFoundError as e:
25
+ print(f"ERROR: Model file not found: {e}")
26
+ raise RuntimeError(f"Failed to load required model files. Ensure 'checkpoint/' is correctly populated. Error: {e}")
27
+
28
+ app = FastAPI(
29
+ title="STOP Classifier API",
30
+ description="STOP/NOT_STOP text classification using Linear SVM. The main UI is at the root '/', while the API endpoints are at '/api-docs' and '/predict'.",
31
+ version="1.0.0"
32
+ )
33
+
34
+ class PredictionRequest(BaseModel):
35
+ texts: List[str] = Field(
36
+ ...,
37
+ description="A list of text strings to classify.",
38
+ example=[
39
+ "please discontinue all communication",
40
+ "I will stop by the station after lunch"
41
+ ]
42
+ )
43
+
44
+ class PredictionResponse(BaseModel):
45
+ text: str = Field(..., description="The input text.")
46
+ prediction: str = Field(..., description="The predicted label (STOP or NOT_STOP).")
47
+ probability_NOT_STOP: float = Field(..., description="Probability of NOT_STOP label.")
48
+ probability_STOP: float = Field(..., description="Probability of STOP label.")
49
+ inference_model: str = Field("SVM", description="The model used for inference.")
50
+
51
+ def predict_svm(texts: List[str]) -> List[PredictionResponse]:
52
+ if not texts:
53
+ return []
54
+
55
+ vec = tfidf_vectorizer.transform(texts)
56
+ probs = svm_model.predict_proba(vec)
57
+ preds = svm_model.predict(vec)
58
+
59
+ results = []
60
+ for i, txt in enumerate(texts):
61
+ pred_label = LABEL_1 if preds[i] == 1 else LABEL_0
62
+ results.append(PredictionResponse(
63
+ text=txt,
64
+ prediction=pred_label,
65
+ probability_NOT_STOP=float(probs[i][0]),
66
+ probability_STOP=float(probs[i][1]),
67
+ inference_model="SVM"
68
+ ))
69
+
70
+ return results
71
+
72
+ @app.get("/health", status_code=200, tags=["API"])
73
+ def health_check():
74
+ return {"status": "ok", "model_loaded": bool(svm_model)}
75
+
76
+ @app.post("/predict", response_model=List[PredictionResponse], tags=["API"])
77
+ async def post_predict(request: PredictionRequest):
78
+ try:
79
+ results = predict_svm(request.texts)
80
+ return results
81
+ except Exception as e:
82
+ raise HTTPException(status_code=500, detail=f"Internal Server Error during POST prediction: {e}")
83
+
84
+ @app.get("/predict", response_model=PredictionResponse, tags=["API"])
85
+ async def get_predict(text: str):
86
+ if not text.strip():
87
+ raise HTTPException(status_code=400, detail="Text query parameter cannot be empty.")
88
+
89
+ try:
90
+ results = predict_svm([text])
91
+ if not results:
92
+ raise HTTPException(status_code=500, detail="Prediction returned empty result.")
93
+
94
+ return results[0]
95
+
96
+ except Exception as e:
97
+ raise HTTPException(status_code=500, detail=f"Internal Server Error during GET prediction: {e}")
98
+
99
+ def gradio_interface_fn(text_input):
100
+ if not text_input or not text_input.strip():
101
+ return "Please enter text for classification.", None
102
+
103
+ try:
104
+ result = predict_svm([text_input])[0]
105
+ prediction_label = result.prediction
106
+
107
+ prob_display = {
108
+ result.prediction: result.probability_STOP if result.prediction == LABEL_1 else result.probability_NOT_STOP,
109
+ LABEL_1 if result.prediction == LABEL_0 else LABEL_0: result.probability_STOP if result.prediction == LABEL_0 else result.probability_NOT_STOP
110
+ }
111
+
112
+ return prediction_label, prob_display
113
+
114
+ except Exception as e:
115
+ return f"An error occurred: {str(e)}", None
116
+
117
+ ui = gr.Interface(
118
+ fn=gradio_interface_fn,
119
+ inputs=gr.Textbox(lines=2, placeholder="Enter a message to classify...", label="Input Text"),
120
+ outputs=[
121
+ gr.Label(label="Classification Result"),
122
+ gr.Label(label="Probabilities")
123
+ ],
124
+ title="STOP Classifier (Low-Latency SVM)",
125
+ description="This is the user interface for the SVM model. The model classifies text as intended to end communication (STOP) or not (NOT_STOP). The API is available at the '/predict' endpoints."
126
+ )
127
+
128
+ app = gr.mount_app(app, ui, path="/")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ scikit-learn
4
+ joblib
5
+ pydantic