GameAI / question_classifier.py
j-js's picture
Update question_classifier.py
462a39a verified
from __future__ import annotations
def normalize_category(category: str | None) -> str:
c = (category or "").strip().lower()
if c in {"quantitative", "quant", "q", "math"}:
return "Quantitative"
if c in {"datainsight", "data_insight", "data insight", "di", "data"}:
return "DataInsight"
if c in {"verbal", "v"}:
return "Verbal"
if c in {"general", "", "unknown", "none", "null"}:
return "General"
return category or "General"
def classify_question(question_text: str, category: str | None = None) -> dict:
q = (question_text or "").lower()
normalized = normalize_category(category)
if normalized == "Quantitative":
if ("percent" in q or "%" in q) and any(
k in q for k in ["then", "after", "followed by", "successive", "increase", "decrease", "discount"]
):
return {"category": normalized, "topic": "percent", "type": "successive_percent"}
if "percent" in q or "%" in q:
return {"category": normalized, "topic": "percent", "type": "percent_change"}
if "ratio" in q or ":" in q:
return {"category": normalized, "topic": "ratio", "type": "ratio_total"}
if "probability" in q or "chosen at random" in q:
return {"category": normalized, "topic": "probability", "type": "simple_probability"}
if "divisible" in q or "remainder" in q or "mod" in q:
return {"category": normalized, "topic": "number_theory", "type": "remainder_or_divisibility"}
if "|" in q:
return {"category": normalized, "topic": "algebra", "type": "absolute_value"}
if any(k in q for k in ["circle", "radius", "circumference", "triangle", "perimeter", "area"]):
return {"category": normalized, "topic": "geometry", "type": "geometry"}
if any(k in q for k in ["average", "mean", "median"]):
return {"category": normalized, "topic": "statistics", "type": "average"}
if "sequence" in q:
return {"category": normalized, "topic": "sequence", "type": "sequence"}
if "=" in q:
return {"category": normalized, "topic": "algebra", "type": "equation"}
return {"category": normalized, "topic": "quant", "type": "general"}
if normalized == "DataInsight":
if "percent" in q or "%" in q:
return {"category": normalized, "topic": "percent", "type": "percent_change"}
if any(k in q for k in ["mean", "median", "distribution"]):
return {"category": normalized, "topic": "statistics", "type": "distribution"}
if any(k in q for k in ["correlation", "scatter", "trend", "table", "chart"]):
return {"category": normalized, "topic": "data", "type": "correlation_or_graph"}
return {"category": normalized, "topic": "data", "type": "general"}
if normalized == "Verbal":
if "meaning" in q or "definition" in q:
return {"category": normalized, "topic": "vocabulary", "type": "definition"}
if "grammatically" in q or "sentence correction" in q:
return {"category": normalized, "topic": "grammar", "type": "sentence_correction"}
if "argument" in q or "author" in q:
return {"category": normalized, "topic": "reasoning", "type": "argument_analysis"}
return {"category": normalized, "topic": "verbal", "type": "general"}
if any(k in q for k in ["percent", "%", "ratio", "remainder", "divisible", "probability", "circle", "triangle", "="]):
return classify_question(question_text, "Quantitative")
if any(k in q for k in ["table", "chart", "scatter", "trend", "distribution"]):
return classify_question(question_text, "DataInsight")
return {"category": "General", "topic": "unknown", "type": "unknown"}