pluto90 commited on
Commit
5c8556d
·
verified ·
1 Parent(s): d89ba8b

Update api/main.py

Browse files
Files changed (1) hide show
  1. api/main.py +248 -0
api/main.py CHANGED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ import json
7
+ import os
8
+ import sys
9
+ from datetime import datetime
10
+ import torch
11
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
12
+ import re
13
+ import shap
14
+ import numpy as np
15
+
16
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+ os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+ # Config
20
+ MODEL_DIR = "models"
21
+ BEST_METRICS_PATH = "models/best_metrics.json"
22
+ DRIFT_LOG_PATH = "models/drift_log.json"
23
+ RETRAIN_LOG_PATH = "models/retrain_log.json"
24
+
25
+ app = FastAPI(
26
+ title="Sentiment ML System",
27
+ description="Production ML system with DistilBERT",
28
+ version="2.0.0"
29
+ )
30
+
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["http://localhost:5173"],
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ # Load model
40
+ print("Loading DistilBERT model...")
41
+ tokenizer = DistilBertTokenizer.from_pretrained(MODEL_DIR)
42
+ model = DistilBertForSequenceClassification.from_pretrained(MODEL_DIR)
43
+ model.eval()
44
+
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ model.to(device)
47
+ print(f"✓ DistilBERT loaded on {device}")
48
+
49
+ class ReviewRequest(BaseModel):
50
+ review: str
51
+
52
+ class PredictionResponse(BaseModel):
53
+ sentiment: str
54
+ confidence: float
55
+ label: int
56
+ timestamp: str
57
+
58
+ class ExplanationResponse(BaseModel):
59
+ sentiment: str
60
+ confidence: float
61
+ label: int
62
+ explanation: list
63
+ timestamp: str
64
+
65
+ def preprocess_text(text):
66
+ text = text.lower()
67
+ text = re.sub(r"<.*?>", "", text)
68
+ text = re.sub(r"[^a-z0-9\s]", "", text)
69
+ return text.strip()
70
+
71
+ @app.get("/")
72
+ def root():
73
+ return {"status": "running", "message": "Sentiment ML System - DistilBERT"}
74
+
75
+ @app.post("/predict", response_model=PredictionResponse)
76
+ def predict(request: ReviewRequest):
77
+ if not request.review.strip():
78
+ raise HTTPException(status_code=400, detail="Review text cannot be empty")
79
+
80
+ try:
81
+ review = preprocess_text(request.review)
82
+
83
+ inputs = tokenizer(
84
+ review,
85
+ return_tensors="pt",
86
+ truncation=True,
87
+ max_length=256,
88
+ padding="max_length"
89
+ )
90
+
91
+ inputs = {k: v.to(device) for k, v in inputs.items()}
92
+
93
+ with torch.no_grad():
94
+ outputs = model(**inputs)
95
+ logits = outputs.logits
96
+ probabilities = torch.softmax(logits, dim=-1)
97
+ label = int(torch.argmax(probabilities, dim=-1).item())
98
+ confidence = float(probabilities[0][label].item())
99
+
100
+ sentiment = "Positive" if label == 1 else "Negative"
101
+
102
+ return PredictionResponse(
103
+ sentiment=sentiment,
104
+ confidence=round(confidence, 4),
105
+ label=label,
106
+ timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
107
+ )
108
+ except Exception as e:
109
+ raise HTTPException(status_code=500, detail=str(e))
110
+
111
+ @app.get("/metrics")
112
+ def get_metrics():
113
+ response = {}
114
+
115
+ if os.path.exists(BEST_METRICS_PATH):
116
+ with open(BEST_METRICS_PATH, "r") as f:
117
+ response["best_model"] = json.load(f)
118
+ else:
119
+ response["best_model"] = None
120
+
121
+ if os.path.exists(DRIFT_LOG_PATH):
122
+ with open(DRIFT_LOG_PATH, "r") as f:
123
+ response["drift_log"] = json.load(f)
124
+ else:
125
+ response["drift_log"] = []
126
+
127
+ if os.path.exists(RETRAIN_LOG_PATH):
128
+ with open(RETRAIN_LOG_PATH, "r") as f:
129
+ response["retrain_log"] = json.load(f)
130
+ else:
131
+ response["retrain_log"] = []
132
+
133
+ return response
134
+
135
+ @app.get("/health")
136
+ def health():
137
+ return {
138
+ "status": "healthy",
139
+ "model": "DistilBERT",
140
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
141
+ }
142
+
143
+
144
+ @app.post("/explain", response_model=ExplanationResponse)
145
+ def explain(request: ReviewRequest):
146
+ if not request.review.strip():
147
+ raise HTTPException(status_code=400, detail="Review text cannot be empty")
148
+
149
+ try:
150
+ review = preprocess_text(request.review)
151
+
152
+ # Get prediction first
153
+ inputs = tokenizer(
154
+ review,
155
+ return_tensors="pt",
156
+ truncation=True,
157
+ max_length=256,
158
+ padding="max_length",
159
+ return_offsets_mapping=True
160
+ )
161
+
162
+ offset_mapping = inputs.pop("offset_mapping")[0]
163
+ inputs = {k: v.to(device) for k, v in inputs.items()}
164
+
165
+ with torch.no_grad():
166
+ outputs = model(**inputs)
167
+ logits = outputs.logits
168
+ probabilities = torch.softmax(logits, dim=-1)
169
+ label = int(torch.argmax(probabilities, dim=-1).item())
170
+ confidence = float(probabilities[0][label].item())
171
+
172
+ sentiment = "Positive" if label == 1 else "Negative"
173
+
174
+ # SHAP explanation
175
+ def model_predict(texts):
176
+ """Wrapper for SHAP"""
177
+ all_probs = []
178
+ for text in texts:
179
+ text_clean = preprocess_text(text)
180
+ inputs = tokenizer(
181
+ text_clean,
182
+ return_tensors="pt",
183
+ truncation=True,
184
+ max_length=256,
185
+ padding="max_length"
186
+ )
187
+ inputs = {k: v.to(device) for k, v in inputs.items()}
188
+
189
+ with torch.no_grad():
190
+ outputs = model(**inputs)
191
+ probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
192
+ all_probs.append(probs)
193
+
194
+ return np.array(all_probs)
195
+
196
+ # Create explainer
197
+ explainer = shap.Explainer(model_predict, tokenizer)
198
+
199
+ # Get SHAP values
200
+ shap_values = explainer([review])
201
+
202
+ # Extract word impacts for the predicted class
203
+ tokens = tokenizer.tokenize(review)
204
+ token_impacts = shap_values.values[0, :, label]
205
+
206
+ # Map tokens back to words
207
+ word_impacts = []
208
+ current_word = ""
209
+ current_impact = 0.0
210
+
211
+ for i, (token, impact) in enumerate(zip(tokens, token_impacts)):
212
+ if token.startswith("##"):
213
+ # Continuation of previous word
214
+ current_word += token[2:]
215
+ current_impact += impact
216
+ else:
217
+ # New word
218
+ if current_word:
219
+ word_impacts.append({
220
+ "word": current_word,
221
+ "impact": round(float(current_impact), 4)
222
+ })
223
+ current_word = token
224
+ current_impact = impact
225
+
226
+ # Add last word
227
+ if current_word:
228
+ word_impacts.append({
229
+ "word": current_word,
230
+ "impact": round(float(current_impact), 4)
231
+ })
232
+
233
+ # Filter out special tokens and very low impacts
234
+ word_impacts = [
235
+ w for w in word_impacts
236
+ if w["word"] not in ["[CLS]", "[SEP]", "[PAD]"] and abs(w["impact"]) > 0.01
237
+ ]
238
+
239
+ return ExplanationResponse(
240
+ sentiment=sentiment,
241
+ confidence=round(confidence, 4),
242
+ label=label,
243
+ explanation=word_impacts,
244
+ timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
245
+ )
246
+
247
+ except Exception as e:
248
+ raise HTTPException(status_code=500, detail=str(e))