Spaces:
Sleeping
Sleeping
| import os | |
| import joblib | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from typing import List | |
| CHECKPOINT_DIR = "checkpoints" | |
| TFIDF_PATH = os.path.join(CHECKPOINT_DIR, "tfidf_vectorizer.pkl") | |
| SVM_PATH = os.path.join(CHECKPOINT_DIR, "svm_stop_classifier.pkl") | |
| LABEL_0 = "NOT_STOP" | |
| LABEL_1 = "STOP" | |
| tfidf_vectorizer = None | |
| svm_model = None | |
| try: | |
| print(f"Loading TFIDF Vectorizer from {TFIDF_PATH}...") | |
| tfidf_vectorizer = joblib.load(TFIDF_PATH) | |
| print(f"Loading SVM Model from {SVM_PATH}...") | |
| svm_model = joblib.load(SVM_PATH) | |
| print("Models loaded successfully.") | |
| except FileNotFoundError as e: | |
| print(f"ERROR: Model file not found: {e}") | |
| raise RuntimeError(f"Failed to load required model files. Ensure 'checkpoint/' is correctly populated. Error: {e}") | |
| app = FastAPI( | |
| title="STOP Classifier API", | |
| description="STOP/NOT_STOP text classification using Linear SVM. The main UI is at the root '/', while the API endpoints are at '/api-docs' and '/predict'.", | |
| version="1.0.0" | |
| ) | |
| class PredictionRequest(BaseModel): | |
| texts: List[str] = Field( | |
| ..., | |
| description="A list of text strings to classify.", | |
| example=[ | |
| "please discontinue all communication", | |
| "I will stop by the station after lunch" | |
| ] | |
| ) | |
| class PredictionResponse(BaseModel): | |
| text: str = Field(..., description="The input text.") | |
| prediction: str = Field(..., description="The predicted label (STOP or NOT_STOP).") | |
| probability_NOT_STOP: float = Field(..., description="Probability of NOT_STOP label.") | |
| probability_STOP: float = Field(..., description="Probability of STOP label.") | |
| inference_model: str = Field("SVM", description="The model used for inference.") | |
| def predict_svm(texts: List[str]) -> List[PredictionResponse]: | |
| if not texts: | |
| return [] | |
| vec = tfidf_vectorizer.transform(texts) | |
| probs = svm_model.predict_proba(vec) | |
| preds = svm_model.predict(vec) | |
| results = [] | |
| for i, txt in enumerate(texts): | |
| pred_label = LABEL_1 if preds[i] == 1 else LABEL_0 | |
| results.append(PredictionResponse( | |
| text=txt, | |
| prediction=pred_label, | |
| probability_NOT_STOP=float(probs[i][0]), | |
| probability_STOP=float(probs[i][1]), | |
| inference_model="SVM" | |
| )) | |
| return results | |
| def health_check(): | |
| return {"status": "ok", "model_loaded": bool(svm_model)} | |
| async def post_predict(request: PredictionRequest): | |
| try: | |
| results = predict_svm(request.texts) | |
| return results | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Internal Server Error during POST prediction: {e}") | |
| async def get_predict(text: str): | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text query parameter cannot be empty.") | |
| try: | |
| results = predict_svm([text]) | |
| if not results: | |
| raise HTTPException(status_code=500, detail="Prediction returned empty result.") | |
| return results[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Internal Server Error during GET prediction: {e}") | |
| def gradio_interface_fn(text_input): | |
| """Interface function to be called by Gradio UI.""" | |
| if not text_input or not text_input.strip(): | |
| return "Please enter text for classification.", None | |
| try: | |
| result = predict_svm([text_input])[0] | |
| prediction_label = result.prediction | |
| prob_display = { | |
| LABEL_0: result.probability_NOT_STOP, | |
| LABEL_1: result.probability_STOP | |
| } | |
| return prediction_label, prob_display | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}", None | |
| ui = gr.Interface( | |
| fn=gradio_interface_fn, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter a message to classify...", label="Input Text"), | |
| outputs=[ | |
| gr.Label(label="Classification Result"), | |
| gr.Label(label="Probabilities") | |
| ], | |
| title="STOP Classifier SVM", | |
| description="This is the user interface for the SVM model. The model classifies text as intended to end communication (STOP) or not (NOT_STOP). The API is available at the '/predict' endpoints." | |
| ) | |
| app = gr.mount_gradio_app(app, ui, path="/") |