minde-classifier / README.md
atti433's picture
Add model card
2b2fbde verified
|
Raw
History Blame Contribute Delete
2.45 kB
---
language:
- ko
license: other
library_name: transformers
pipeline_tag: text-classification
base_model: klue/bert-base
tags:
- bert
- klue
- korean
- text-classification
- minwon
- complaint
- public-administration
---
# MindE ๋ฏผ์› ๋ถ„๋ฅ˜๊ธฐ (bert-v9)
ํ•œ๊ตญ ๊ณต๊ณต ๋ฏผ์›์„ **11๊ฐœ ์นดํ…Œ๊ณ ๋ฆฌ**๋กœ ์ž๋™ ๋ถ„๋ฅ˜ํ•˜๋Š” KLUE BERT ๊ธฐ๋ฐ˜ ๋ชจ๋ธ.
## ์นดํ…Œ๊ณ ๋ฆฌ (11)
| ID | ์นดํ…Œ๊ณ ๋ฆฌ | per-class F1 |
|---:|---|---:|
| 1 | ๊ตํ†ต | 0.882 |
| 2 | ๊ฑด์ถ• | 0.755 |
| 3 | ํ–‰์ • | 0.812 |
| 4 | ๋ณด๊ฑด์œ„์ƒ | 0.911 |
| 5 | ํ™˜๊ฒฝ | 0.874 |
| 6 | ๋ฌธํ™”_์—ฌ๊ฐ€ | 0.825 |
| 7 | ๋†์ถ•์‚ฐ | 0.909 |
| 8 | ๋ณต์ง€ | 0.866 |
| 9 | ์„ธ๋ฌด | 0.974 |
| 10 | ์ƒํ•˜์ˆ˜๋„ | 0.921 |
| 11 | ๊ฒฝ์ œ | 0.874 |
**Test set (20,788๊ฑด)**
- Accuracy: **0.871**
- Macro F1: **0.873**
- Weighted F1: 0.871
## ํ•™์Šต ๋ฐ์ดํ„ฐ
- AI Hub 143๋ฒˆ "๋ฏผ์› ์—…๋ฌด ํšจ์œจ, ์ž๋™ํ™”๋ฅผ ์œ„ํ•œ ์–ธ์–ด AI ํ•™์Šต๋ฐ์ดํ„ฐ" (~86๋งŒ ๊ฑด, 18 ์นดํ…Œ๊ณ ๋ฆฌ โ†’ 11 ๋งคํ•‘)
- group_id ๋‹จ์œ„ 8:1:1 ๋ถ„ํ•  + ์นดํ…Œ๊ณ ๋ฆฌ๋‹น train 20k cap
- ๋งˆ์Šคํ‚น ํ† ํฐ(`#@์ฃผ์†Œ#` ๋“ฑ) โ†’ special token(`[ADDR]` ๋“ฑ) ์น˜ํ™˜
## ํ•™์Šต ์„ค์ •
- Base: `klue/bert-base`
- max_length: 128
- batch_size: 32
- epochs: 3
- learning_rate: 2e-5
- warmup_ratio: 0.1
- weight_decay: 0.01
- ํ•™์Šต ์‹œ๊ฐ„: ~45๋ถ„ (RTX 4060 Ti)
## ์‚ฌ์šฉ ์˜ˆ์‹œ
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained("atti433/minde-classifier")
model = AutoModelForSequenceClassification.from_pretrained("atti433/minde-classifier")
text = "์ง‘ ์•ž์— ์ฐจ๊ฐ€ ์ž๊พธ ๋ถˆ๋ฒ•์ฃผ์ฐจํ•ด์„œ ๋„ˆ๋ฌด ๋ถˆํŽธํ•ฉ๋‹ˆ๋‹ค."
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
labels = ['๊ตํ†ต','๊ฑด์ถ•','ํ–‰์ •','๋ณด๊ฑด์œ„์ƒ','ํ™˜๊ฒฝ','๋ฌธํ™”_์—ฌ๊ฐ€','๋†์ถ•์‚ฐ','๋ณต์ง€','์„ธ๋ฌด','์ƒํ•˜์ˆ˜๋„','๊ฒฝ์ œ']
pred = labels[probs.argmax().item()]
print(pred, probs.max().item())
```
๋˜๋Š” ๋ณธ ํ”„๋กœ์ ํŠธ์˜ `chatbot_service.classify_complaint()` ์‚ฌ์šฉ.
## ํ•œ๊ณ„
- ํ•™์Šต ๋ฐ์ดํ„ฐ(AI Hub 143)๋Š” ์ฐฝ์›์‹œ ๋ฏผ์› ์ค‘์‹ฌ์ด๋ผ ์ง€์—ญ ์–ดํœ˜ ํŽธํ–ฅ ๊ฐ€๋Šฅ
- "๊ฑด์ถ•" ์นดํ…Œ๊ณ ๋ฆฌ F1 0.755๊ฐ€ ๊ฐ€์žฅ ๋‚ฎ์Œ โ€” ์•ˆ์ „๊ฑด์„ค๊ณผ raw_category์— ๋„๋กœ/์‹œ์„ค ๋ฏผ์›์ด ์„ž์—ฌ์žˆ๋˜ ๋ผ๋ฒจ ๋…ธ์ด์ฆˆ ์˜ํ–ฅ
- ๋™์Œ์ด์˜/์งง์€ ํ…์ŠคํŠธ(์˜ˆ: "์‹ ํ˜ธ๋“ฑ")๋Š” confidence ๋‚ฎ์Œ. top-3๋กœ ๋ฐ›์•„์„œ LLM์ด ํŒ๋‹จ ๊ถŒ์žฅ