from __future__ import annotations import math import re from statistics import mean, median from typing import Dict, List, Optional, Tuple try: import sympy as sp except Exception: sp = None from models import SolverResult from utils import clean_math_text, normalize_spaces def extract_choices(text: str) -> Dict[str, str]: text = text or "" matches = list( re.finditer( r"(?i)\b([A-E])[\)\.:]\s*(.*?)(?=\s+\b[A-E][\)\.:]\s*|$)", text, ) ) return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches} def has_answer_choices(text: str) -> bool: return len(extract_choices(text)) >= 3 def is_quant_question(text: str) -> bool: lower = clean_math_text(text).lower() keywords = [ "solve", "equation", "percent", "ratio", "probability", "mean", "median", "average", "sum", "difference", "product", "quotient", "triangle", "circle", "rectangle", "area", "perimeter", "volume", "algebra", "integer", "divisible", "number", "fraction", "decimal", "geometry", "distance", "speed", "work", "remainder", "discount", ] if any(k in lower for k in keywords): return True if "=" in lower and re.search(r"[a-z]", lower): return True if re.search(r"\d", lower) and ("?" in lower or has_answer_choices(lower)): return True return False def _prepare_expression(expr: str) -> str: expr = clean_math_text(expr).strip() expr = expr.replace("^", "**") # remove common prompt wrappers expr = re.sub(r"(?i)^\s*(solve|simplify|evaluate|find|what is|compute)\s*:?\s*", "", expr) # implicit multiplication expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) expr = re.sub(r"\)\s*(\d)", r")*\1", expr) expr = re.sub(r"(\d)\s*([a-zA-Z])", r"\1*\2", expr) expr = re.sub(r"([a-zA-Z])\s*\(", r"\1*(", expr) expr = re.sub(r"\)\s*([a-zA-Z])", r")*\1", expr) expr = re.sub(r"([a-zA-Z])\s+([a-zA-Z])", r"\1*\2", expr) return expr def _clean_equation_candidate(text: str) -> str: s = clean_math_text(text).strip() # remove leading prompt phrases but keep equation content s = re.sub(r"(?i)^\s*(solve|simplify|evaluate|find)\s*:?\s*", "", s) s = re.sub(r"(?i)^\s*how do i solve\s*:?\s*", "", s) s = re.sub(r"(?i)^\s*what is\s+", "", s) s = normalize_spaces(s) return s def _extract_equation(text: str) -> Optional[str]: cleaned = _clean_equation_candidate(text) if "=" not in cleaned: return None # first try: take the full equation-looking span patterns = [ r"([A-Za-z0-9\.\+\-\*/\^\(\)\s]+=[A-Za-z0-9\.\+\-\*/\^\(\)\s]+)", ] for pattern in patterns: for m in re.finditer(pattern, cleaned): candidate = normalize_spaces(m.group(1)) if "=" not in candidate: continue if re.search(r"[a-z]", candidate.lower()): return candidate # fallback: split once on equals and trim to likely expression zones parts = cleaned.split("=", 1) if len(parts) != 2: return None lhs = parts[0].strip(" .,:;!?") rhs = parts[1].strip(" .,:;!?") if not lhs or not rhs: return None candidate = f"{lhs} = {rhs}" if re.search(r"[a-z]", candidate.lower()): return candidate return None def _extract_variable_names(expr: str) -> List[str]: # catches x in "3x + 5 = 20" as well as standalone x vars_found = sorted(set(re.findall(r"[a-z]", expr.lower()))) # avoid treating common words as many variables by keeping only likely algebra variables first preferred = [v for v in vars_found if v in {"x", "y", "z", "n"}] return preferred or vars_found def _parse_number(text: str) -> Optional[float]: raw = clean_math_text(text).strip().lower() pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw.replace(" ", "")) if pct: return float(pct.group(1)) / 100.0 frac = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw) if frac: den = float(frac.group(2)) if den == 0: return None return float(frac.group(1)) / den try: return float( eval( _prepare_expression(raw), {"__builtins__": {}}, {"sqrt": math.sqrt, "pi": math.pi}, ) ) except Exception: return None def _best_choice(answer_value: float, choices: Dict[str, str]) -> Optional[str]: best_letter = None best_diff = float("inf") for letter, raw in choices.items(): parsed = _parse_number(raw) if parsed is None: continue diff = abs(parsed - answer_value) if diff < best_diff: best_diff = diff best_letter = letter if best_letter is not None and best_diff <= 1e-6: return best_letter return None def _make_result( *, topic: str, answer_value: str, internal_answer: Optional[str] = None, steps: Optional[List[str]] = None, choices_text: str = "", ) -> SolverResult: answer_float = _parse_number(answer_value) choices = extract_choices(choices_text) answer_letter = _best_choice(answer_float, choices) if (answer_float is not None and choices) else None return SolverResult( domain="quant", solved=True, topic=topic, answer_value=answer_value, answer_letter=answer_letter, internal_answer=internal_answer or answer_value, steps=steps or [], ) def _solve_successive_percent(text: str) -> Optional[SolverResult]: lower = clean_math_text(text).lower() pattern = re.findall( r"(increase|decrease|discount|mark(?:ed)?\s*up|mark(?:ed)?\s*down|rise|fall)\s+by\s+(\d+(?:\.\d+)?)\s*(?:%|percent)", lower, ) if len(pattern) < 2: pattern = re.findall( r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+(increase|decrease|discount|rise|fall)", lower, ) pattern = [(op, pct) for pct, op in pattern] if len(pattern) < 2: return None multiplier = 1.0 step_lines: List[str] = [] for op, pct_raw in pattern: pct = float(pct_raw) if any(k in op for k in ["decrease", "discount", "down", "fall"]): factor = 1 - pct / 100.0 step_lines.append(f"A {pct:g}% decrease means multiply by {factor:g}.") else: factor = 1 + pct / 100.0 step_lines.append(f"A {pct:g}% increase means multiply by {factor:g}.") multiplier *= factor net_change = (multiplier - 1.0) * 100.0 direction = "increase" if net_change >= 0 else "decrease" magnitude = abs(net_change) return _make_result( topic="percent", answer_value=f"{magnitude:g}%", internal_answer=f"net {direction} of {magnitude:g}%", steps=step_lines + [f"Combine the multipliers to find the overall percent change."], choices_text=text, ) def _extract_ratio_labels(text: str) -> Optional[Tuple[str, str]]: m = re.search(r"ratio of ([a-z ]+?) to ([a-z ]+?) is \d+\s*:\s*\d+", text.lower()) if not m: return None left = normalize_spaces(m.group(1)).rstrip("s") right = normalize_spaces(m.group(2)).rstrip("s") return left, right def _solve_ratio_total(text: str) -> Optional[SolverResult]: lower = clean_math_text(text).lower() ratio_match = re.search(r"(\d+)\s*:\s*(\d+)", lower) total_match = re.search(r"(?:total|altogether|in all|sum)\s*(?:is|=|of)?\s*(\d+)", lower) if not ratio_match or not total_match: return None a = int(ratio_match.group(1)) b = int(ratio_match.group(2)) total = int(total_match.group(1)) part_sum = a + b if part_sum == 0: return None unit = total / part_sum left_value = a * unit right_value = b * unit labels = _extract_ratio_labels(lower) requested_value = left_value if labels: left_label, right_label = labels if left_label in lower and re.search(rf"how many {re.escape(left_label)}", lower): requested_value = left_value elif right_label in lower and re.search(rf"how many {re.escape(right_label)}", lower): requested_value = right_value else: requested_value = left_value return _make_result( topic="ratio", answer_value=f"{requested_value:g}", internal_answer=f"{requested_value:g}", steps=[ f"Add the ratio parts: {a} + {b} = {part_sum}.", f"Find the value of one ratio part using the total.", f"Multiply by the required ratio part.", ], choices_text=text, ) def _solve_remainder(text: str) -> Optional[SolverResult]: lower = clean_math_text(text).lower() m = re.search(r"remainder .*? when (\d+) is divided by (\d+)", lower) if not m: m = re.search(r"(\d+)\s*(?:mod|%)\s*(\d+)", lower) if not m: return None a = int(m.group(1)) b = int(m.group(2)) if b == 0: return None r = a % b return _make_result( topic="number_theory", answer_value=str(r), internal_answer=str(r), steps=[ f"Divide {a} by {b}.", "Keep the amount left over after division.", ], choices_text=text, ) def _solve_percent(text: str) -> Optional[SolverResult]: lower = clean_math_text(text).lower() choices = extract_choices(text) patterns = [ r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(?:a\s+)?number\s+is\s+(\d+(?:\.\d+)?)", r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+([a-z])\s+is\s+(\d+(?:\.\d+)?)", r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+([a-z])\s*=\s*(\d+(?:\.\d+)?)", ] for pat in patterns: m = re.search(pat, lower) if not m: continue if len(m.groups()) == 2: p = float(m.group(1)) value = float(m.group(2)) else: p = float(m.group(1)) value = float(m.group(3)) if p == 0: return None ans = value / (p / 100.0) answer_letter = _best_choice(ans, choices) if choices else None return SolverResult( domain="quant", solved=True, topic="percent", answer_value=f"{ans:g}", answer_letter=answer_letter, internal_answer=f"{ans:g}", steps=[ "Let the unknown quantity be a variable.", f"Convert {p:g}% into its decimal form.", "Write an equation for the part and whole relationship.", "Solve that equation for the unknown.", ], ) m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower) if m: p = float(m.group(1)) n = float(m.group(2)) ans = p / 100.0 * n answer_letter = _best_choice(ans, choices) if choices else None return SolverResult( domain="quant", solved=True, topic="percent", answer_value=f"{ans:g}", answer_letter=answer_letter, internal_answer=f"{ans:g}", steps=[ f"Convert {p:g}% to a decimal.", f"Multiply that decimal by {n}.", ], ) return None def _solve_mean_median(text: str) -> Optional[SolverResult]: lower = clean_math_text(text).lower() nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)] if not nums: return None if "mean" in lower or "average" in lower: ans = mean(nums) return SolverResult( domain="quant", solved=True, topic="statistics", answer_value=f"{ans:g}", internal_answer=f"{ans:g}", steps=["Add the values.", f"Divide by {len(nums)}."], ) if "median" in lower: ans = median(nums) return SolverResult( domain="quant", solved=True, topic="statistics", answer_value=f"{ans:g}", internal_answer=f"{ans:g}", steps=["Order the values.", "Take the middle value."], ) return None def _solve_linear_equation(text: str) -> Optional[SolverResult]: if sp is None: return None expr = _extract_equation(text) if not expr: return None try: lhs, rhs = expr.split("=", 1) lhs_prepped = _prepare_expression(lhs) rhs_prepped = _prepare_expression(rhs) symbols = _extract_variable_names(expr) if not symbols: return None var_name = symbols[0] var = sp.symbols(var_name) equation = sp.Eq(sp.sympify(lhs_prepped), sp.sympify(rhs_prepped)) sol = sp.solve(equation, var) if not sol: return None value = sol[0] try: as_float = float(value) except Exception: as_float = None choices = extract_choices(text) return SolverResult( domain="quant", solved=True, topic="algebra", answer_value=str(value), answer_letter=_best_choice(as_float, choices) if (as_float is not None and choices) else None, internal_answer=f"{var_name} = {value}", steps=[ "Treat the statement as an equation.", "Undo operations on both sides to isolate the variable.", "Keep simplifying until the variable is alone.", ], ) except Exception: return None def solve_quant(text: str) -> SolverResult: text = text or "" for fn in ( _solve_successive_percent, _solve_ratio_total, _solve_remainder, _solve_percent, _solve_mean_median, _solve_linear_equation, ): result = fn(text) if result is not None: return result return SolverResult( domain="quant", solved=False, topic="general_quant", reply="This looks quantitative, but it does not match a strong rule-based pattern yet.", steps=[ "Identify the quantity the question wants.", "Translate the wording into an equation, ratio, or diagram.", "Carry out the calculation carefully.", ], )