| --- |
| language: |
| - en |
| license: apache-2.0 |
| tags: |
| - question-classification |
| - text-classification |
| - onnx |
| - english |
| - eat |
| - calibrated |
| datasets: |
| - TigreGotico/EAT |
| --- |
| |
| # eat-classifiers |
|
|
| English question answer-type (EAT) classifiers trained on the |
| [TigreGotico/EAT](https://huggingface.co/datasets/TigreGotico/EAT) dataset |
| (30,017 questions, 53 fine-grained labels across 7 TREC categories). |
|
|
| Two-stage inference (eat7 gates eat53) achieves **93.4% macro F1** on the test set. |
|
|
| Used by [little_questions](https://github.com/OpenJarbas/little_questions). |
|
|
| ## Label taxonomy |
|
|
| 7 main categories, 53 sub-types: |
|
|
| | Main | Sub-types | |
| |------|-----------| |
| | `ABBR` | abb, exp | |
| | `BOOL` | yesno | |
| | `DESC` | def, desc, manner, reason | |
| | `ENTY` | animal, body, color, cremat, currency, dismed, event, food, instru, lang, letter, other, plant, product, religion, sport, substance, symbol, techmeth, termeq, veh, word | |
| | `HUM` | desc, gr, ind, title | |
| | `LOC` | city, country, landmass, mount, other, state, water | |
| | `NUM` | code, count, date, dist, money, ord, other, perc, period, speed, temp, volsize, weight | |
|
|
| ## Models |
|
|
| | File | Input variant | Output[1] | |
| |------|---------------|-----------| |
| | `eat53_logreg_EN_0.9.0.onnx` | punctuated (written) | decision score | |
| | `eat53_sgd_EN_0.9.0.onnx` | punctuated (written) | decision score | |
| | `eat53_svm_EN_0.9.0.onnx` | punctuated (written) | decision score | |
| | `eat53_svm_cal_EN_0.9.0.onnx` | punctuated (written) | calibrated probability | |
| | `eat53_svm_cal_unpunct_EN_0.9.0.onnx` | unpunctuated (ASR/voice) | calibrated probability | |
| | `eat7_logreg_EN_0.9.0.onnx` | punctuated (written) | decision score | |
| | `eat7_sgd_EN_0.9.0.onnx` | punctuated (written) | decision score | |
| | `eat7_svm_EN_0.9.0.onnx` | punctuated (written) | decision score | |
| | `eat7_svm_cal_EN_0.9.0.onnx` | punctuated (written) | calibrated probability | |
| | `eat7_svm_cal_unpunct_EN_0.9.0.onnx` | unpunctuated (ASR/voice) | calibrated probability | |
|
|
| Both punctuated and unpunctuated variants are provided. |
| Use the unpunctuated (`_unpunct`) model for ASR / voice assistant input. |
|
|
| ## Two-stage inference |
|
|
| ```python |
| import onnxruntime as rt, numpy as np, json |
| |
| sess7 = rt.InferenceSession("eat7_svm_cal_EN_0.9.0.onnx") |
| sess53 = rt.InferenceSession("eat53_svm_cal_EN_0.9.0.onnx") |
| classes7 = json.loads(sess7.get_modelmeta().custom_metadata_map["classes"]) |
| classes53 = json.loads(sess53.get_modelmeta().custom_metadata_map["classes"]) |
| main_of_53 = [c.split(":")[0] for c in classes53] |
| |
| def classify(text): |
| inp = np.array([text], dtype=object) |
| main = classes7[int(sess7.run(None, {"input": inp})[0][0])] |
| _, probs = sess53.run(None, {"input": inp}) |
| row = probs[0].copy() |
| for j, m in enumerate(main_of_53): |
| if m != main: |
| row[j] = 0.0 |
| row /= row.sum() |
| return classes53[int(np.argmax(row))], float(row.max()) |
| |
| print(classify("Who invented the telephone?")) # ('HUM:ind', 0.96) |
| ``` |
|
|
| ## Benchmarks |
|
|
| Full results: [BENCHMARKS.md](BENCHMARKS.md) |
|
|