WangchanBERTa Fine-tuned for Thai News Multi-label Classification
Fine-tuned WangchanBERTa on prachathai-67k for multi-label news topic classification in Thai.
github source code: https://github.com/Datchthana1/NLP_WachangBERT/tree/main
Model Description
Given a Thai news article, the model predicts which of 12 topic categories it belongs to. A single article can belong to multiple categories simultaneously.
Training Results
| Epoch | F1 Score (micro) | Eval Loss |
|---|---|---|
| 1 | 0.9297 | 0.1734 |
| 2 | 0.9343 | 0.1642 |
| 3 | 0.9360 | 0.1621 |
Best F1 Score: 0.9360
Labels
| ID | Label | Description |
|---|---|---|
| 0 | politics | การเมือง |
| 1 | human_rights | สิทธิมนุษยชน |
| 2 | quality_of_life | คุณภาพชีวิต |
| 3 | international | ต่างประเทศ |
| 4 | social | สังคม |
| 5 | environment | สิ่งแวดล้อม |
| 6 | economics | เศรษฐกิจ |
| 7 | culture | วัฒนธรรม |
| 8 | labor | แรงงาน |
| 9 | national_security | ความมั่นคง |
| 10 | ict | เทคโนโลยี |
| 11 | education | การศึกษา |
Usage
from transformers import CamembertTokenizer, AutoModelForSequenceClassification
import torch
model_name = "Datchthana/wangchanberta-prachathai"
tokenizer = CamembertTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
model = AutoModelForSequenceClassification.from_pretrained(model_name)
LABELS = ["politics","human_rights","quality_of_life","international",
"social","environment","economics","culture",
"labor","national_security","ict","education"]
def predict(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits)[0]
return {label: round(prob.item(), 3) for label, prob in zip(LABELS, probs)}
text = "รัฐบาลประกาศนโยบายเศรษฐกิจใหม่เพื่อกระตุ้นการลงทุน"
prediction = predict(text)
print(dict(sorted(prediction.items(), key=lambda x: x[1], reverse=True)))
Training Details
| Parameter | Value |
|---|---|
| Base model | airesearch/wangchanberta-base-att-spm-uncased |
| Dataset | PyThaiNLP/prachathai67k |
| Train size | 54,379 |
| Validation size | 6,721 |
| Epochs | 3 |
| Batch size | 16 |
| Learning rate | 2e-5 |
| Max token length | 128 |
| Problem type | multi_label_classification |
- Downloads last month
- 51