--- language: - en license: apache-2.0 tags: - question-classification - text-classification - onnx - english - eat - calibrated datasets: - TigreGotico/EAT --- # eat-classifiers English question answer-type (EAT) classifiers trained on the [TigreGotico/EAT](https://huggingface.co/datasets/TigreGotico/EAT) dataset (30,017 questions, 53 fine-grained labels across 7 TREC categories). Two-stage inference (eat7 gates eat53) achieves **93.4% macro F1** on the test set. Used by [little_questions](https://github.com/OpenJarbas/little_questions). ## Label taxonomy 7 main categories, 53 sub-types: | Main | Sub-types | |------|-----------| | `ABBR` | abb, exp | | `BOOL` | yesno | | `DESC` | def, desc, manner, reason | | `ENTY` | animal, body, color, cremat, currency, dismed, event, food, instru, lang, letter, other, plant, product, religion, sport, substance, symbol, techmeth, termeq, veh, word | | `HUM` | desc, gr, ind, title | | `LOC` | city, country, landmass, mount, other, state, water | | `NUM` | code, count, date, dist, money, ord, other, perc, period, speed, temp, volsize, weight | ## Models | File | Input variant | Output[1] | |------|---------------|-----------| | `eat53_logreg_EN_0.9.0.onnx` | punctuated (written) | decision score | | `eat53_sgd_EN_0.9.0.onnx` | punctuated (written) | decision score | | `eat53_svm_EN_0.9.0.onnx` | punctuated (written) | decision score | | `eat53_svm_cal_EN_0.9.0.onnx` | punctuated (written) | calibrated probability | | `eat53_svm_cal_unpunct_EN_0.9.0.onnx` | unpunctuated (ASR/voice) | calibrated probability | | `eat7_logreg_EN_0.9.0.onnx` | punctuated (written) | decision score | | `eat7_sgd_EN_0.9.0.onnx` | punctuated (written) | decision score | | `eat7_svm_EN_0.9.0.onnx` | punctuated (written) | decision score | | `eat7_svm_cal_EN_0.9.0.onnx` | punctuated (written) | calibrated probability | | `eat7_svm_cal_unpunct_EN_0.9.0.onnx` | unpunctuated (ASR/voice) | calibrated probability | Both punctuated and unpunctuated variants are provided. Use the unpunctuated (`_unpunct`) model for ASR / voice assistant input. ## Two-stage inference ```python import onnxruntime as rt, numpy as np, json sess7 = rt.InferenceSession("eat7_svm_cal_EN_0.9.0.onnx") sess53 = rt.InferenceSession("eat53_svm_cal_EN_0.9.0.onnx") classes7 = json.loads(sess7.get_modelmeta().custom_metadata_map["classes"]) classes53 = json.loads(sess53.get_modelmeta().custom_metadata_map["classes"]) main_of_53 = [c.split(":")[0] for c in classes53] def classify(text): inp = np.array([text], dtype=object) main = classes7[int(sess7.run(None, {"input": inp})[0][0])] _, probs = sess53.run(None, {"input": inp}) row = probs[0].copy() for j, m in enumerate(main_of_53): if m != main: row[j] = 0.0 row /= row.sum() return classes53[int(np.argmax(row))], float(row.max()) print(classify("Who invented the telephone?")) # ('HUM:ind', 0.96) ``` ## Benchmarks Full results: [BENCHMARKS.md](BENCHMARKS.md)