JustinTX's picture
Add files using upload-large-folder tool
14c9c2b verified
# Adapted from https://github.com/SamuelSchmidgall/AgentLaboratory/blob/main/utils.py
import os
import re
import backoff
import re
from pathlib import Path
import openai
from dotenv import load_dotenv
env_path = Path(__file__).parent.parent.parent / ".env"
load_dotenv(dotenv_path=env_path, override=True)
M = 1_000_000
ANSWER_REGEX = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?")
def extract_numeric_answer(text: str) -> str:
"""Pulls the last number found in the model's reply."""
matches = ANSWER_REGEX.findall(text.replace(",", ""))
return matches[-1].lstrip("0") if matches else text.strip()
def backoff_handler(details):
exc = details.get("exception")
if exc:
print(
f"OpenAI - Retry {details['tries']} due to error: {exc}. "
f"Waiting {details['wait']:0.1f}s..."
)
costs_per_token = {
"gpt-4.1-nano": {"input": 0.1 / M, "output": 0.4 / M},
"gpt-4.1-mini": {"input": 0.4 / M, "output": 1.6 / M},
"gpt-4.1": {"input": 2.0 / M, "output": 8.0 / M},
"gpt-4o-mini": {"input": 0.15 / M, "output": 0.6 / M},
"o4-mini": {"input": 1.1 / M, "output": 4.4 / M},
}
class MaxCallsExceededError(Exception):
"""Raised when the maximum number of LLM calls is exceeded."""
pass
def create_call_limited_query_llm(base_query_llm, max_calls=3):
"""
Creates a wrapper around query_llm that limits the number of calls
per forward pass.
Args:
base_query_llm: The original query_llm function
max_calls: Maximum number of calls allowed (default: 3)
Returns:
A wrapped query_llm function with call limiting
"""
import threading
thread_local = threading.local()
def limited_query_llm(*args, **kwargs):
# Initialize call_count for this thread if it doesn't exist
if not hasattr(thread_local, "call_count"):
thread_local.call_count = 0
if thread_local.call_count >= max_calls:
raise MaxCallsExceededError(
f"Maximum number of LLM calls ({max_calls}) exceeded"
)
thread_local.call_count += 1
return base_query_llm(*args, **kwargs)
def reset_calls():
thread_local.call_count = 0
def get_call_count():
return getattr(thread_local, "call_count", 0)
# Attach reset method to the function
limited_query_llm.reset_calls = reset_calls
limited_query_llm.get_call_count = get_call_count
return limited_query_llm
@backoff.on_exception(
backoff.expo,
(
openai.APIConnectionError,
openai.APIStatusError,
openai.RateLimitError,
openai.APITimeoutError,
),
max_tries=20,
max_value=20,
on_backoff=backoff_handler,
)
def query_llm(prompt, system, temperature=0.0, model_name="gpt-4.1-nano"):
# client = openai.AzureOpenAI(
# api_key=os.getenv("AZURE_OPENAI_API_KEY"),
# api_version=os.getenv("AZURE_API_VERSION"),
# azure_endpoint=os.getenv("AZURE_API_ENDPOINT"),
# )
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
if system is not None:
messages = [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
]
else:
messages = [{"role": "user", "content": prompt}]
if model_name == "o4-mini":
temperature = 1.0
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
# max_tokens=16384,
)
out_text = response.choices[0].message.content
cost = (
response.usage.prompt_tokens * costs_per_token[model_name]["input"]
+ response.usage.completion_tokens * costs_per_token[model_name]["output"]
)
return out_text, cost
# string normalization from:
# https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
def clean_answer(s):
# makes no difference but can lead to errors
s = s.replace("\\dfrac", "\\frac")
s = s.replace("x \\in", "")
# Remove all \mathbf{...} and replace with just the contents
s = re.sub(r"\\mathbf\s*{([^}]*)}", r"\1", s)
s = re.sub(r"\\textbf\s*{([^}]*)}", r"\1", s)
return s
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[: len(left)] == left
return s[len(left) :]
left = "\\boxed{"
if not s.startswith(left):
return None
assert s[-1] == "}"
return clean_answer(s[len(left) : -1])
def last_boxed_only_string(string: str) -> str:
"""
Extracts the last LaTeX \\boxed{...} or \\fbox{...} command from a string.
Handles nested braces. If no \\boxed is found, returns an empty string.
"""
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return ""
# Find the opening brace
brace_idx = string.find("{", idx)
if brace_idx < 0:
return "" # No braces, return empty for robustness.
# Brace matching
level = 0
for i in range(brace_idx, len(string)):
if string[i] == "{":
level += 1
elif string[i] == "}":
level -= 1
if level == 0:
return string[idx : i + 1]
return "" # Mismatched braces
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing
# units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{."
# Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc.
# Even works with \frac1{72} (but not \frac{72}1).
# Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
if string == "5.5":
string = "\\frac{11}{2}"
if "(x - 3)(x + 3)" in string:
string = string.replace("(x - 3)(x + 3)", "(x+3)(x-3)")
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases
# fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string