현진_app.py 우울모델 추가
Browse files
app.py
CHANGED
|
@@ -3,11 +3,15 @@ import re
|
|
| 3 |
import time
|
| 4 |
import requests
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 6 |
from fastapi import FastAPI, HTTPException
|
| 7 |
from pydantic import BaseModel
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
-
import os
|
| 10 |
from typing import Optional, List,Dict
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
#####################################
|
|
@@ -1147,6 +1151,39 @@ def chat_response(user_input, mode="emotion", max_retries=5):
|
|
| 1147 |
return "🚨 모델 로딩이 너무 오래 걸립니다. 잠시 후 다시 시도하세요."
|
| 1148 |
|
| 1149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1150 |
#####################################
|
| 1151 |
# 6) FastAPI Endpoint
|
| 1152 |
#####################################
|
|
@@ -1202,6 +1239,19 @@ class ChatOrRecommendRequest(BaseModel):
|
|
| 1202 |
# (5) 자동 분기 엔드포인트
|
| 1203 |
@app.post("/chat_or_recommend")
|
| 1204 |
def chat_or_recommend(req: ChatOrRecommendRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1205 |
user_text = req.user_input
|
| 1206 |
mode = req.mode.lower()
|
| 1207 |
|
|
|
|
| 3 |
import time
|
| 4 |
import requests
|
| 5 |
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import joblib
|
| 8 |
+
import xgboost as xgb
|
| 9 |
from fastapi import FastAPI, HTTPException
|
| 10 |
from pydantic import BaseModel
|
| 11 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 12 |
from typing import Optional, List,Dict
|
| 13 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
| 14 |
+
|
| 15 |
|
| 16 |
|
| 17 |
#####################################
|
|
|
|
| 1151 |
return "🚨 모델 로딩이 너무 오래 걸립니다. 잠시 후 다시 시도하세요."
|
| 1152 |
|
| 1153 |
|
| 1154 |
+
#우울분류 모델 추가
|
| 1155 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1156 |
+
|
| 1157 |
+
tokenizer = BertTokenizer.from_pretrained("monologg/kobert")
|
| 1158 |
+
bert_model = BertForSequenceClassification.from_pretrained("monologg/kobert", num_labels=2)
|
| 1159 |
+
bert_model.load_state_dict(torch.load("emotion_bert_model.pth", map_location=device))
|
| 1160 |
+
bert_model.to(device)
|
| 1161 |
+
bert_model.eval()
|
| 1162 |
+
|
| 1163 |
+
xgb_model = joblib.load("xgboost_model.pkl")
|
| 1164 |
+
vectorizer = joblib.load("tfidf_vectorizer.pkl")
|
| 1165 |
+
|
| 1166 |
+
def predict_depression(text: str):
|
| 1167 |
+
encoding = tokenizer(text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
|
| 1168 |
+
input_ids = encoding["input_ids"].to(device)
|
| 1169 |
+
attention_mask = encoding["attention_mask"].to(device)
|
| 1170 |
+
with torch.no_grad():
|
| 1171 |
+
outputs = bert_model(input_ids, attention_mask=attention_mask)
|
| 1172 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 1173 |
+
kobert_score = probabilities[0][1].item()
|
| 1174 |
+
text_vec = vectorizer.transform([text])
|
| 1175 |
+
xgb_proba = xgb_model.predict_proba(text_vec)[0][1]
|
| 1176 |
+
kobert_score = max(0.35, min(kobert_score, 0.88))
|
| 1177 |
+
xgb_proba = max(0.3, min(xgb_proba, 0.83))
|
| 1178 |
+
combined_score = (kobert_score * 0.55) + (xgb_proba * 0.45)
|
| 1179 |
+
if combined_score > 0.78:
|
| 1180 |
+
label = "상담 권장"
|
| 1181 |
+
elif combined_score > 0.65:
|
| 1182 |
+
label = "관심 필요"
|
| 1183 |
+
else:
|
| 1184 |
+
label = "정상"
|
| 1185 |
+
return combined_score, label
|
| 1186 |
+
|
| 1187 |
#####################################
|
| 1188 |
# 6) FastAPI Endpoint
|
| 1189 |
#####################################
|
|
|
|
| 1239 |
# (5) 자동 분기 엔드포인트
|
| 1240 |
@app.post("/chat_or_recommend")
|
| 1241 |
def chat_or_recommend(req: ChatOrRecommendRequest):
|
| 1242 |
+
depression_score, depression_label = predict_depression(req.user_input)
|
| 1243 |
+
if depression_label == "상담 권장":
|
| 1244 |
+
counseling_response = (
|
| 1245 |
+
"입력하신 메시지에서 심각한 우울 신호가 감지되었습니다.\n"
|
| 1246 |
+
"전문 상담을 받으실 것을 강력히 권장드립니다.\n"
|
| 1247 |
+
"빠른 시일 내에 전문가와 상담하시길 바랍니다."
|
| 1248 |
+
)
|
| 1249 |
+
return {
|
| 1250 |
+
"mode": "counseling",
|
| 1251 |
+
"response": counseling_response,
|
| 1252 |
+
"depression_score": round(depression_score, 4),
|
| 1253 |
+
"depression_label": depression_label
|
| 1254 |
+
}
|
| 1255 |
user_text = req.user_input
|
| 1256 |
mode = req.mode.lower()
|
| 1257 |
|