eat-classifiers / README.md
Jarbas's picture
Upload README.md with huggingface_hub
bb5810b verified
metadata
language:
  - en
license: apache-2.0
tags:
  - question-classification
  - text-classification
  - onnx
  - english
  - eat
  - calibrated
datasets:
  - TigreGotico/EAT

eat-classifiers

English question answer-type (EAT) classifiers trained on the TigreGotico/EAT dataset (30,017 questions, 53 fine-grained labels across 7 TREC categories).

Two-stage inference (eat7 gates eat53) achieves 93.4% macro F1 on the test set.

Used by little_questions.

Label taxonomy

7 main categories, 53 sub-types:

Main Sub-types
ABBR abb, exp
BOOL yesno
DESC def, desc, manner, reason
ENTY animal, body, color, cremat, currency, dismed, event, food, instru, lang, letter, other, plant, product, religion, sport, substance, symbol, techmeth, termeq, veh, word
HUM desc, gr, ind, title
LOC city, country, landmass, mount, other, state, water
NUM code, count, date, dist, money, ord, other, perc, period, speed, temp, volsize, weight

Models

File Input variant Output[1]
eat53_logreg_EN_0.9.0.onnx punctuated (written) decision score
eat53_sgd_EN_0.9.0.onnx punctuated (written) decision score
eat53_svm_EN_0.9.0.onnx punctuated (written) decision score
eat53_svm_cal_EN_0.9.0.onnx punctuated (written) calibrated probability
eat53_svm_cal_unpunct_EN_0.9.0.onnx unpunctuated (ASR/voice) calibrated probability
eat7_logreg_EN_0.9.0.onnx punctuated (written) decision score
eat7_sgd_EN_0.9.0.onnx punctuated (written) decision score
eat7_svm_EN_0.9.0.onnx punctuated (written) decision score
eat7_svm_cal_EN_0.9.0.onnx punctuated (written) calibrated probability
eat7_svm_cal_unpunct_EN_0.9.0.onnx unpunctuated (ASR/voice) calibrated probability

Both punctuated and unpunctuated variants are provided. Use the unpunctuated (_unpunct) model for ASR / voice assistant input.

Two-stage inference

import onnxruntime as rt, numpy as np, json

sess7  = rt.InferenceSession("eat7_svm_cal_EN_0.9.0.onnx")
sess53 = rt.InferenceSession("eat53_svm_cal_EN_0.9.0.onnx")
classes7  = json.loads(sess7.get_modelmeta().custom_metadata_map["classes"])
classes53 = json.loads(sess53.get_modelmeta().custom_metadata_map["classes"])
main_of_53 = [c.split(":")[0] for c in classes53]

def classify(text):
    inp = np.array([text], dtype=object)
    main = classes7[int(sess7.run(None, {"input": inp})[0][0])]
    _, probs = sess53.run(None, {"input": inp})
    row = probs[0].copy()
    for j, m in enumerate(main_of_53):
        if m != main:
            row[j] = 0.0
    row /= row.sum()
    return classes53[int(np.argmax(row))], float(row.max())

print(classify("Who invented the telephone?"))  # ('HUM:ind', 0.96)

Benchmarks

Full results: BENCHMARKS.md