| |
| import os |
| import re |
| import backoff |
| import re |
| from pathlib import Path |
| import openai |
| from dotenv import load_dotenv |
|
|
|
|
| env_path = Path(__file__).parent.parent.parent / ".env" |
| load_dotenv(dotenv_path=env_path, override=True) |
|
|
| M = 1_000_000 |
|
|
| ANSWER_REGEX = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?") |
|
|
|
|
| def extract_numeric_answer(text: str) -> str: |
| """Pulls the last number found in the model's reply.""" |
| matches = ANSWER_REGEX.findall(text.replace(",", "")) |
| return matches[-1].lstrip("0") if matches else text.strip() |
|
|
|
|
| def backoff_handler(details): |
| exc = details.get("exception") |
| if exc: |
| print( |
| f"OpenAI - Retry {details['tries']} due to error: {exc}. " |
| f"Waiting {details['wait']:0.1f}s..." |
| ) |
|
|
|
|
| costs_per_token = { |
| "gpt-4.1-nano": {"input": 0.1 / M, "output": 0.4 / M}, |
| "gpt-4.1-mini": {"input": 0.4 / M, "output": 1.6 / M}, |
| "gpt-4.1": {"input": 2.0 / M, "output": 8.0 / M}, |
| "gpt-4o-mini": {"input": 0.15 / M, "output": 0.6 / M}, |
| "o4-mini": {"input": 1.1 / M, "output": 4.4 / M}, |
| } |
|
|
|
|
| class MaxCallsExceededError(Exception): |
| """Raised when the maximum number of LLM calls is exceeded.""" |
|
|
| pass |
|
|
|
|
| def create_call_limited_query_llm(base_query_llm, max_calls=3): |
| """ |
| Creates a wrapper around query_llm that limits the number of calls |
| per forward pass. |
| |
| Args: |
| base_query_llm: The original query_llm function |
| max_calls: Maximum number of calls allowed (default: 3) |
| |
| Returns: |
| A wrapped query_llm function with call limiting |
| """ |
| import threading |
|
|
| thread_local = threading.local() |
|
|
| def limited_query_llm(*args, **kwargs): |
| |
| if not hasattr(thread_local, "call_count"): |
| thread_local.call_count = 0 |
|
|
| if thread_local.call_count >= max_calls: |
| raise MaxCallsExceededError( |
| f"Maximum number of LLM calls ({max_calls}) exceeded" |
| ) |
| thread_local.call_count += 1 |
| return base_query_llm(*args, **kwargs) |
|
|
| def reset_calls(): |
| thread_local.call_count = 0 |
|
|
| def get_call_count(): |
| return getattr(thread_local, "call_count", 0) |
|
|
| |
| limited_query_llm.reset_calls = reset_calls |
| limited_query_llm.get_call_count = get_call_count |
|
|
| return limited_query_llm |
|
|
|
|
| @backoff.on_exception( |
| backoff.expo, |
| ( |
| openai.APIConnectionError, |
| openai.APIStatusError, |
| openai.RateLimitError, |
| openai.APITimeoutError, |
| ), |
| max_tries=20, |
| max_value=20, |
| on_backoff=backoff_handler, |
| ) |
| def query_llm(prompt, system, temperature=0.0, model_name="gpt-4.1-nano"): |
| |
| |
| |
| |
| |
| client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| if system is not None: |
| messages = [ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": prompt}, |
| ] |
| else: |
| messages = [{"role": "user", "content": prompt}] |
|
|
| if model_name == "o4-mini": |
| temperature = 1.0 |
|
|
| response = client.chat.completions.create( |
| model=model_name, |
| messages=messages, |
| temperature=temperature, |
| |
| ) |
| out_text = response.choices[0].message.content |
| cost = ( |
| response.usage.prompt_tokens * costs_per_token[model_name]["input"] |
| + response.usage.completion_tokens * costs_per_token[model_name]["output"] |
| ) |
| return out_text, cost |
|
|
|
|
| |
| |
| def is_equiv(str1, str2, verbose=False): |
| if str1 is None and str2 is None: |
| print("WARNING: Both None") |
| return True |
| if str1 is None or str2 is None: |
| return False |
|
|
| try: |
| ss1 = strip_string(str1) |
| ss2 = strip_string(str2) |
| if verbose: |
| print(ss1, ss2) |
| return ss1 == ss2 |
| except Exception: |
| return str1 == str2 |
|
|
|
|
| def clean_answer(s): |
| |
| s = s.replace("\\dfrac", "\\frac") |
| s = s.replace("x \\in", "") |
|
|
| |
| s = re.sub(r"\\mathbf\s*{([^}]*)}", r"\1", s) |
| s = re.sub(r"\\textbf\s*{([^}]*)}", r"\1", s) |
| return s |
|
|
|
|
| def remove_boxed(s): |
| if "\\boxed " in s: |
| left = "\\boxed " |
| assert s[: len(left)] == left |
| return s[len(left) :] |
|
|
| left = "\\boxed{" |
| if not s.startswith(left): |
| return None |
|
|
| assert s[-1] == "}" |
|
|
| return clean_answer(s[len(left) : -1]) |
|
|
|
|
| def last_boxed_only_string(string: str) -> str: |
| """ |
| Extracts the last LaTeX \\boxed{...} or \\fbox{...} command from a string. |
| Handles nested braces. If no \\boxed is found, returns an empty string. |
| """ |
| idx = string.rfind("\\boxed") |
| if idx < 0: |
| idx = string.rfind("\\fbox") |
| if idx < 0: |
| return "" |
|
|
| |
| brace_idx = string.find("{", idx) |
| if brace_idx < 0: |
| return "" |
|
|
| |
| level = 0 |
| for i in range(brace_idx, len(string)): |
| if string[i] == "{": |
| level += 1 |
| elif string[i] == "}": |
| level -= 1 |
| if level == 0: |
| return string[idx : i + 1] |
|
|
| return "" |
|
|
|
|
| def fix_fracs(string): |
| substrs = string.split("\\frac") |
| new_str = substrs[0] |
| if len(substrs) > 1: |
| substrs = substrs[1:] |
| for substr in substrs: |
| new_str += "\\frac" |
| if substr[0] == "{": |
| new_str += substr |
| else: |
| try: |
| assert len(substr) >= 2 |
| except AssertionError: |
| return string |
| a = substr[0] |
| b = substr[1] |
| if b != "{": |
| if len(substr) > 2: |
| post_substr = substr[2:] |
| new_str += "{" + a + "}{" + b + "}" + post_substr |
| else: |
| new_str += "{" + a + "}{" + b + "}" |
| else: |
| if len(substr) > 2: |
| post_substr = substr[2:] |
| new_str += "{" + a + "}" + b + post_substr |
| else: |
| new_str += "{" + a + "}" + b |
| string = new_str |
| return string |
|
|
|
|
| def fix_a_slash_b(string): |
| if len(string.split("/")) != 2: |
| return string |
| a = string.split("/")[0] |
| b = string.split("/")[1] |
| try: |
| a = int(a) |
| b = int(b) |
| assert string == "{}/{}".format(a, b) |
| new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
| return new_string |
| except AssertionError: |
| return string |
|
|
|
|
| def remove_right_units(string): |
| |
| |
| if "\\text{ " in string: |
| splits = string.split("\\text{ ") |
| assert len(splits) == 2 |
| return splits[0] |
| else: |
| return string |
|
|
|
|
| def fix_sqrt(string): |
| if "\\sqrt" not in string: |
| return string |
| splits = string.split("\\sqrt") |
| new_string = splits[0] |
| for split in splits[1:]: |
| if split[0] != "{": |
| a = split[0] |
| new_substr = "\\sqrt{" + a + "}" + split[1:] |
| else: |
| new_substr = "\\sqrt" + split |
| new_string += new_substr |
| return new_string |
|
|
|
|
| def strip_string(string): |
| |
| string = string.replace("\n", "") |
|
|
| |
| string = string.replace("\\!", "") |
|
|
| |
| string = string.replace("\\\\", "\\") |
|
|
| |
| string = string.replace("tfrac", "frac") |
| string = string.replace("dfrac", "frac") |
|
|
| |
| string = string.replace("\\left", "") |
| string = string.replace("\\right", "") |
|
|
| |
| string = string.replace("^{\\circ}", "") |
| string = string.replace("^\\circ", "") |
|
|
| |
| string = string.replace("\\$", "") |
|
|
| |
| string = remove_right_units(string) |
|
|
| |
| string = string.replace("\\%", "") |
| string = string.replace("\%", "") |
|
|
| |
| |
| string = string.replace(" .", " 0.") |
| string = string.replace("{.", "{0.") |
| |
| if len(string) == 0: |
| return string |
| if string[0] == ".": |
| string = "0" + string |
|
|
| |
| if len(string.split("=")) == 2: |
| if len(string.split("=")[0]) <= 2: |
| string = string.split("=")[1] |
|
|
| |
| string = fix_sqrt(string) |
|
|
| |
| string = string.replace(" ", "") |
|
|
| |
| |
| |
| string = fix_fracs(string) |
|
|
| |
| if string == "0.5": |
| string = "\\frac{1}{2}" |
| if string == "5.5": |
| string = "\\frac{11}{2}" |
| if "(x - 3)(x + 3)" in string: |
| string = string.replace("(x - 3)(x + 3)", "(x+3)(x-3)") |
|
|
| |
| |
| string = fix_a_slash_b(string) |
|
|
| return string |
|
|