| | from typing import Dict, List |
| | import os |
| | import re |
| |
|
| | from larm.data.envs.base_env import StaticEnv |
| | from larm.common.registry import registry |
| | from larm.memory_generator.trainer.verifier import verify_solution_equivalence |
| |
|
| |
|
| | @registry.register_env("math_vision") |
| | class MathVisionEnv(StaticEnv): |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | @classmethod |
| | def _accuracy_reward(cls, completions: List[str], solution: List[str], **kwargs) -> List[float]: |
| | def _extract_answer(text: str) -> str: |
| | """Extract answer from text. Supports multiple formats: |
| | 1. Multiple choice: single letter answer (A, B, C, D, E) |
| | 2. Numerical/text answer: direct answer value |
| | 3. <answer> tags for backward compatibility |
| | 4. \\boxed{} LaTeX format |
| | """ |
| | try: |
| | |
| | low = text.lower() |
| | s = low.find("<answer>") |
| | e = low.find("</answer>") |
| | if s != -1 and e != -1 and e > s: |
| | return text[s + len("<answer>") : e].strip() |
| | |
| | |
| | |
| | boxed_pattern = r"\\boxed\{" |
| | matches = list(re.finditer(boxed_pattern, text)) |
| | |
| | if matches: |
| | |
| | last_match = matches[-1] |
| | start_pos = last_match.end() |
| | |
| | |
| | brace_count = 1 |
| | pos = start_pos |
| | while pos < len(text) and brace_count > 0: |
| | if text[pos] == '{': |
| | brace_count += 1 |
| | elif text[pos] == '}': |
| | brace_count -= 1 |
| | pos += 1 |
| | |
| | if brace_count == 0: |
| | |
| | return text[start_pos:pos-1].strip() |
| | |
| | |
| | |
| | mc_pattern = r"(?:answer|Answer|ANSWER)[\s:]*([A-E])\b" |
| | mc_matches = re.findall(mc_pattern, text) |
| | if mc_matches: |
| | return mc_matches[-1].strip() |
| | |
| | |
| | return "" |
| | except Exception: |
| | pass |
| | return "" |
| | |
| | def _normalize_answer(text: str) -> str: |
| | """Normalize answer for comparison (remove whitespace, convert to lowercase)""" |
| | |
| | if len(text) == 1 and text.upper() in ['A', 'B', 'C', 'D', 'E']: |
| | return text.upper() |
| | |
| | return re.sub(r'\s+', '', text.lower().strip()) |
| |
|
| | |
| | use_llm_verifier = os.environ.get("USE_LLM_VERIFIER", "false").lower() in ("true", "1", "yes") |
| |
|
| | scores: List[float] = [] |
| | for c, s in zip(completions, solution): |
| | candidate = _extract_answer(c) |
| | ground_truth = _extract_answer(s) if s else s |
| | |
| | |
| | if not candidate or not ground_truth: |
| | scores.append(0.0) |
| | continue |
| | |
| | try: |
| | if use_llm_verifier: |
| | |
| | ok = verify_solution_equivalence(candidate, ground_truth) |
| | else: |
| | |
| | ok = _normalize_answer(candidate) == _normalize_answer(ground_truth) |
| | except Exception: |
| | ok = False |
| | |
| | scores.append(1.0 if ok else 0.0) |
| | return scores |
| |
|
| | @classmethod |
| | def _format_reward(cls, completions: List[str], **kwargs): |
| | pass |
| |
|
| |
|