GameAI / quant_solver.py
j-js's picture
Update quant_solver.py
074e4ab verified
from __future__ import annotations
import math
import re
from statistics import mean, median
from typing import Dict, List, Optional, Tuple
try:
import sympy as sp
except Exception:
sp = None
from models import SolverResult
from utils import clean_math_text, normalize_spaces
def extract_choices(text: str) -> Dict[str, str]:
text = text or ""
matches = list(
re.finditer(
r"(?i)\b([A-E])[\)\.:]\s*(.*?)(?=\s+\b[A-E][\)\.:]\s*|$)",
text,
)
)
return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
def has_answer_choices(text: str) -> bool:
return len(extract_choices(text)) >= 3
def is_quant_question(text: str) -> bool:
lower = clean_math_text(text).lower()
keywords = [
"solve", "equation", "percent", "ratio", "probability", "mean", "median",
"average", "sum", "difference", "product", "quotient", "triangle", "circle",
"rectangle", "area", "perimeter", "volume", "algebra", "integer", "divisible",
"number", "fraction", "decimal", "geometry", "distance", "speed", "work",
"remainder", "discount",
]
if any(k in lower for k in keywords):
return True
if "=" in lower and re.search(r"[a-z]", lower):
return True
if re.search(r"\d", lower) and ("?" in lower or has_answer_choices(lower)):
return True
return False
def _prepare_expression(expr: str) -> str:
expr = clean_math_text(expr).strip()
expr = expr.replace("^", "**")
# remove common prompt wrappers
expr = re.sub(r"(?i)^\s*(solve|simplify|evaluate|find|what is|compute)\s*:?\s*", "", expr)
# implicit multiplication
expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
expr = re.sub(r"(\d)\s*([a-zA-Z])", r"\1*\2", expr)
expr = re.sub(r"([a-zA-Z])\s*\(", r"\1*(", expr)
expr = re.sub(r"\)\s*([a-zA-Z])", r")*\1", expr)
expr = re.sub(r"([a-zA-Z])\s+([a-zA-Z])", r"\1*\2", expr)
return expr
def _clean_equation_candidate(text: str) -> str:
s = clean_math_text(text).strip()
# remove leading prompt phrases but keep equation content
s = re.sub(r"(?i)^\s*(solve|simplify|evaluate|find)\s*:?\s*", "", s)
s = re.sub(r"(?i)^\s*how do i solve\s*:?\s*", "", s)
s = re.sub(r"(?i)^\s*what is\s+", "", s)
s = normalize_spaces(s)
return s
def _extract_equation(text: str) -> Optional[str]:
cleaned = _clean_equation_candidate(text)
if "=" not in cleaned:
return None
# first try: take the full equation-looking span
patterns = [
r"([A-Za-z0-9\.\+\-\*/\^\(\)\s]+=[A-Za-z0-9\.\+\-\*/\^\(\)\s]+)",
]
for pattern in patterns:
for m in re.finditer(pattern, cleaned):
candidate = normalize_spaces(m.group(1))
if "=" not in candidate:
continue
if re.search(r"[a-z]", candidate.lower()):
return candidate
# fallback: split once on equals and trim to likely expression zones
parts = cleaned.split("=", 1)
if len(parts) != 2:
return None
lhs = parts[0].strip(" .,:;!?")
rhs = parts[1].strip(" .,:;!?")
if not lhs or not rhs:
return None
candidate = f"{lhs} = {rhs}"
if re.search(r"[a-z]", candidate.lower()):
return candidate
return None
def _extract_variable_names(expr: str) -> List[str]:
# catches x in "3x + 5 = 20" as well as standalone x
vars_found = sorted(set(re.findall(r"[a-z]", expr.lower())))
# avoid treating common words as many variables by keeping only likely algebra variables first
preferred = [v for v in vars_found if v in {"x", "y", "z", "n"}]
return preferred or vars_found
def _parse_number(text: str) -> Optional[float]:
raw = clean_math_text(text).strip().lower()
pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw.replace(" ", ""))
if pct:
return float(pct.group(1)) / 100.0
frac = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
if frac:
den = float(frac.group(2))
if den == 0:
return None
return float(frac.group(1)) / den
try:
return float(
eval(
_prepare_expression(raw),
{"__builtins__": {}},
{"sqrt": math.sqrt, "pi": math.pi},
)
)
except Exception:
return None
def _best_choice(answer_value: float, choices: Dict[str, str]) -> Optional[str]:
best_letter = None
best_diff = float("inf")
for letter, raw in choices.items():
parsed = _parse_number(raw)
if parsed is None:
continue
diff = abs(parsed - answer_value)
if diff < best_diff:
best_diff = diff
best_letter = letter
if best_letter is not None and best_diff <= 1e-6:
return best_letter
return None
def _make_result(
*,
topic: str,
answer_value: str,
internal_answer: Optional[str] = None,
steps: Optional[List[str]] = None,
choices_text: str = "",
) -> SolverResult:
answer_float = _parse_number(answer_value)
choices = extract_choices(choices_text)
answer_letter = _best_choice(answer_float, choices) if (answer_float is not None and choices) else None
return SolverResult(
domain="quant",
solved=True,
topic=topic,
answer_value=answer_value,
answer_letter=answer_letter,
internal_answer=internal_answer or answer_value,
steps=steps or [],
)
def _solve_successive_percent(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
pattern = re.findall(
r"(increase|decrease|discount|mark(?:ed)?\s*up|mark(?:ed)?\s*down|rise|fall)\s+by\s+(\d+(?:\.\d+)?)\s*(?:%|percent)",
lower,
)
if len(pattern) < 2:
pattern = re.findall(
r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+(increase|decrease|discount|rise|fall)",
lower,
)
pattern = [(op, pct) for pct, op in pattern]
if len(pattern) < 2:
return None
multiplier = 1.0
step_lines: List[str] = []
for op, pct_raw in pattern:
pct = float(pct_raw)
if any(k in op for k in ["decrease", "discount", "down", "fall"]):
factor = 1 - pct / 100.0
step_lines.append(f"A {pct:g}% decrease means multiply by {factor:g}.")
else:
factor = 1 + pct / 100.0
step_lines.append(f"A {pct:g}% increase means multiply by {factor:g}.")
multiplier *= factor
net_change = (multiplier - 1.0) * 100.0
direction = "increase" if net_change >= 0 else "decrease"
magnitude = abs(net_change)
return _make_result(
topic="percent",
answer_value=f"{magnitude:g}%",
internal_answer=f"net {direction} of {magnitude:g}%",
steps=step_lines + [f"Combine the multipliers to find the overall percent change."],
choices_text=text,
)
def _extract_ratio_labels(text: str) -> Optional[Tuple[str, str]]:
m = re.search(r"ratio of ([a-z ]+?) to ([a-z ]+?) is \d+\s*:\s*\d+", text.lower())
if not m:
return None
left = normalize_spaces(m.group(1)).rstrip("s")
right = normalize_spaces(m.group(2)).rstrip("s")
return left, right
def _solve_ratio_total(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
ratio_match = re.search(r"(\d+)\s*:\s*(\d+)", lower)
total_match = re.search(r"(?:total|altogether|in all|sum)\s*(?:is|=|of)?\s*(\d+)", lower)
if not ratio_match or not total_match:
return None
a = int(ratio_match.group(1))
b = int(ratio_match.group(2))
total = int(total_match.group(1))
part_sum = a + b
if part_sum == 0:
return None
unit = total / part_sum
left_value = a * unit
right_value = b * unit
labels = _extract_ratio_labels(lower)
requested_value = left_value
if labels:
left_label, right_label = labels
if left_label in lower and re.search(rf"how many {re.escape(left_label)}", lower):
requested_value = left_value
elif right_label in lower and re.search(rf"how many {re.escape(right_label)}", lower):
requested_value = right_value
else:
requested_value = left_value
return _make_result(
topic="ratio",
answer_value=f"{requested_value:g}",
internal_answer=f"{requested_value:g}",
steps=[
f"Add the ratio parts: {a} + {b} = {part_sum}.",
f"Find the value of one ratio part using the total.",
f"Multiply by the required ratio part.",
],
choices_text=text,
)
def _solve_remainder(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
m = re.search(r"remainder .*? when (\d+) is divided by (\d+)", lower)
if not m:
m = re.search(r"(\d+)\s*(?:mod|%)\s*(\d+)", lower)
if not m:
return None
a = int(m.group(1))
b = int(m.group(2))
if b == 0:
return None
r = a % b
return _make_result(
topic="number_theory",
answer_value=str(r),
internal_answer=str(r),
steps=[
f"Divide {a} by {b}.",
"Keep the amount left over after division.",
],
choices_text=text,
)
def _solve_percent(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
choices = extract_choices(text)
patterns = [
r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(?:a\s+)?number\s+is\s+(\d+(?:\.\d+)?)",
r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+([a-z])\s+is\s+(\d+(?:\.\d+)?)",
r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+([a-z])\s*=\s*(\d+(?:\.\d+)?)",
]
for pat in patterns:
m = re.search(pat, lower)
if not m:
continue
if len(m.groups()) == 2:
p = float(m.group(1))
value = float(m.group(2))
else:
p = float(m.group(1))
value = float(m.group(3))
if p == 0:
return None
ans = value / (p / 100.0)
answer_letter = _best_choice(ans, choices) if choices else None
return SolverResult(
domain="quant",
solved=True,
topic="percent",
answer_value=f"{ans:g}",
answer_letter=answer_letter,
internal_answer=f"{ans:g}",
steps=[
"Let the unknown quantity be a variable.",
f"Convert {p:g}% into its decimal form.",
"Write an equation for the part and whole relationship.",
"Solve that equation for the unknown.",
],
)
m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
if m:
p = float(m.group(1))
n = float(m.group(2))
ans = p / 100.0 * n
answer_letter = _best_choice(ans, choices) if choices else None
return SolverResult(
domain="quant",
solved=True,
topic="percent",
answer_value=f"{ans:g}",
answer_letter=answer_letter,
internal_answer=f"{ans:g}",
steps=[
f"Convert {p:g}% to a decimal.",
f"Multiply that decimal by {n}.",
],
)
return None
def _solve_mean_median(text: str) -> Optional[SolverResult]:
lower = clean_math_text(text).lower()
nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
if not nums:
return None
if "mean" in lower or "average" in lower:
ans = mean(nums)
return SolverResult(
domain="quant",
solved=True,
topic="statistics",
answer_value=f"{ans:g}",
internal_answer=f"{ans:g}",
steps=["Add the values.", f"Divide by {len(nums)}."],
)
if "median" in lower:
ans = median(nums)
return SolverResult(
domain="quant",
solved=True,
topic="statistics",
answer_value=f"{ans:g}",
internal_answer=f"{ans:g}",
steps=["Order the values.", "Take the middle value."],
)
return None
def _solve_linear_equation(text: str) -> Optional[SolverResult]:
if sp is None:
return None
expr = _extract_equation(text)
if not expr:
return None
try:
lhs, rhs = expr.split("=", 1)
lhs_prepped = _prepare_expression(lhs)
rhs_prepped = _prepare_expression(rhs)
symbols = _extract_variable_names(expr)
if not symbols:
return None
var_name = symbols[0]
var = sp.symbols(var_name)
equation = sp.Eq(sp.sympify(lhs_prepped), sp.sympify(rhs_prepped))
sol = sp.solve(equation, var)
if not sol:
return None
value = sol[0]
try:
as_float = float(value)
except Exception:
as_float = None
choices = extract_choices(text)
return SolverResult(
domain="quant",
solved=True,
topic="algebra",
answer_value=str(value),
answer_letter=_best_choice(as_float, choices) if (as_float is not None and choices) else None,
internal_answer=f"{var_name} = {value}",
steps=[
"Treat the statement as an equation.",
"Undo operations on both sides to isolate the variable.",
"Keep simplifying until the variable is alone.",
],
)
except Exception:
return None
def solve_quant(text: str) -> SolverResult:
text = text or ""
for fn in (
_solve_successive_percent,
_solve_ratio_total,
_solve_remainder,
_solve_percent,
_solve_mean_median,
_solve_linear_equation,
):
result = fn(text)
if result is not None:
return result
return SolverResult(
domain="quant",
solved=False,
topic="general_quant",
reply="This looks quantitative, but it does not match a strong rule-based pattern yet.",
steps=[
"Identify the quantity the question wants.",
"Translate the wording into an equation, ratio, or diagram.",
"Carry out the calculation carefully.",
],
)