dpe1's picture
download
raw
6.87 kB
import re
from src.capabilities import lookup_dictionary, evaluate_math
def reward_reasoning_tag(completion):
"""Reward for including 'Reasoning:' tag."""
if "Reasoning:" in completion:
return 0.1
return 0.0
def reward_answer_tag(completion):
"""Reward for including 'Answer:' tag."""
if "Answer:" in completion:
return 0.1
return 0.0
def reward_capability_syntax(completion):
"""Reward for correct capability syntax: [CAP]payload[CAPABILITY_STOP]result[CAPABILITY_STOP]."""
# This is tricky because the result and the second stop are injected by the sampler.
# We want to check if the model generated [DEFINE]...[CAPABILITY_STOP] or [SYMPY]...[CAPABILITY_STOP].
pattern = r"\[(DEFINE|SYMPY)\][^\]]+\[CAPABILITY_STOP\]"
reward = 0.0
if re.search(pattern, completion):
reward += 0.5
if "Error evaluating expression" in completion:
reward -= 0.5
return reward
def reward_correctness(completion, reference_answer, task_type):
"""Reward for matching the reference answer."""
# Extract Answer: ...
match = re.search(r"Answer:\s*(.*)", completion)
if not match:
return 0.0
extracted_answer = match.group(1).strip()
if task_type == "math":
# Try to compare numerically if possible
try:
if float(extracted_answer) == float(reference_answer):
return 1.0
except:
if extracted_answer == reference_answer:
return 1.0
else:
# Dictionary: fuzzy match or exact?
# SFT data uses the exact first definition.
if reference_answer.lower() in extracted_answer.lower() or extracted_answer.lower() in reference_answer.lower():
return 1.0
return 0.0
def reward_length_penalty(completion):
"""Penalty for overly long responses."""
if len(completion) > 400:
return -0.1
return 0.0
def reward_use_tool_result(completion):
"""Reward for using the tool result in the reasoning or answer."""
# Check if the result injected between [CAPABILITY_STOP] is present later
pattern = r"\[CAPABILITY_STOP\](.*?)\[CAPABILITY_STOP\]"
matches = re.findall(pattern, completion)
if not matches:
return 0.0
reward = 0.0
# Use the content after the last injected tool result for checking
parts = completion.split("[CAPABILITY_STOP]")
if len(parts) < 3:
return 0.0
after_all_tools = parts[-1].lower()
for i in range(1, len(parts)-1, 2):
result = parts[i].strip()
if result and result != "No definition found." and result.lower() in after_all_tools:
reward += 0.3
return min(reward, 0.6)
def reward_grounding(prompt, completion):
"""Reward for using entities from the prompt in capability calls, and penalize hallucinations."""
# Find capability calls
pattern = r"\[(DEFINE|SYMPY)\](.*?)\[CAPABILITY_STOP\]"
calls = re.findall(pattern, completion)
if not calls:
return 0.0
reward = 0.0
seen_calls = set()
for cap_type, payload in calls:
# Penalty for duplicate calls
call_sig = f"{cap_type}:{payload}"
if call_sig in seen_calls:
reward -= 0.5
seen_calls.add(call_sig)
call_grounded = False
if cap_type == "SYMPY":
# Extract numbers from prompt
prompt_nums = set(re.findall(r"\d+", prompt))
payload_nums = set(re.findall(r"\d+", payload))
if not payload_nums:
reward -= 10.0
else:
for num in payload_nums:
if num in prompt_nums:
reward += 10.0
else:
reward -= 20.0
if any(num in prompt_nums for num in payload_nums):
call_grounded = True
elif cap_type == "DEFINE":
prompt_words = set([w.lower() for w in re.findall(r"\w+", prompt) if len(w) >= 3])
payload_words = set([w.lower() for w in re.findall(r"\w+", payload) if len(w) >= 3])
if not payload_words:
reward -= 10.0
else:
for w in payload_words:
if w in prompt_words:
reward += 10.0
else:
reward -= 20.0
if any(w in prompt_words for w in payload_words):
call_grounded = True
if not call_grounded:
reward -= 10.0 # Strong penalty for each hallucinated call
# Specific anti-hallucination for known mode-collapse words
for mode_word in ["elephant", "jacket", "moss", "bat", "sout", "guitar", "cat", "banana", "tomss", "seet"]:
if mode_word in payload.lower() and mode_word not in prompt.lower():
reward -= 20.0
# Target Grounding Bonus: check if the primary target word/numbers are present
if cap_type == "DEFINE":
# Extract target word from prompt like "What is the definition of apple?"
target_match = re.search(r"definition of ([\w'-]+)", prompt.lower())
if target_match:
target_word = target_match.group(1)
if target_word in payload.lower():
reward += 10.0 # Extra bonus for target word grounding
elif cap_type == "SYMPY":
# Check if all numbers in prompt are in payload
prompt_nums = re.findall(r"\d+", prompt)
if all(num in payload for num in prompt_nums) and len(prompt_nums) > 0:
reward += 10.0 # Extra bonus for full math grounding
return min(max(reward, -100.0), 40.0)
def reward_correct_capability_type(completion, task_type):
"""Reward for calling the right tool type for the task."""
pattern = r"\[(DEFINE|SYMPY)\]"
matches = re.findall(pattern, completion)
if not matches:
return 0.0
reward = 0.0
for cap_type in matches:
if task_type == "math":
if cap_type == "SYMPY":
reward += 0.5
else:
reward -= 1.0
elif task_type == "dict":
if cap_type == "DEFINE":
reward += 0.5
else:
reward -= 1.0
return reward
def get_total_reward(prompt, completion, reference_answer, task_type):
reward = 0.0
reward += reward_reasoning_tag(completion)
reward += reward_answer_tag(completion)
reward += reward_capability_syntax(completion)
reward += reward_correctness(completion, reference_answer, task_type)
reward += reward_length_penalty(completion)
reward += reward_use_tool_result(completion)
reward += reward_grounding(prompt, completion)
reward += reward_correct_capability_type(completion, task_type)
return reward

Xet Storage Details

Size:
6.87 kB
·
Xet hash:
84832bd0f94a1ffbf27e7a99d00d453525657cdb707a24e5aa4890ef3c41b639

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.