nhs0327 commited on
Commit
a0c98eb
·
verified ·
1 Parent(s): 5523ac1

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 재난문자 분류 API — Hugging Face Spaces 배포용
3
+
4
+ HUB_MODEL_ID를 push_to_hub.py 실행 후 업로드한 모델 ID로 변경하세요.
5
+ """
6
+
7
+ import re
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from fastapi import FastAPI
11
+ from pydantic import BaseModel
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
+
14
+ # ── 수정 필요 ──────────────────────────────────
15
+ HUB_MODEL_ID = "nhs0327/koelectra-disaster-v3"
16
+ # ──────────────────────────────────────────────
17
+
18
+ app = FastAPI(title="재난문자 분류 API")
19
+
20
+ MAX_LENGTH = 96
21
+ LABEL_NAMES = ['긴급', '주의', '일반']
22
+ UNCERTAIN_THRESH = {'긴급': 0.60, '주의': 0.70, '일반': 0.70}
23
+
24
+ _ORG_PATTERN = re.compile(r'\[[^\]]{1,20}\]')
25
+ _CERT_EMERG = [
26
+ '즉시 대피', '대피명령', '대피 명령', '긴급대피', '긴급 대피', '신속히 대피',
27
+ '지진 발생', '쓰나미', '민방공 경보', '민방공경보', '테러 발생',
28
+ ]
29
+ _CERT_CAUTION = [
30
+ '호우경보', '호우주의보', '태풍경보', '태풍주의보',
31
+ '한파경보', '한파주의보', '폭염경보', '폭염주의보',
32
+ '대설경보', '대설주의보', '강풍경보', '강풍주의보',
33
+ '풍랑경보', '풍랑주의보',
34
+ ]
35
+ _CERT_GENERAL = ['찾습니다', '실종된']
36
+
37
+ device = torch.device("cpu")
38
+ tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
39
+ model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_ID)
40
+ model.eval()
41
+
42
+
43
+ class ClassifyRequest(BaseModel):
44
+ message: str
45
+
46
+
47
+ @app.post("/classify")
48
+ async def classify(req: ClassifyRequest):
49
+ text = _ORG_PATTERN.sub('[기관]', req.message)
50
+
51
+ if any(kw in text for kw in _CERT_EMERG):
52
+ return {"label": "긴급", "confidence": 1.0, "stage": "rule", "uncertain": False}
53
+ if any(kw in text for kw in _CERT_CAUTION):
54
+ return {"label": "주의", "confidence": 1.0, "stage": "rule", "uncertain": False}
55
+ if any(kw in text for kw in _CERT_GENERAL) and 'cm' in text:
56
+ return {"label": "일반", "confidence": 1.0, "stage": "rule", "uncertain": False}
57
+
58
+ inputs = tokenizer(text, truncation=True, padding='max_length',
59
+ max_length=MAX_LENGTH, return_tensors='pt')
60
+ with torch.no_grad():
61
+ probs = F.softmax(model(**inputs).logits, dim=-1)[0]
62
+ pred_idx = probs.argmax().item()
63
+ label = LABEL_NAMES[pred_idx]
64
+ confidence = probs[pred_idx].item()
65
+
66
+ return {
67
+ "label": label,
68
+ "confidence": round(confidence, 4),
69
+ "stage": "model",
70
+ "uncertain": confidence < UNCERTAIN_THRESH[label],
71
+ "probs": {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(3)},
72
+ }
73
+
74
+
75
+ @app.get("/health")
76
+ async def health():
77
+ return {"status": "ok"}