File size: 6,529 Bytes
8fdd265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
backend/main.py

Run:
    pip install -r requirements.txt
    uvicorn main:app --reload --port 8000
"""

import os
import re
import pickle
import time
from contextlib import asynccontextmanager
from typing import Optional

import nltk
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator

# ── NLTK setup ────────────────────────────────────────────────────────────────
for _pkg, _path in [
    ("stopwords", "corpora/stopwords"),
    ("punkt_tab", "tokenizers/punkt_tab"),
    ("wordnet",   "corpora/wordnet"),
]:
    try:
        nltk.data.find(_path)
    except LookupError:
        nltk.download(_pkg, quiet=True)

from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

_STOP_WORDS = nltk.corpus.stopwords.words("english")
_LEMMATIZER = WordNetLemmatizer()


# ── cleaning() β€” exact copy from notebook cell 12 ────────────────────────────
def cleaning(text: str) -> str:
    preprocessed = str(text).lower()
    preprocessed = re.sub(r"[^a-zA-Z\s]", "", preprocessed)
    words = nltk.word_tokenize(preprocessed)
    filtered_words = [word for word in words if word not in _STOP_WORDS]
    filtered_words = [_LEMMATIZER.lemmatize(word) for word in filtered_words]
    return " ".join(filtered_words)


# ── Artifact loading ──────────────────────────────────────────────────────────
ARTIFACT_DIR = os.getenv("ARTIFACT_DIR", "./artifacts")

MODEL      = None
VECTORIZER = None
ENCODER    = None


def _load(fname: str):
    path = os.path.join(ARTIFACT_DIR, fname)
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"Artifact not found: {path}\n"
            f"Unzip model.zip into {ARTIFACT_DIR}/ first."
        )
    with open(path, "rb") as f:
        return pickle.load(f)


@asynccontextmanager
async def lifespan(app: FastAPI):
    global MODEL, VECTORIZER, ENCODER
    print(f"Loading artifacts from: {ARTIFACT_DIR}")
    MODEL      = _load("model.pkl")
    VECTORIZER = _load("tfidf.pkl")
    ENCODER    = _load("encoder.pkl")
    print(f"Model loaded βœ“  |  {type(MODEL).__name__}  |  Classes: {list(ENCODER.classes_)}")
    yield
    print("Shutting down.")


# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
    title="Mental Health Sentiment Analysis API",
    version="1.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# ── Schemas ───────────────────────────────────────────────────────────────────
class PredictRequest(BaseModel):
    text: str = Field(..., min_length=3, max_length=5000)

    @field_validator("text")
    @classmethod
    def strip_text(cls, v: str) -> str:
        return v.strip()


class ClassProbability(BaseModel):
    label: str
    probability: float


class PredictResponse(BaseModel):
    label: str
    confidence: float
    probabilities: list[ClassProbability]
    cleaned_input: str
    latency_ms: float


class BatchPredictRequest(BaseModel):
    texts: list[str] = Field(..., min_length=1, max_length=50)


class BatchPredictResponse(BaseModel):
    results: list[PredictResponse]
    total_latency_ms: float


class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    model_type: Optional[str] = None
    classes: Optional[list[str]] = None


# ── Core inference ────────────────────────────────────────────────────────────
def _infer(text: str) -> PredictResponse:
    t0 = time.perf_counter()

    cleaned = cleaning(text)
    if not cleaned.strip():
        raise HTTPException(status_code=422, detail="Text is empty after preprocessing.")

    vec        = VECTORIZER.transform([cleaned])
    pred_idx   = MODEL.predict(vec)[0]
    label      = ENCODER.inverse_transform([pred_idx])[0]
    proba      = MODEL.predict_proba(vec)[0]
    confidence = float(proba[pred_idx])

    probs_sorted = [
        ClassProbability(label=cls, probability=round(float(p), 4))
        for cls, p in sorted(
            zip(ENCODER.classes_, proba),
            key=lambda x: x[1],
            reverse=True,
        )
    ]

    return PredictResponse(
        label         = label,
        confidence    = round(confidence, 4),
        probabilities = probs_sorted,
        cleaned_input = cleaned,
        latency_ms    = round((time.perf_counter() - t0) * 1000, 2),
    )


# ── Routes ────────────────────────────────────────────────────────────────────
@app.get("/", response_model=HealthResponse)
def health():
    return HealthResponse(
        status       = "ok",
        model_loaded = MODEL is not None,
        model_type   = type(MODEL).__name__ if MODEL else None,
        classes      = list(ENCODER.classes_) if ENCODER else None,
    )


@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
    if MODEL is None:
        raise HTTPException(status_code=503, detail="Model not loaded.")
    return _infer(req.text)


@app.post("/predict/batch", response_model=BatchPredictResponse)
def predict_batch(req: BatchPredictRequest):
    if MODEL is None:
        raise HTTPException(status_code=503, detail="Model not loaded.")
    t0 = time.perf_counter()
    results = [_infer(t) for t in req.texts]
    return BatchPredictResponse(
        results          = results,
        total_latency_ms = round((time.perf_counter() - t0) * 1000, 2),
    )


@app.get("/classes")
def get_classes():
    if ENCODER is None:
        raise HTTPException(status_code=503, detail="Model not loaded.")
    return {"classes": list(ENCODER.classes_)}