eat-classifiers / README.md
Jarbas's picture
Upload README.md with huggingface_hub
bb5810b verified
---
language:
- en
license: apache-2.0
tags:
- question-classification
- text-classification
- onnx
- english
- eat
- calibrated
datasets:
- TigreGotico/EAT
---
# eat-classifiers
English question answer-type (EAT) classifiers trained on the
[TigreGotico/EAT](https://huggingface.co/datasets/TigreGotico/EAT) dataset
(30,017 questions, 53 fine-grained labels across 7 TREC categories).
Two-stage inference (eat7 gates eat53) achieves **93.4% macro F1** on the test set.
Used by [little_questions](https://github.com/OpenJarbas/little_questions).
## Label taxonomy
7 main categories, 53 sub-types:
| Main | Sub-types |
|------|-----------|
| `ABBR` | abb, exp |
| `BOOL` | yesno |
| `DESC` | def, desc, manner, reason |
| `ENTY` | animal, body, color, cremat, currency, dismed, event, food, instru, lang, letter, other, plant, product, religion, sport, substance, symbol, techmeth, termeq, veh, word |
| `HUM` | desc, gr, ind, title |
| `LOC` | city, country, landmass, mount, other, state, water |
| `NUM` | code, count, date, dist, money, ord, other, perc, period, speed, temp, volsize, weight |
## Models
| File | Input variant | Output[1] |
|------|---------------|-----------|
| `eat53_logreg_EN_0.9.0.onnx` | punctuated (written) | decision score |
| `eat53_sgd_EN_0.9.0.onnx` | punctuated (written) | decision score |
| `eat53_svm_EN_0.9.0.onnx` | punctuated (written) | decision score |
| `eat53_svm_cal_EN_0.9.0.onnx` | punctuated (written) | calibrated probability |
| `eat53_svm_cal_unpunct_EN_0.9.0.onnx` | unpunctuated (ASR/voice) | calibrated probability |
| `eat7_logreg_EN_0.9.0.onnx` | punctuated (written) | decision score |
| `eat7_sgd_EN_0.9.0.onnx` | punctuated (written) | decision score |
| `eat7_svm_EN_0.9.0.onnx` | punctuated (written) | decision score |
| `eat7_svm_cal_EN_0.9.0.onnx` | punctuated (written) | calibrated probability |
| `eat7_svm_cal_unpunct_EN_0.9.0.onnx` | unpunctuated (ASR/voice) | calibrated probability |
Both punctuated and unpunctuated variants are provided.
Use the unpunctuated (`_unpunct`) model for ASR / voice assistant input.
## Two-stage inference
```python
import onnxruntime as rt, numpy as np, json
sess7 = rt.InferenceSession("eat7_svm_cal_EN_0.9.0.onnx")
sess53 = rt.InferenceSession("eat53_svm_cal_EN_0.9.0.onnx")
classes7 = json.loads(sess7.get_modelmeta().custom_metadata_map["classes"])
classes53 = json.loads(sess53.get_modelmeta().custom_metadata_map["classes"])
main_of_53 = [c.split(":")[0] for c in classes53]
def classify(text):
inp = np.array([text], dtype=object)
main = classes7[int(sess7.run(None, {"input": inp})[0][0])]
_, probs = sess53.run(None, {"input": inp})
row = probs[0].copy()
for j, m in enumerate(main_of_53):
if m != main:
row[j] = 0.0
row /= row.sum()
return classes53[int(np.argmax(row))], float(row.max())
print(classify("Who invented the telephone?")) # ('HUM:ind', 0.96)
```
## Benchmarks
Full results: [BENCHMARKS.md](BENCHMARKS.md)