itsLu commited on
Commit
1ef8e3b
·
1 Parent(s): 65c34b7

feat: add Longformer pipeline with lazy-load model registry

Browse files
Files changed (2) hide show
  1. app.py +178 -34
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,10 +1,19 @@
 
1
  from contextlib import asynccontextmanager
2
  from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
  import torch
6
  import torch.nn.functional as F
7
- from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
 
 
 
 
 
 
 
 
8
  import joblib
9
  import os
10
 
@@ -17,8 +26,10 @@ CHECKPOINT_PATH = os.getenv(
17
  os.path.join(_BASE, "saved_models", "mentalbert_v3flat_best.pt"),
18
  )
19
  LABEL_ENCODER_PATH = os.path.join(MODEL_DIR, "label_encoder.joblib")
 
 
 
20
  N_CLASSES = 7
21
- MAX_LEN = 128
22
  DEVICE = torch.device("cpu")
23
 
24
  LABEL_MAP: dict[str, str] = {
@@ -31,15 +42,19 @@ LABEL_MAP: dict[str, str] = {
31
  "Suicidal": "suicidal",
32
  }
33
 
34
- model_state: dict = {}
 
 
35
 
36
 
37
- @asynccontextmanager
38
- async def lifespan(_app: FastAPI):
39
- tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
 
40
 
41
- # Build BERT architecture offline using known mental-bert/bert-base-uncased dims,
42
- # then load the fine-tuned weights from the .pt checkpoint.
 
43
  cfg = BertConfig(
44
  vocab_size=30522,
45
  hidden_size=768,
@@ -54,15 +69,156 @@ async def lifespan(_app: FastAPI):
54
  model.load_state_dict(state_dict)
55
  model.to(DEVICE)
56
  model.eval()
57
-
58
  label_encoder = joblib.load(LABEL_ENCODER_PATH)
59
- model_state.update({"tokenizer": tokenizer, "model": model, "label_encoder": label_encoder})
60
- print("MentalBERT model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  yield
62
- model_state.clear()
63
 
64
 
65
- app = FastAPI(title="VibeCheck API", version="1.0.0", lifespan=lifespan)
66
 
67
  app.add_middleware(
68
  CORSMiddleware,
@@ -75,6 +231,7 @@ app.add_middleware(
75
 
76
  class ClassifyRequest(BaseModel):
77
  text: str
 
78
 
79
 
80
  class ClassifyResponse(BaseModel):
@@ -84,35 +241,22 @@ class ClassifyResponse(BaseModel):
84
 
85
  @app.get("/")
86
  def health():
87
- return {"status": "ok", "model_loaded": bool(model_state)}
 
88
 
89
 
90
  @app.post("/classify", response_model=ClassifyResponse)
91
- def classify(req: ClassifyRequest):
92
  text = req.text.strip()
93
  if not text:
94
  raise HTTPException(status_code=422, detail="text must not be empty")
95
 
96
- tokenizer = model_state["tokenizer"]
97
- model = model_state["model"]
98
- label_encoder = model_state["label_encoder"]
99
-
100
- inputs = tokenizer(
101
- text,
102
- return_tensors="pt",
103
- truncation=True,
104
- padding="max_length",
105
- max_length=MAX_LEN,
106
- )
107
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
108
-
109
- with torch.no_grad():
110
- logits = model(**inputs).logits
111
 
112
- probs = F.softmax(logits, dim=-1)
113
- confidence = float(probs.max().item())
114
- pred_idx = int(torch.argmax(probs, dim=-1).item())
115
- raw_label: str = label_encoder.inverse_transform([pred_idx])[0]
116
 
117
  return ClassifyResponse(
118
  classification=LABEL_MAP.get(raw_label, "normal"),
 
1
+ import asyncio
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
  import torch
7
  import torch.nn.functional as F
8
+ from transformers import (
9
+ AutoTokenizer,
10
+ BertForSequenceClassification,
11
+ BertConfig,
12
+ BertTokenizerFast,
13
+ LongformerForSequenceClassification,
14
+ LongformerTokenizerFast,
15
+ )
16
+ from huggingface_hub import hf_hub_download
17
  import joblib
18
  import os
19
 
 
26
  os.path.join(_BASE, "saved_models", "mentalbert_v3flat_best.pt"),
27
  )
28
  LABEL_ENCODER_PATH = os.path.join(MODEL_DIR, "label_encoder.joblib")
29
+
30
+ HF_REPO = "itsLu/mentalbert-longformer-stage3"
31
+ THRESHOLD_1A = 0.6
32
  N_CLASSES = 7
 
33
  DEVICE = torch.device("cpu")
34
 
35
  LABEL_MAP: dict[str, str] = {
 
42
  "Suicidal": "suicidal",
43
  }
44
 
45
+ # Registry: model_name -> loaded state dict
46
+ model_registry: dict = {}
47
+ _registry_locks: dict[str, asyncio.Lock] = {}
48
 
49
 
50
+ def _get_lock(name: str) -> asyncio.Lock:
51
+ if name not in _registry_locks:
52
+ _registry_locks[name] = asyncio.Lock()
53
+ return _registry_locks[name]
54
 
55
+
56
+ def _load_mentalbert() -> dict:
57
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
58
  cfg = BertConfig(
59
  vocab_size=30522,
60
  hidden_size=768,
 
69
  model.load_state_dict(state_dict)
70
  model.to(DEVICE)
71
  model.eval()
 
72
  label_encoder = joblib.load(LABEL_ENCODER_PATH)
73
+ print("MentalBERT loaded.")
74
+ return {"tokenizer": tokenizer, "model": model, "label_encoder": label_encoder}
75
+
76
+
77
+ def _load_longformer() -> dict:
78
+ # Stage 1A tokenizer (shared with 1B and 2)
79
+ tok_bert = BertTokenizerFast.from_pretrained(HF_REPO, subfolder="stage1a")
80
+
81
+ model_1a = BertForSequenceClassification.from_pretrained(HF_REPO, subfolder="stage1a")
82
+ model_1a.to(DEVICE).eval()
83
+
84
+ model_1b = BertForSequenceClassification.from_pretrained(HF_REPO, subfolder="stage1b")
85
+ model_1b.to(DEVICE).eval()
86
+
87
+ model_2 = BertForSequenceClassification.from_pretrained(HF_REPO, subfolder="stage2")
88
+ model_2.to(DEVICE).eval()
89
+
90
+ le_path = hf_hub_download(HF_REPO, "stage2/label_encoder.joblib")
91
+ label_encoder_2 = joblib.load(le_path)
92
+
93
+ tok_longformer = LongformerTokenizerFast.from_pretrained(HF_REPO, subfolder="stage3")
94
+ model_3 = LongformerForSequenceClassification.from_pretrained(HF_REPO, subfolder="stage3")
95
+ model_3.to(DEVICE).eval()
96
+
97
+ print("Longformer pipeline loaded.")
98
+ return {
99
+ "tok_bert": tok_bert,
100
+ "model_1a": model_1a,
101
+ "model_1b": model_1b,
102
+ "model_2": model_2,
103
+ "label_encoder_2": label_encoder_2,
104
+ "tok_longformer": tok_longformer,
105
+ "model_3": model_3,
106
+ }
107
+
108
+
109
+ async def get_or_load(name: str) -> dict:
110
+ if name in model_registry:
111
+ return model_registry[name]
112
+ lock = _get_lock(name)
113
+ async with lock:
114
+ # Double-check after acquiring lock
115
+ if name in model_registry:
116
+ return model_registry[name]
117
+ loop = asyncio.get_event_loop()
118
+ if name == "mentalbert":
119
+ state = await loop.run_in_executor(None, _load_mentalbert)
120
+ elif name == "longformer":
121
+ state = await loop.run_in_executor(None, _load_longformer)
122
+ else:
123
+ raise HTTPException(status_code=400, detail=f"Unknown model: {name}")
124
+ model_registry[name] = state
125
+ return model_registry[name]
126
+
127
+
128
+ def _run_mentalbert(text: str, state: dict) -> tuple[str, float]:
129
+ tokenizer = state["tokenizer"]
130
+ model = state["model"]
131
+ label_encoder = state["label_encoder"]
132
+
133
+ inputs = tokenizer(
134
+ text,
135
+ return_tensors="pt",
136
+ truncation=True,
137
+ padding="max_length",
138
+ max_length=128,
139
+ )
140
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
141
+
142
+ with torch.no_grad():
143
+ logits = model(**inputs).logits
144
+
145
+ probs = F.softmax(logits, dim=-1)
146
+ confidence = float(probs.max().item())
147
+ pred_idx = int(torch.argmax(probs, dim=-1).item())
148
+ raw_label: str = label_encoder.inverse_transform([pred_idx])[0]
149
+ return raw_label, confidence
150
+
151
+
152
+ def _run_longformer(text: str, state: dict) -> tuple[str, float]:
153
+ tok_bert = state["tok_bert"]
154
+ model_1a = state["model_1a"]
155
+ model_1b = state["model_1b"]
156
+ model_2 = state["model_2"]
157
+ label_encoder_2 = state["label_encoder_2"]
158
+ tok_longformer = state["tok_longformer"]
159
+ model_3 = state["model_3"]
160
+
161
+ def bert_inputs(text_: str) -> dict:
162
+ enc = tok_bert(
163
+ text_,
164
+ return_tensors="pt",
165
+ truncation=True,
166
+ padding="max_length",
167
+ max_length=128,
168
+ )
169
+ return {k: v.to(DEVICE) for k, v in enc.items()}
170
+
171
+ # Stage 1A — suicidal gate
172
+ with torch.no_grad():
173
+ logits_1a = model_1a(**bert_inputs(text)).logits
174
+ probs_1a = F.softmax(logits_1a, dim=-1)
175
+ if float(probs_1a[0, 1].item()) >= THRESHOLD_1A:
176
+ return "Suicidal", float(probs_1a[0, 1].item())
177
+
178
+ # Stage 1B — normal vs distress
179
+ with torch.no_grad():
180
+ logits_1b = model_1b(**bert_inputs(text)).logits
181
+ probs_1b = F.softmax(logits_1b, dim=-1)
182
+ if float(probs_1b[0, 1].item()) <= 0.5:
183
+ return "Normal", float(1.0 - probs_1b[0, 1].item())
184
+
185
+ # Stage 2 — 5-class distress
186
+ with torch.no_grad():
187
+ logits_2 = model_2(**bert_inputs(text)).logits
188
+ probs_2 = F.softmax(logits_2, dim=-1)
189
+ pred_idx_2 = int(torch.argmax(probs_2, dim=-1).item())
190
+ raw_2: str = label_encoder_2.inverse_transform([pred_idx_2])[0]
191
+ if raw_2 != "Depression":
192
+ return raw_2, float(probs_2.max().item())
193
+
194
+ # Stage 3 — depression vs suicidal re-scorer
195
+ enc3 = tok_longformer(
196
+ text,
197
+ return_tensors="pt",
198
+ truncation=True,
199
+ padding="max_length",
200
+ max_length=1024,
201
+ )
202
+ enc3 = {k: v.to(DEVICE) for k, v in enc3.items()}
203
+ # Global attention on [CLS]
204
+ global_attn = torch.zeros_like(enc3["attention_mask"])
205
+ global_attn[:, 0] = 1
206
+ enc3["global_attention_mask"] = global_attn
207
+
208
+ with torch.no_grad():
209
+ logits_3 = model_3(**enc3).logits
210
+ probs_3 = F.softmax(logits_3, dim=-1)
211
+ raw_3 = "Suicidal" if float(probs_3[0, 1].item()) > 0.5 else "Depression"
212
+ return raw_3, float(probs_3.max().item())
213
+
214
+
215
+ @asynccontextmanager
216
+ async def lifespan(_app: FastAPI):
217
  yield
218
+ model_registry.clear()
219
 
220
 
221
+ app = FastAPI(title="VibeCheck API", version="2.0.0", lifespan=lifespan)
222
 
223
  app.add_middleware(
224
  CORSMiddleware,
 
231
 
232
  class ClassifyRequest(BaseModel):
233
  text: str
234
+ model: str = "mentalbert"
235
 
236
 
237
  class ClassifyResponse(BaseModel):
 
241
 
242
  @app.get("/")
243
  def health():
244
+ loaded = list(model_registry.keys())
245
+ return {"status": "ok", "loaded_models": loaded}
246
 
247
 
248
  @app.post("/classify", response_model=ClassifyResponse)
249
+ async def classify(req: ClassifyRequest):
250
  text = req.text.strip()
251
  if not text:
252
  raise HTTPException(status_code=422, detail="text must not be empty")
253
 
254
+ state = await get_or_load(req.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ if req.model == "mentalbert":
257
+ raw_label, confidence = _run_mentalbert(text, state)
258
+ else:
259
+ raw_label, confidence = _run_longformer(text, state)
260
 
261
  return ClassifyResponse(
262
  classification=LABEL_MAP.get(raw_label, "normal"),
requirements.txt CHANGED
@@ -7,3 +7,4 @@ scikit-learn==1.5.1
7
  joblib==1.4.2
8
  safetensors==0.4.3
9
  pydantic==2.8.2
 
 
7
  joblib==1.4.2
8
  safetensors==0.4.3
9
  pydantic==2.8.2
10
+ sentencepiece