AnLan577's picture
Upload 2 files
d248e7a verified
import argparse, json, os, re, subprocess, sys, tempfile
from typing import Dict, Any, List, Optional, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import json
YOUR_SYSTEM_PROMPT = """# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name": "python_interpreter", "description": "Executes a complete, runnable Python code. The code must be well-structured, preferably with a 'main()' function, and wrapped in a Python markdown block as shown in the parameter description.\\nNote: Common packages (math, collections, re, etc.) are pre-imported, so do NOT include standard 'import' statements.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "A string containing a full Python code.\\nThe code should define its logic within a 'main()' function and call it using an 'if __name__ == \\\\\"__main__\\\\":' block. \\nThe final, concise result must be printed to the console using the 'print()' function.\\n\\nFor example:\\ndef main():\\n # Your code logic here\\n result = 10 * 5\\n print(result)\\n\\nif __name__ == \\\\\"__main__\\\\":\\n main()"}}, "required": ["code"]}}}
{"type": "function", "function": {"name": "unit_conversion", "description": "Performs precise physical unit conversions. Handles length, mass, time, pressure, energy, and complex compound units (e.g., velocity, acceleration).\\nIMPORTANT: For temperature, use 'degC', 'degF', 'K'. For scientific notation like '1.8 * 10^5 Pa', the value must be fully resolved to a number.", "parameters": {"type": "object", "properties": {"value": {"type": "number", "description": "The complete numerical value to be converted, as a single float or integer. If the value is in scientific notation (e.g., 1.8e5 or 1.8 * 10**5), calculate and pass the final number. Example: for '1.8 * 10^5', pass 180000."}, "source_unit": {"type": "string", "description": "The unit of the input 'value'. Must be a standard physical unit symbol ONLY. Do NOT include numbers or scientific notation like '10^5' or 'e5'. Supports compound units with '/' for division and '^' or '**' for powers. Examples: 'kg', 'km/h', 'm/s^2', 'Pa', 'MJ', 'degC'."}, "target_unit": {"type": "string", "description": "The desired unit to convert to. Must be dimensionally compatible with the source unit. Examples: 'g', 'm/s', 'ft/s^2', 'atm', 'kWh', 'degF'."}}, "required": ["value", "source_unit", "target_unit"]}}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>"""
USER_TMPL = '''You are a diligent and precise assistant tasked with evaluating the correctness of responses. You will receive a question, an output sentence, and the correct answer. Your task is to determine if the output sentence corectly answers the question based on the provided correct answer. You can perform a short tool call and a short reasoning process. After a short reasoing process, put your response in the final with either [Correct] or [Incorrect] wrapped in \\boxed{{}}
Evaluation Protocol:
1. Reference Standard:
- The standard (gold) answer is definitive and always correct.
- The question is always valid — never challenge it.
- Allow equivalenct meaning answers.
- Do not regenerate answers; only compare candidate's final answer with the gold answer.
- You only need to compare correct answer and output sentence, do not regenerate or judge correct answer.
2. Comparison Method:
- Analyze the question's requirements and the gold answer's structure.
- Determine if the question requires exact matching or allows equivalence.
- Compare ONLY the candidate's final answer. Ignore reasoning errors.
- Ignore differences in formatting or style.
- For math expressions: check algebraic equivalence step by step; if uncertain, test numerically at multiple points.
- For multiple-choice: only compare the final choice and its content.
3. Multi-part Answers:
- All parts must match the gold answer exactly.
- Partial matches are incorrect.
- If not specified, answer order may vary. For example, \\frac{{27}}{{7}}, -\\frac{{8}}{{7}} and -\\frac{{8}}{{7}}, \\frac{{27}}{{7}} are equivalent.
-
Special considerations:
1. **Mathematical Problems**:
- If the formats differ but the answers are mathematically equivalent after simplfying or rounding to two decimal places(e.g. 2.909 vs \\frac{{32}}{{11}}, \\frac{{32}}{{11}} vs \\frac{{96}}{{33}}), respond with [Correct].
- You only need to verify the correctness of the mathematical expression, not values unrelated to the overall expression, such as the domain or units(e.g. 16 vs 16km, 20 vs 20db), these cases will be considered as [Correct].
- You may need to calculate the value or converse the value to different unit when needed to match the reference answer.
2. **Multi-choice questions**:
- If the question provides explicit candidate answers(e.g. multi-choice questions), the output will be considered correct if it clearly indicates the correct option's content or the correct option's code.
3. **Fact quuestions**:
- If the question provides fact-seeking answers, the output must align with the correct answer in content to be considered [Correct].
4. **Multiple Reference Answers**:
-If multiple reference answers are equivalent, just matching one answer will be considered [Correct].
-If multiple reference answers are inequivalent, only mathcing all answers will be considered [Correct].
5. **Ohter conditions**:
- If incomplete (cut off, unfinished sentence) → Label as [Incorrect].
- If repetitive (looping words/phrases) → Label as [Incorrect].
- Gives an answer but then negates it at the end. → Label as [Incorrect].
- Numerically correct but without units. → Label as [Correct].
-
-
You can use following tools to help your verification process:
1. **Python Intepreter**: When you feel needed, you can use a python inteperter to help you determine your verification result.
2. **Unit Conversion Tool**: When faced with different physical units, you can use a unit conversion tool to convert them into the same unit.
-
Question: """{question}"""
Output sentence: """{pred}"""
Correct answer: {reference}
Judgement:'''
# ======= Tool execution (demo) =======
def run_python_interpreter(args: Dict[str, Any]) -> str:
code = str(args.get("code", ""))
# SECURITY NOTE: This executes arbitrary code. Demo only.
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as tf:
tf.write(code)
path = tf.name
try:
proc = subprocess.run(
[sys.executable, path],
capture_output=True,
text=True,
timeout=float(args.get("timeout", 5.0)),
)
return json.dumps(
{"stdout": proc.stdout.strip(), "stderr": proc.stderr.strip(), "returncode": proc.returncode},
ensure_ascii=False,
)
except subprocess.TimeoutExpired:
return "ERROR: Timeout"
except Exception as e:
return f"ERROR: {e}"
finally:
try: os.remove(path)
except Exception: pass
def run_unit_conversion(args: Dict[str, Any]) -> str:
try:
import pint
except Exception:
return "ERROR: pint not installed. pip install pint"
try:
value = float(args["value"])
src = str(args["source_unit"])
tgt = str(args["target_unit"])
except Exception as e:
return f"ERROR: bad args: {e}"
ureg = pint.UnitRegistry(system="mks")
try:
q = ureg.Quantity(value, src)
res = q.to(tgt)
return f"{float(res.magnitude)} {res.units}"
except pint.errors.DimensionalityError as e:
return f"ERROR: Dimensionality mismatch: {e}"
except pint.errors.UndefinedUnitError as e:
return f"ERROR: Undefined unit: {e}"
except Exception as e:
return f"ERROR: {e}"
TOOL_IMPLS = {
"python_interpreter": run_python_interpreter,
"unit_conversion": run_unit_conversion,
}
# ======= Helpers =======
TC_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
def parse_tool_call(text: str) -> Optional[Dict[str, Any]]:
m = TC_RE.search(text)
if not m:
return None
try:
obj = json.loads(m.group(1))
name = obj.get("name")
args = obj.get("arguments", {})
if isinstance(name, str) and isinstance(args, dict):
return {"name": name, "arguments": args}
except Exception:
return None
return None
def apply_chat(tokenizer, messages: List[Dict[str, str]]):
if hasattr(tokenizer, "apply_chat_template"):
txt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return tokenizer(txt, return_tensors="pt")
txt = ""
for m in messages:
txt += f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n"
return tokenizer(txt, return_tensors="pt")
def parse_verdict(text: str) -> str:
import re
m = re.search(r"\\boxed\{\s*\[?(Correct|Incorrect)\]?\s*\}", text, flags=re.IGNORECASE)
if not m:
return "INCORRECT"
return "CORRECT" if m.group(1).lower() == "correct" else "INCORRECT"
def load_model(model_path: str):
if torch.cuda.is_available():
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
device_map = "auto"
else:
dtype = torch.float32
device_map = None
tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval().cuda()
return model, tok
@torch.inference_mode()
def generate_reply(model, tokenizer, messages, max_new_tokens=128) -> str:
inputs = apply_chat(tokenizer, messages)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
cfg = GenerationConfig(
temperature=0.6,
top_p=0.85,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
out = model.generate(
**inputs,
do_sample=True,
temperature=0.6,
top_p=0.85,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
gen_ids = out[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
def run_case(model, tok, question: str, pred: str, ref: str, allow_tool_rounds=5) -> Tuple[str, List[Dict[str,str]], str]:
msgs = [
{"role": "system", "content": YOUR_SYSTEM_PROMPT},
{"role": "user", "content": USER_TMPL.format(question=question, pred=pred, reference=ref)},
]
# First response (may include <tool_call>)
assistant = generate_reply(model, tok, msgs, max_new_tokens=2048)
msgs.append({"role": "assistant", "content": assistant})
print(assistant)
for _ in range(allow_tool_rounds):
call = parse_tool_call(assistant)
if not call:
break
name = call["name"]
args = call["arguments"]
impl = TOOL_IMPLS.get(name)
if impl is None:
result_payload = {
"name": name,
"status": "Error",
"error": f"Unknown tool '{name}'"
}
else:
try:
raw = impl(args)
if isinstance(raw, str):
result_payload = {
"name": name,
"status": "Success",
"run_result": {
"stdout": raw,
"stderr": "",
"exit_success": True,
"return_code": 0,
"status": "Finished"
}
}
else:
result_payload = {"name": name, **raw}
result_payload.setdefault("status", "Success")
except Exception as e:
result_payload = {
"name": name,
"status": "Error",
"error": repr(e)
}
tool_response_msg = {
"role": "user",
"content": "<tool_response>\n"
+ json.dumps(result_payload, ensure_ascii=False)
+ "\n</tool_response>"
}
msgs.append(tool_response_msg)
print(tool_response_msg)
assistant = generate_reply(model, tok, msgs, max_new_tokens=2048)
print(assistant)
msgs.append({"role": "assistant", "content": assistant})
verdict = parse_verdict(assistant)
return verdict, msgs, assistant
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True, help="Local path or HF id of your verifier")
ap.add_argument("--tool_turns", required=False, help="Number of Turns", default="5")
args = ap.parse_args()
model, tok = load_model(args.model)
cases = [
{
"question": "A car accelerates at 0.02 km/s². Convert this acceleration to m/s².",
"pred": "20 m/s^2",
"ref": "20 m/s²",
"gt": "[Correct]"
},
{
"question": "Please predict the dot-bracket notation of the secondary structure directly from the RNA sequence: CAACUAAAUCCACCCUUGCGGGUGGGUGAAAUAUUGCUUCGCAAUAUGAAAUACGCUUUCAGCGUAUGAAAUCGCUG",
"ref": ".......((((((((....))))))))...((((((.....))))))...(((((((...)))))))..........",
"pred": "......((((((((....))))))))...((((((.....))))))...(((((((...)))))))..........",
"gt": "[Incorrect]"
},
{
"question": "The number of particles in a classical ideal gas of monoatomic molecules is $N$, the volume is $V$, and the temperature is $T$. Each atom has two internal energy levels $\\varepsilon_1 = 0$ and $\\varepsilon_2 = \\Delta$. Find the chemical potential $\\mu$?",
"ref": "['\\\\boxed{\\\\mu = kT \\\\ln(n V_q) - kT \\\\ln z_{\\\\text{int}}}']",
"pred": "\\mu = kT \\ln\\left( \\frac{N V_q}{V z_{\\text{int}}}",
"gt": "[Correct]",
},
{
"question": "A particle $m$ undergoes projectile motion with an initial velocity of $v_{0}$ and a horizontal angle of elevation $\\alpha$. Solve using the Hamilton-Jacobi equation.",
"ref": "['y=x \\\\tan \\\\alpha-\\\\frac{g x^{2}}{2 v_{0}^{2} \\\\cos ^{2} \\\\alpha}']",
"pred": "y = x \\tan \\alpha - \\frac{g x^2 \\tan^2 \\alpha}{2 v_0^2} - \\frac{g x^2}{2 v_0^2}",
"gt": "[Correct]"
},
{
"question": """以下是中国数学竞赛中的解答题,答案类型为表达式。请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以\"所以最终答案是\\\\boxed{答案}。\"显式给出结果。\\n$已知a_n为等差数列,前n项和为S_n(n\\\\in N^),b_n是首项为2的等比数列,且公比大于0,b_2+b_3=12,b_3=a_4-2a_1,S_{11}=11b_4.$\\n$求数列{a_{2n}b_{n}}的前n项和(n\\\\in N^{}).$\\n\\n请通过逐步推理来解答问题,并把最终答案放置于\\\\boxed{}中。""",
"ref": "$(3n-4)2^{n+2}+16$",
"pred": "16 + (12n - 16)2^n",
"gt": "[Correct]"
},
]
for i, c in enumerate(cases, 1):
print(f"==================Case {i}==================")
print(c)
verdict, msgs, final_output = run_case(model, tok, c["question"], c["pred"], c["ref"])
if __name__ == "__main__":
main()