from typing import Dict, List import os import re from larm.data.envs.base_env import StaticEnv from larm.common.registry import registry from larm.memory_generator.trainer.verifier import verify_solution_equivalence @registry.register_env("math_vision") class MathVisionEnv(StaticEnv): def __init__(self, config): super().__init__(config) @classmethod def _accuracy_reward(cls, completions: List[str], solution: List[str], **kwargs) -> List[float]: def _extract_answer(text: str) -> str: """Extract answer from text. Supports multiple formats: 1. Multiple choice: single letter answer (A, B, C, D, E) 2. Numerical/text answer: direct answer value 3. tags for backward compatibility 4. \\boxed{} LaTeX format """ try: # First try to extract from tags (for backward compatibility) low = text.lower() s = low.find("") e = low.find("") if s != -1 and e != -1 and e > s: return text[s + len("") : e].strip() # Try to find \boxed{...} pattern in the completion # Extract content between matching braces after \boxed{ boxed_pattern = r"\\boxed\{" matches = list(re.finditer(boxed_pattern, text)) if matches: # Get the last occurrence of \boxed{ last_match = matches[-1] start_pos = last_match.end() # Position after \boxed{ # Find the matching closing brace by counting braces brace_count = 1 pos = start_pos while pos < len(text) and brace_count > 0: if text[pos] == '{': brace_count += 1 elif text[pos] == '}': brace_count -= 1 pos += 1 if brace_count == 0: # Found matching closing brace at pos-1 return text[start_pos:pos-1].strip() # For multiple choice, try to extract single letter answer # Look for pattern like "Answer: A" or "The answer is B" at the end mc_pattern = r"(?:answer|Answer|ANSWER)[\s:]*([A-E])\b" mc_matches = re.findall(mc_pattern, text) if mc_matches: return mc_matches[-1].strip() # Fallback: return empty string to indicate extraction failed return "" except Exception: pass return "" def _normalize_answer(text: str) -> str: """Normalize answer for comparison (remove whitespace, convert to lowercase)""" # For multiple choice, just uppercase and strip if len(text) == 1 and text.upper() in ['A', 'B', 'C', 'D', 'E']: return text.upper() # For other answers, remove all whitespace and lowercase return re.sub(r'\s+', '', text.lower().strip()) # Check if LLM verifier is enabled (default: False for simple matching) use_llm_verifier = os.environ.get("USE_LLM_VERIFIER", "false").lower() in ("true", "1", "yes") scores: List[float] = [] for c, s in zip(completions, solution): candidate = _extract_answer(c) ground_truth = _extract_answer(s) if s else s # Extract from ground truth to normalize format # If extraction failed, answer is wrong if not candidate or not ground_truth: scores.append(0.0) continue try: if use_llm_verifier: # Use LLM-based verification (more flexible, requires API) ok = verify_solution_equivalence(candidate, ground_truth) else: # Use simple string matching (fast, no API cost) ok = _normalize_answer(candidate) == _normalize_answer(ground_truth) except Exception: ok = False scores.append(1.0 if ok else 0.0) return scores @classmethod def _format_reward(cls, completions: List[str], **kwargs): pass