TradingGameAI / quant_solver.py
j-js's picture
Update quant_solver.py
36e5b98 verified
from __future__ import annotations
import math
import re
from fractions import Fraction
from statistics import mean, median
from typing import Dict, List, Optional
from models import SolverResult
from solver_router import route_solver
from utils import clean_math_text, normalize_spaces
try:
from math_normalizer import normalize_for_solver
except Exception:
def normalize_for_solver(text: str) -> str:
return text
try:
import sympy as sp
except Exception:
sp = None
CHOICE_LETTERS = ["A", "B", "C", "D", "E"]
def has_answer_choices(text: str) -> bool:
patterns = [
r"\bA[\)\.\:]",
r"\bB[\)\.\:]",
r"\bC[\)\.\:]",
r"\bD[\)\.\:]",
r"\bE[\)\.\:]",
]
return sum(bool(re.search(p, text, flags=re.I)) for p in patterns) >= 3
def extract_choices(text: str) -> Dict[str, str]:
matches = list(
re.finditer(
r"(?im)(?:^|\n|\s)([A-E])[\)\.\:]\s*(.*?)(?=(?:\n?\s*[A-E][\)\.\:]\s)|$)",
text,
)
)
choices: Dict[str, str] = {}
for m in matches:
letter = m.group(1).upper()
content = normalize_spaces(m.group(2))
choices[letter] = content
return choices
def extract_numbers(text: str) -> List[float]:
text = normalize_for_solver(text)
nums = re.findall(r"-?\d+(?:\.\d+)?", text.replace(",", ""))
out: List[float] = []
for n in nums:
try:
out.append(float(n))
except Exception:
pass
return out
def is_quant_question(text: str) -> bool:
lower = normalize_for_solver(text).lower()
quant_keywords = [
"solve",
"equation",
"integer",
"percent",
"ratio",
"probability",
"mean",
"median",
"average",
"sum",
"difference",
"product",
"quotient",
"triangle",
"circle",
"rectangle",
"perimeter",
"area",
"volume",
"number line",
"positive",
"negative",
"multiple",
"factor",
"prime",
"distance",
"speed",
"work",
"mixture",
"consecutive",
"algebra",
"value of x",
"value of y",
"what is x",
"what is y",
"divisible",
"sqrt",
"standard deviation",
"radius",
"diameter",
"probability",
]
if any(k in lower for k in quant_keywords):
return True
if re.search(r"[0-9]", lower) and ("?" in lower or has_answer_choices(lower) or "=" in lower):
return True
return False
def _prepare_expression(expr: str) -> str:
expr = normalize_for_solver(expr)
expr = clean_math_text(expr)
expr = expr.strip()
expr = expr.replace("^", "**")
expr = expr.replace("%", "/100")
expr = expr.replace("percent", "/100")
expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
expr = re.sub(r"([a-zA-Z])(\d)", r"\1*\2", expr)
return expr
def try_eval_expression(expr: str) -> Optional[float]:
expr = _prepare_expression(expr)
if not expr:
return None
allowed = re.sub(r"sqrt", "", expr)
allowed = re.sub(r"pi", "", allowed)
allowed = re.sub(r"[^0-9A-Za-z_\.\+\-\*\/\(\)\s]", "", allowed)
try:
safe_globals = {"__builtins__": {}}
safe_locals = {
"sqrt": math.sqrt,
"pi": math.pi,
}
return float(eval(allowed, safe_globals, safe_locals))
except Exception:
return None
def _parse_numeric_text(text: str) -> Optional[float]:
raw = normalize_for_solver(text).strip().lower()
raw_no_space = raw.replace(" ", "")
pct_match = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw_no_space)
if pct_match:
try:
return float(pct_match.group(1)) / 100.0
except Exception:
return None
frac_match = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
if frac_match:
try:
num = float(frac_match.group(1))
den = float(frac_match.group(2))
if den == 0:
return None
return num / den
except Exception:
return None
expr_val = try_eval_expression(raw)
if expr_val is not None:
return expr_val
nums = extract_numbers(raw)
if len(nums) == 1:
return nums[0]
return None
def parse_direct_expression_question(text: str) -> Optional[str]:
lower = normalize_for_solver(text).lower()
patterns = [
r"(?:what is|calculate|compute|evaluate|find)\s+([-+*/().\d\s]+)\??",
r"(?:what is|calculate|compute|evaluate|find)\s+([a-z0-9\^\+\-\*\/\(\)\.\s]+)\??",
]
for pattern in patterns:
m = re.search(pattern, lower)
if m:
expr = m.group(1).strip()
if expr:
return expr
return None
def compare_to_choices_numeric(
answer_value: float,
choices: Dict[str, str],
tolerance: float = 1e-9,
) -> Optional[str]:
best_letter = None
best_diff = float("inf")
for letter, raw in choices.items():
expr_val = _parse_numeric_text(raw)
if expr_val is None:
continue
diff = abs(expr_val - answer_value)
if diff < best_diff:
best_diff = diff
best_letter = letter
if best_letter is not None and best_diff <= max(tolerance, 1e-6):
return best_letter
return None
def compare_fraction_to_choices(
fraction_text: str,
choices: Dict[str, str],
) -> Optional[str]:
target_val = _parse_numeric_text(fraction_text)
if target_val is None:
return None
return compare_to_choices_numeric(target_val, choices)
def solve_basic_expression(text: str, help_mode: str) -> Optional[SolverResult]:
expr = parse_direct_expression_question(text)
if not expr:
return None
value = try_eval_expression(expr)
if value is None:
return None
if help_mode == "hint":
reply = "Focus on order of operations: parentheses, multiplication/division, then addition/subtraction."
elif help_mode == "walkthrough":
reply = f"Evaluate the expression {expr} using order of operations.\nThat gives {value:g}."
else:
reply = f"The value is {value:g}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_value=f"{value:g}",
)
def solve_percent_question(text: str, help_mode: str) -> Optional[SolverResult]:
lower = normalize_for_solver(text).lower()
choices = extract_choices(text)
m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
if m:
p = float(m.group(1))
n = float(m.group(2))
ans = p / 100 * n
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "Convert the percent to a decimal, then multiply."
elif help_mode == "walkthrough":
reply = f"{p}% of {n} = {p/100:g} × {n} = {ans:g}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The answer is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
m = re.search(r"(\d+(?:\.\d+)?)\s+is\s+what percent of\s+(\d+(?:\.\d+)?)", lower)
if m:
x = float(m.group(1))
y = float(m.group(2))
if y == 0:
return None
ans = x / y * 100
letter = compare_to_choices_numeric(ans / 100.0, choices) if choices else None
if help_mode == "hint":
reply = "Use part ÷ whole × 100."
elif help_mode == "walkthrough":
reply = f"Percent = ({x} / {y}) × 100 = {ans:g}%."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The answer is {ans:g}%."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}%",
)
m = re.search(r"(\d+(?:\.\d+)?)\s*(?:%|percent).*?(?:equals|as a fraction|fraction)", lower)
if not m:
m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+as a fraction", lower)
if m:
p = float(m.group(1))
if p.is_integer():
frac = Fraction(int(p), 100)
fraction_text = f"{frac.numerator}/{frac.denominator}"
letter = compare_fraction_to_choices(fraction_text, choices) if choices else None
if help_mode == "hint":
reply = "Think of percent as 'out of 100', then simplify the fraction."
elif help_mode == "walkthrough":
reply = f"{int(p)}% = {int(p)}/100 = {fraction_text}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"{int(p)}% = {fraction_text}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=fraction_text,
)
m = re.search(
r"(\d+(?:\.\d+)?)\s+(?:is\s+)?(increased|decreased)\s+by\s+(\d+(?:\.\d+)?)\s*(?:%|percent)",
lower,
)
if m:
base = float(m.group(1))
direction = m.group(2)
p = float(m.group(3)) / 100
ans = base * (1 + p) if direction == "increased" else base * (1 - p)
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "Use multiplier form: increase → 1 + p, decrease → 1 - p."
elif help_mode == "walkthrough":
mult = 1 + p if direction == "increased" else 1 - p
reply = f"Multiplier = {mult:g}, so {base:g} × {mult:g} = {ans:g}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The result is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
return None
def solve_ratio_question(text: str, help_mode: str) -> Optional[SolverResult]:
lower = normalize_for_solver(text).lower()
choices = extract_choices(text)
patterns = [
r"ratio of .*? is\s+(\d+)\s*:\s*(\d+).*?total (?:is|of)\s+(\d+(?:\.\d+)?)",
r"(\d+)\s*:\s*(\d+).*?total (?:is|of)\s+(\d+(?:\.\d+)?)",
r"ratio of .*? is\s+(\d+)\s*:\s*(\d+).*?there (?:is|are)\s+(\d+(?:\.\d+)?).*?total",
r"(\d+)\s*:\s*(\d+).*?there (?:is|are)\s+(\d+(?:\.\d+)?).*?total",
]
m = None
for pattern in patterns:
m = re.search(pattern, lower)
if m:
break
if m:
a = float(m.group(1))
b = float(m.group(2))
total = float(m.group(3))
parts = a + b
if parts == 0:
return None
first = total * a / parts
second = total * b / parts
letter = compare_to_choices_numeric(first, choices) if choices else None
if help_mode == "hint":
reply = "Add the ratio parts, then take that fraction of the total."
elif help_mode == "walkthrough":
reply = (
f"Total parts = {a:g} + {b:g} = {parts:g}.\n"
f"First quantity = {a:g}/{parts:g} × {total:g} = {first:g}.\n"
f"Second quantity = {b:g}/{parts:g} × {total:g} = {second:g}."
)
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The first quantity is {first:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{first:g}",
)
return None
def solve_average_question(text: str, help_mode: str) -> Optional[SolverResult]:
lower = normalize_for_solver(text).lower()
choices = extract_choices(text)
if "mean" in lower or "average" in lower:
nums = extract_numbers(text)
if len(nums) >= 2:
ans = mean(nums)
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "Add the numbers, then divide by how many there are."
elif help_mode == "walkthrough":
reply = f"The numbers are {', '.join(f'{x:g}' for x in nums)}.\nTheir average is {ans:g}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The average is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
if "median" in lower:
nums = extract_numbers(text)
if len(nums) >= 2:
ans = median(nums)
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "Sort the numbers first, then take the middle value."
elif help_mode == "walkthrough":
s = sorted(nums)
reply = f"Sorted numbers: {', '.join(f'{x:g}' for x in s)}.\nThe median is {ans:g}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The median is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
return None
def solve_probability_question(text: str, help_mode: str) -> Optional[SolverResult]:
lower = normalize_for_solver(text).lower()
choices = extract_choices(text)
m = re.search(r"(\d+)\s+red.*?(\d+)\s+blue.*?probability", lower)
if m:
r = float(m.group(1))
b = float(m.group(2))
total = r + b
if total == 0:
return None
ans = r / total
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "Probability = favorable outcomes ÷ total outcomes."
elif help_mode == "walkthrough":
reply = (
f"Favorable outcomes = {r:g}, total outcomes = {total:g}, "
f"so probability = {r:g}/{total:g} = {ans:g}."
)
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The probability is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
return None
def solve_geometry_question(text: str, help_mode: str) -> Optional[SolverResult]:
lower = normalize_for_solver(text).lower()
choices = extract_choices(text)
m = re.search(
r"rectangle.*?perimeter\s*(?:is|=)?\s*(\d+(?:\.\d+)?).*?length\s*(?:is|=)?\s*(\d+(?:\.\d+)?)",
lower,
)
if m and "width" in lower:
p = float(m.group(1))
l = float(m.group(2))
ans = p / 2 - l
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "Use the perimeter formula for a rectangle: P = 2(L + W)."
elif help_mode == "walkthrough":
reply = (
f"For a rectangle, P = 2(L + W).\n"
f"So {p:g} = 2({l:g} + W).\n"
f"Divide by 2: {p/2:g} = {l:g} + W.\n"
f"So W = {ans:g}."
)
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The width is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
m = re.search(
r"rectangle.*?length\s*(?:is|=)?\s*(\d+(?:\.\d+)?)\D+width\s*(?:is|=)?\s*(\d+(?:\.\d+)?)",
lower,
)
if m and "area" in lower:
l = float(m.group(1))
w = float(m.group(2))
ans = l * w
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "For a rectangle, area = length × width."
elif help_mode == "walkthrough":
reply = f"Area = {l:g} × {w:g} = {ans:g}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The area is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
m = re.search(r"circle.*?radius\s*(?:is|=)?\s*(\d+(?:\.\d+)?)", lower)
if m and "area" in lower:
r = float(m.group(1))
ans = math.pi * r * r
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "For a circle, area = πr²."
elif help_mode == "walkthrough":
reply = f"Area = π({r:g})² = {ans:.4f}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The area is {ans:.4f}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:.4f}",
)
m = re.search(
r"triangle.*?base\s*(?:is|=)?\s*(\d+(?:\.\d+)?)\D+height\s*(?:is|=)?\s*(\d+(?:\.\d+)?)",
lower,
)
if m and "area" in lower:
b = float(m.group(1))
h = float(m.group(2))
ans = 0.5 * b * h
letter = compare_to_choices_numeric(ans, choices) if choices else None
if help_mode == "hint":
reply = "Triangle area = 1/2 × base × height."
elif help_mode == "walkthrough":
reply = f"Area = 1/2 × {b:g} × {h:g} = {ans:g}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"The area is {ans:g}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{ans:g}",
)
return None
def solve_divisibility_question(text: str, help_mode: str) -> Optional[SolverResult]:
lower = normalize_for_solver(text).lower()
choices = extract_choices(text)
if "divisible by 5" in lower and "3x+2" in lower.replace(" ", ""):
valid_letter = None
valid_x = None
for letter, raw in choices.items():
nums = extract_numbers(raw)
if len(nums) != 1:
continue
x = int(nums[0])
if (3 * x + 2) % 5 == 0:
valid_letter = letter
valid_x = x
break
if valid_x is None:
return None
if help_mode == "hint":
reply = "Test the answer choices in 3x + 2 and see which makes a multiple of 5."
elif help_mode == "walkthrough":
reply = (
f"Substitute the choices into 3x + 2.\n"
f"When x = {valid_x}, 3({valid_x}) + 2 = {3 * valid_x + 2}, which is divisible by 5.\n"
f"So the correct choice is {valid_letter}."
)
else:
reply = f"x = {valid_x}, so the correct choice is {valid_letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=valid_letter,
answer_value=str(valid_x),
)
return None
def solve_linear_equation(text: str, help_mode: str) -> Optional[SolverResult]:
if sp is None:
return None
lower = normalize_for_solver(text).lower()
lower = lower.replace("^", "**")
choices = extract_choices(text)
m = re.search(
r"(?:solve for x|find x|value of x).*?([\-0-9a-z\+\*/\s\(\)\.=]+=[\-0-9a-z\+\*/\s\(\)\.=]+)",
lower,
)
if not m:
m = re.search(r"([\-0-9a-z\+\*/\s\(\)\.]+=[\-0-9a-z\+\*/\s\(\)\.]+)", lower)
if not m:
return None
eq_text = m.group(1).strip()
eq_text = re.sub(r"(\d)([a-z])", r"\1*\2", eq_text)
eq_text = re.sub(r"([a-z])(\d)", r"\1*\2", eq_text)
parts = eq_text.split("=")
if len(parts) != 2:
return None
try:
x = sp.symbols("x")
left = sp.sympify(parts[0], locals={"x": x})
right = sp.sympify(parts[1], locals={"x": x})
sols = sp.solve(sp.Eq(left, right), x)
if not sols:
return None
sol = sp.simplify(sols[0])
sol_float = None
try:
sol_float = float(sol)
except Exception:
pass
letter = None
if choices and sol_float is not None:
letter = compare_to_choices_numeric(sol_float, choices)
if help_mode == "hint":
reply = "Collect the x-terms on one side and constants on the other."
elif help_mode == "walkthrough":
reply = f"Solve {eq_text} for x.\nThis gives x = {sol}."
if letter:
reply += f"\nThat matches choice {letter}."
else:
reply = f"x = {sol}."
if letter:
reply += f" So the correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=str(sol),
)
except Exception:
return None
def solve_by_option_checking(text: str, help_mode: str) -> Optional[SolverResult]:
choices = extract_choices(text)
if len(choices) < 3:
return None
expr = parse_direct_expression_question(text)
if expr:
val = try_eval_expression(expr)
if val is not None:
letter = compare_to_choices_numeric(val, choices)
if letter:
if help_mode == "hint":
reply = "Evaluate the expression first, then match it to the answer choices."
elif help_mode == "walkthrough":
reply = f"The expression evaluates to {val:g}, which matches choice {letter}."
else:
reply = f"The correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{val:g}",
)
lower = normalize_for_solver(text).lower()
m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
if m:
val = float(m.group(1)) / 100 * float(m.group(2))
letter = compare_to_choices_numeric(val, choices)
if letter:
if help_mode == "hint":
reply = "Compute the percentage, then match it to the options."
elif help_mode == "walkthrough":
reply = f"{m.group(1)}% of {m.group(2)} = {val:g}, so the correct choice is {letter}."
else:
reply = f"The correct choice is {letter}."
return SolverResult(
reply=reply,
domain="quant",
solved=True,
help_mode=help_mode,
answer_letter=letter,
answer_value=f"{val:g}",
)
return None
def solve_quant(text: str, help_mode: str) -> SolverResult:
text = normalize_for_solver(text)
solver_type = route_solver(text)
routed_solvers = {
"percent_solver": [
solve_by_option_checking,
solve_percent_question,
solve_basic_expression,
],
"ratio_solver": [
solve_by_option_checking,
solve_ratio_question,
solve_basic_expression,
],
"algebra_solver": [
solve_by_option_checking,
solve_linear_equation,
solve_divisibility_question,
solve_basic_expression,
],
"statistics_solver": [
solve_by_option_checking,
solve_average_question,
solve_basic_expression,
],
"probability_solver": [
solve_by_option_checking,
solve_probability_question,
solve_basic_expression,
],
"geometry_solver": [
solve_by_option_checking,
solve_geometry_question,
solve_basic_expression,
],
"basic_solver": [
solve_by_option_checking,
solve_percent_question,
solve_ratio_question,
solve_average_question,
solve_probability_question,
solve_geometry_question,
solve_divisibility_question,
solve_linear_equation,
solve_basic_expression,
],
}
solvers = routed_solvers.get(solver_type, routed_solvers["basic_solver"])
for solver in solvers:
try:
out = solver(text, help_mode)
if out:
return out
except Exception:
pass
if solver_type != "basic_solver":
for solver in routed_solvers["basic_solver"]:
try:
out = solver(text, help_mode)
if out:
return out
except Exception:
pass
if help_mode == "hint":
reply = (
"Identify the question type first, then use the key relationship or formula. "
"If there are answer choices, you can also test them directly."
)
elif help_mode == "walkthrough":
reply = (
"I can help with this, but I cannot confidently solve it from the current parse alone. "
"I can still talk through the first step or eliminate options if needed."
)
else:
reply = "I can help with this, but I can’t confidently solve it from the current parse alone yet."
return SolverResult(
reply=reply,
domain="quant",
solved=False,
help_mode=help_mode,
)