model111 / larm /data /envs /math_vision_env.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
from typing import Dict, List
import os
import re
from larm.data.envs.base_env import StaticEnv
from larm.common.registry import registry
from larm.memory_generator.trainer.verifier import verify_solution_equivalence
@registry.register_env("math_vision")
class MathVisionEnv(StaticEnv):
def __init__(self, config):
super().__init__(config)
@classmethod
def _accuracy_reward(cls, completions: List[str], solution: List[str], **kwargs) -> List[float]:
def _extract_answer(text: str) -> str:
"""Extract answer from text. Supports multiple formats:
1. Multiple choice: single letter answer (A, B, C, D, E)
2. Numerical/text answer: direct answer value
3. <answer> tags for backward compatibility
4. \\boxed{} LaTeX format
"""
try:
# First try to extract from <answer> tags (for backward compatibility)
low = text.lower()
s = low.find("<answer>")
e = low.find("</answer>")
if s != -1 and e != -1 and e > s:
return text[s + len("<answer>") : e].strip()
# Try to find \boxed{...} pattern in the completion
# Extract content between matching braces after \boxed{
boxed_pattern = r"\\boxed\{"
matches = list(re.finditer(boxed_pattern, text))
if matches:
# Get the last occurrence of \boxed{
last_match = matches[-1]
start_pos = last_match.end() # Position after \boxed{
# Find the matching closing brace by counting braces
brace_count = 1
pos = start_pos
while pos < len(text) and brace_count > 0:
if text[pos] == '{':
brace_count += 1
elif text[pos] == '}':
brace_count -= 1
pos += 1
if brace_count == 0:
# Found matching closing brace at pos-1
return text[start_pos:pos-1].strip()
# For multiple choice, try to extract single letter answer
# Look for pattern like "Answer: A" or "The answer is B" at the end
mc_pattern = r"(?:answer|Answer|ANSWER)[\s:]*([A-E])\b"
mc_matches = re.findall(mc_pattern, text)
if mc_matches:
return mc_matches[-1].strip()
# Fallback: return empty string to indicate extraction failed
return ""
except Exception:
pass
return ""
def _normalize_answer(text: str) -> str:
"""Normalize answer for comparison (remove whitespace, convert to lowercase)"""
# For multiple choice, just uppercase and strip
if len(text) == 1 and text.upper() in ['A', 'B', 'C', 'D', 'E']:
return text.upper()
# For other answers, remove all whitespace and lowercase
return re.sub(r'\s+', '', text.lower().strip())
# Check if LLM verifier is enabled (default: False for simple matching)
use_llm_verifier = os.environ.get("USE_LLM_VERIFIER", "false").lower() in ("true", "1", "yes")
scores: List[float] = []
for c, s in zip(completions, solution):
candidate = _extract_answer(c)
ground_truth = _extract_answer(s) if s else s # Extract from ground truth to normalize format
# If extraction failed, answer is wrong
if not candidate or not ground_truth:
scores.append(0.0)
continue
try:
if use_llm_verifier:
# Use LLM-based verification (more flexible, requires API)
ok = verify_solution_equivalence(candidate, ground_truth)
else:
# Use simple string matching (fast, no API cost)
ok = _normalize_answer(candidate) == _normalize_answer(ground_truth)
except Exception:
ok = False
scores.append(1.0 if ok else 0.0)
return scores
@classmethod
def _format_reward(cls, completions: List[str], **kwargs):
pass