GameAI / quant_solver.py
j-js's picture
Update quant_solver.py
670ac1a verified
raw
history blame
13.1 kB
from __future__ import annotations
import math
import re
from statistics import mean, median
from typing import Dict, List, Optional, Tuple
try:
import sympy as sp
except Exception:
sp = None
from models import SolverResult
from utils import clean_math_text, normalize_spaces
def extract_choices(text: str) -> Dict[str, str]:
text = text or ""
matches = list(
re.finditer(
r"(?i)\b([A-E])[\)\.:]\s*(.*?)(?=\s+\b[A-E][\)\.:]\s*|$)",
text,
)
)
return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
def has_answer_choices(text: str) -> bool:
return len(extract_choices(text)) >= 3
def is_quant_question(text: str) -> bool:
lower = clean_math_text(text).lower()
keywords = [
"solve", "equation", "percent", "ratio", "probability", "mean", "median",
"average", "sum", "difference", "product", "quotient", "triangle", "circle",
"rectangle", "area", "perimeter", "volume", "algebra", "integer", "divisible",
"number", "fraction", "decimal", "geometry", "distance", "speed", "work",
"remainder", "discount",
]
if any(k in lower for k in keywords):
return True
if "=" in lower and re.search(r"[a-z]", lower):
return True
if re.search(r"\d", lower) and ("?" in lower or has_answer_choices(lower)):
return True
return False
def _prepare_expression(expr: str) -> str:
expr = clean_math_text(expr).strip()
expr = expr.replace("^", "**")
expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
return expr
def _extract_equation(text: str) -> Optional[str]:
cleaned = clean_math_text(text)
if "=" not in cleaned:
return None
patterns = [
r"([A-Za-z0-9\.\+\-\*/\^\(\)\s]*[a-zA-Z][A-Za-z0-9\.\+\-\*/\^\(\)\s]*=[A-Za-z0-9\.\+\-\*/\^\(\)\s]+)",
r"([0-9A-Za-z\.\+\-\*/\^\(\)\s]+=[0-9A-Za-z\.\+\-\*/\^\(\)\s]+)",
]
for pattern in patterns:
for m in re.finditer(pattern, cleaned):
candidate = m.group(1).strip()
if re.search(r"[a-z]", candidate.lower()) and not candidate.lower().startswith(
("how do", "can you", "please", "what is", "solve ")
):
return candidate
eq_index = cleaned.find("=")
left = re.findall(r"[A-Za-z0-9\.\+\-\*/\^\(\)\s]+$", cleaned[:eq_index])
right = re.findall(r"^[A-Za-z0-9\.\+\-\*/\^\(\)\s]+", cleaned[eq_index + 1:])
if left and right:
candidate = left[0].strip().split()[-1] + " = " + right[0].strip().split()[0]
if re.search(r"[a-z]", candidate.lower()):
return candidate
return None
def _parse_number(text: str) -> Optional[float]:
raw = clean_math_text(text).strip().lower()
pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw.replace(" ", ""))
if pct:
return float(pct.group(1)) / 100.0
frac = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
if frac:
den = float(frac.group(2))
if den == 0:
return None
return float(frac.group(1)) / den
try:
return float(
eval(
_prepare_expression(raw),
{"__builtins__": {}},
{"sqrt": math.sqrt, "pi": math.pi},
)
)
except Exception:
return None
def _best_choice(answer_value: float, choices: Dict[str, str]) -> Optional[str]:
best_letter = None
best_diff = float("inf")
for letter, raw in choices.items():
parsed = _parse_number(raw)
if parsed is None:
continue
diff = abs(parsed - answer_value)
if diff < best_diff:
best_diff = diff
best_letter = letter
if best_letter is not None and best_diff <= 1e-6:
return best_letter
return None
def _make_result(
*,
topic: str,
answer_value: str,
internal_answer: Optional[str] = None,
steps: Optional[List[str]] = None,
choices_text: str = "",
) -> SolverResult:
answer_float = _parse_number(answer_value)
choices = extract_choices(choices_text)
answer_letter = _best_choice(answer_float, choices) if (answer_float is not None and choices) else None
return SolverResult(
domain="quant",
solved=True,
topic=topic,
answer_value=answer_value,
answer_letter=answer_letter,
internal_answer=internal_answer or answer_value,
steps=steps or [],
)
def _solve_successive_percent(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
pattern = re.findall(
r"(increase|decrease|discount|mark(?:ed)?\s*up|mark(?:ed)?\s*down|rise|fall)\s+by\s+(\d+(?:\.\d+)?)\s*(?:%|percent)",
lower,
)
if len(pattern) < 2:
pattern = re.findall(
r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+(increase|decrease|discount|rise|fall)",
lower,
)
pattern = [(op, pct) for pct, op in pattern]
if len(pattern) < 2:
return None
multiplier = 1.0
step_lines: List[str] = []
for op, pct_raw in pattern:
pct = float(pct_raw)
if any(k in op for k in ["decrease", "discount", "down", "fall"]):
factor = 1 - pct / 100.0
step_lines.append(f"A {pct:g}% decrease means multiply by {factor:g}.")
else:
factor = 1 + pct / 100.0
step_lines.append(f"A {pct:g}% increase means multiply by {factor:g}.")
multiplier *= factor
net_change = (multiplier - 1.0) * 100.0
direction = "increase" if net_change >= 0 else "decrease"
magnitude = abs(net_change)
return _make_result(
topic="percent",
answer_value=f"{magnitude:g}%",
internal_answer=f"net {direction} of {magnitude:g}%",
steps=step_lines + [f"The combined multiplier gives a net {direction} of {magnitude:g}%."],
choices_text=text,
)
def _extract_ratio_labels(text: str) -> Optional[Tuple[str, str]]:
m = re.search(r"ratio of ([a-z ]+?) to ([a-z ]+?) is \d+\s*:\s*\d+", text.lower())
if not m:
return None
left = normalize_spaces(m.group(1)).rstrip("s")
right = normalize_spaces(m.group(2)).rstrip("s")
return left, right
def _solve_ratio_total(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
ratio_match = re.search(r"(\d+)\s*:\s*(\d+)", lower)
total_match = re.search(r"(?:total|altogether|in all|sum)\s*(?:is|=|of)?\s*(\d+)", lower)
if not ratio_match or not total_match:
return None
a = int(ratio_match.group(1))
b = int(ratio_match.group(2))
total = int(total_match.group(1))
part_sum = a + b
if part_sum == 0:
return None
unit = total / part_sum
left_value = a * unit
right_value = b * unit
labels = _extract_ratio_labels(lower)
requested_value = left_value
requested_label = "first quantity"
if labels:
left_label, right_label = labels
if left_label in lower and re.search(rf"how many {re.escape(left_label)}", lower):
requested_value = left_value
requested_label = left_label
elif right_label in lower and re.search(rf"how many {re.escape(right_label)}", lower):
requested_value = right_value
requested_label = right_label
else:
requested_value = left_value
requested_label = left_label
return _make_result(
topic="ratio",
answer_value=f"{requested_value:g}",
internal_answer=f"{requested_label} = {requested_value:g}",
steps=[
f"Add the ratio parts: {a} + {b} = {part_sum}.",
f"Each ratio unit is {total} / {part_sum} = {unit:g}.",
f"Multiply by the required ratio part to get {requested_value:g}.",
],
choices_text=text,
)
def _solve_remainder(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
m = re.search(r"remainder .*? when (\d+) is divided by (\d+)", lower)
if not m:
m = re.search(r"(\d+)\s*(?:mod|%)\s*(\d+)", lower)
if not m:
return None
a = int(m.group(1))
b = int(m.group(2))
if b == 0:
return None
r = a % b
return _make_result(
topic="number_theory",
answer_value=str(r),
internal_answer=str(r),
steps=[
f"Divide {a} by {b}.",
f"The remainder is {a} mod {b} = {r}.",
],
choices_text=text,
)
def _solve_percent(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
choices = extract_choices(text)
m = re.search(r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(?:a\s+)?number\s+is\s+(\d+(?:\.\d+)?)", lower)
if m:
p = float(m.group(1))
value = float(m.group(2))
ans = value / (p / 100.0)
answer_letter = _best_choice(ans, choices) if choices else None
return SolverResult(
domain="quant",
solved=True,
topic="percent",
answer_value=f"{ans:g}",
answer_letter=answer_letter,
internal_answer=f"{ans:g}",
steps=[
"Let the number be n.",
f"Write {p}% of n as {p / 100:g}n.",
f"Set {p / 100:g}n = {value} and solve for n.",
],
)
m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
if m:
p = float(m.group(1))
n = float(m.group(2))
ans = p / 100.0 * n
answer_letter = _best_choice(ans, choices) if choices else None
return SolverResult(
domain="quant",
solved=True,
topic="percent",
answer_value=f"{ans:g}",
answer_letter=answer_letter,
internal_answer=f"{ans:g}",
steps=[
f"Convert {p}% to {p / 100:g}.",
f"Multiply by {n}.",
],
)
return None
def _solve_mean_median(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
if not nums:
return None
if "mean" in lower or "average" in lower:
ans = mean(nums)
return SolverResult(
domain="quant",
solved=True,
topic="statistics",
answer_value=f"{ans:g}",
internal_answer=f"{ans:g}",
steps=["Add the values.", f"Divide by {len(nums)}."],
)
if "median" in lower:
ans = median(nums)
return SolverResult(
domain="quant",
solved=True,
topic="statistics",
answer_value=f"{ans:g}",
internal_answer=f"{ans:g}",
steps=["Order the values.", "Take the middle value."],
)
return None
def _solve_linear_equation(text: str) -> Optional[SolverResult]:
if sp is None:
return None
expr = _extract_equation(text)
if not expr:
return None
try:
lhs, rhs = expr.split("=", 1)
symbols = sorted(set(re.findall(r"\b[a-z]\b", expr)))
if not symbols:
return None
var_name = symbols[0]
var = sp.symbols(var_name)
sol = sp.solve(
sp.Eq(sp.sympify(_prepare_expression(lhs)), sp.sympify(_prepare_expression(rhs))),
var,
)
if not sol:
return None
value = sol[0]
try:
as_float = float(value)
except Exception:
as_float = None
choices = extract_choices(text)
return SolverResult(
domain="quant",
solved=True,
topic="algebra",
answer_value=str(value),
answer_letter=_best_choice(as_float, choices) if (as_float is not None and choices) else None,
internal_answer=f"{var_name} = {value}",
steps=[
"Treat the statement as an equation.",
"Undo operations on both sides to isolate the variable.",
f"That gives {var_name} = {value}.",
],
)
except Exception:
return None
def solve_quant(text: str) -> SolverResult:
text = text or ""
for fn in (
_solve_successive_percent,
_solve_ratio_total,
_solve_remainder,
_solve_percent,
_solve_mean_median,
_solve_linear_equation,
):
result = fn(text)
if result is not None:
return result
return SolverResult(
domain="quant",
solved=False,
topic="general_quant",
reply="This looks quantitative, but it does not match a strong rule-based pattern yet.",
steps=[
"Identify the quantity the question wants.",
"Translate the wording into an equation, ratio, or diagram.",
"Carry out the calculation carefully.",
],
)