nhs0327 commited on
Commit
d46ae28
·
verified ·
1 Parent(s): 5f4c36c

v9n 모델로 업데이트 (KoELECTRA → koelectra-disaster-v9n)

Browse files
Files changed (1) hide show
  1. app.py +32 -34
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- 재난문자 분류 API — Hugging Face Spaces 배포용
3
 
4
  HUB_MODEL_ID를 push_to_hub.py 실행 후 업로드한 모델 ID로 변경하세요.
5
  """
@@ -12,28 +12,26 @@ from pydantic import BaseModel
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
 
14
  # ── 수정 필요 ──────────────────────────────────
15
- HUB_MODEL_ID = "nhs0327/koelectra-disaster-v3"
16
  # ──────────────────────────────────────────────
17
 
18
  app = FastAPI(title="재난문자 분류 API")
19
 
20
- MAX_LENGTH = 96
21
- LABEL_NAMES = ['긴급', '주의', '일반']
22
- UNCERTAIN_THRESH = {'긴급': 0.60, '주의': 0.70, '일반': 0.70}
23
- EMERG_THRESH = 0.10
24
-
25
- _ORG_PATTERN = re.compile(r'\[[^\]]{1,20}\]')
26
- _CERT_EMERG = [
27
- '즉시 대피', '대피명령', '대피 명령', '긴급대피', '긴급 대피', '신속히 대피',
28
- '지진 발생', '쓰나미', '민방공 경보', '민방공경보', '테러 발생',
29
- ]
30
- _CERT_CAUTION = [
31
- '호우경보', '호우주의보', '태풍경보', '태풍주의보',
32
- '한파경보', '한파주의', '폭염경보', '폭염주의보',
33
- '대설경보', '대설주의보', '강풍경보', '강풍주의보',
34
- '풍랑경보', '풍랑주의보',
35
- ]
36
- _CERT_GENERAL = ['찾습니다', '실종된']
37
 
38
  device = torch.device("cpu")
39
  tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
@@ -49,31 +47,31 @@ class ClassifyRequest(BaseModel):
49
  async def classify(req: ClassifyRequest):
50
  text = _ORG_PATTERN.sub('[기관]', req.message)
51
 
52
- has_emerg = any(kw in text for kw in _CERT_EMERG)
53
- has_caution = any(kw in text for kw in _CERT_CAUTION)
54
- has_general = any(kw in text for kw in _CERT_GENERAL) and 'cm' in text
55
-
56
- if has_emerg and not has_caution:
57
- return {"label": "긴급", "confidence": 1.0, "stage": "rule", "uncertain": False}
58
- if has_caution and not has_emerg:
59
- return {"label": "주의", "confidence": 1.0, "stage": "rule", "uncertain": False}
60
- if has_general and not has_emerg and not has_caution:
61
- return {"label": "일반", "confidence": 1.0, "stage": "rule", "uncertain": False}
62
-
63
  inputs = tokenizer(text, truncation=True, padding='max_length',
64
  max_length=MAX_LENGTH, return_tensors='pt')
65
  with torch.no_grad():
66
  probs = F.softmax(model(**inputs).logits, dim=-1)[0]
67
- pred_idx = 0 if probs[0].item() >= EMERG_THRESH else probs.argmax().item()
 
 
 
 
 
 
 
 
 
68
  label = LABEL_NAMES[pred_idx]
69
- confidence = probs[pred_idx].item()
 
70
 
71
  return {
72
  "label": label,
 
73
  "confidence": round(confidence, 4),
74
  "stage": "model",
75
- "uncertain": confidence < UNCERTAIN_THRESH[label],
76
- "probs": {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(3)},
77
  }
78
 
79
 
 
1
  """
2
+ 재난문자 분류 API — Hugging Face Spaces 배포용 (KLUE-BERT 5-class)
3
 
4
  HUB_MODEL_ID를 push_to_hub.py 실행 후 업로드한 모델 ID로 변경하세요.
5
  """
 
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
 
14
  # ── 수정 필요 ──────────────────────────────────
15
+ HUB_MODEL_ID = "nhs0327/koelectra-disaster-v9n"
16
  # ──────────────────────────────────────────────
17
 
18
  app = FastAPI(title="재난문자 분류 API")
19
 
20
+ MAX_LENGTH = 128
21
+ LABEL_NAMES = ['긴급 아님', '낮은 긴급성', '중간 긴급성', '높은 긴급성', '매우 높은 긴급성']
22
+ UNCERTAIN_THRESH = 0.70
23
+ L3_THRESHOLD = 0.69
24
+
25
+ _ORG_PATTERN = re.compile(r'\[[^\]]{1,20}\]')
26
+
27
+
28
+ def label_to_priority(idx: int) -> str:
29
+ if idx == 4:
30
+ return '긴급'
31
+ if idx in (2, 3):
32
+ return '주의'
33
+ return '일반'
34
+
 
 
35
 
36
  device = torch.device("cpu")
37
  tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_ID)
 
47
  async def classify(req: ClassifyRequest):
48
  text = _ORG_PATTERN.sub('[기관]', req.message)
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  inputs = tokenizer(text, truncation=True, padding='max_length',
51
  max_length=MAX_LENGTH, return_tensors='pt')
52
  with torch.no_grad():
53
  probs = F.softmax(model(**inputs).logits, dim=-1)[0]
54
+
55
+ import numpy as np
56
+ probs_np = probs.cpu().numpy()
57
+ if probs_np[3] >= L3_THRESHOLD:
58
+ pred_idx = 3
59
+ else:
60
+ probs_mod = probs_np.copy()
61
+ probs_mod[3] = -1.0
62
+ pred_idx = int(probs_mod.argmax())
63
+
64
  label = LABEL_NAMES[pred_idx]
65
+ confidence = float(probs_np[pred_idx])
66
+ priority = label_to_priority(pred_idx)
67
 
68
  return {
69
  "label": label,
70
+ "priority": priority,
71
  "confidence": round(confidence, 4),
72
  "stage": "model",
73
+ "uncertain": confidence < UNCERTAIN_THRESH,
74
+ "probs": {LABEL_NAMES[i]: round(probs[i].item(), 4) for i in range(5)},
75
  }
76
 
77