import io import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from multiprocessing.pool import ThreadPool from typing import Any, Callable import jinja2 import numpy as np import requests from tqdm import tqdm from .types import EvalResult, Message, SamplerBase, SingleEvalResult QUERY_TEMPLATE_MULTICHOICE = """ Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. {Question} A) {A} B) {B} C) {C} D) {D} """.strip() ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?" ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" ) # All the different ways "Answer" is written in different languages MULTILINGUAL_ANSWER_REGEXES = [ "Answer\s*:", "Answer\s*:​​​​​​", # Korean invisible character "উত্তর\s*:", "उत्तर\s*:", "উত্তরঃ", "উত্তর\s*:", "Antwort\s*:", "답변\s*:", "정답\s*:", "답\s*:", "答案\s*:", "答案\s*:", "答\s*:", "答\s*:", "答复\s*:", "答曰\s*:", "الإجابة:", "الجواب:", "إجابة:", "الإجابة النهائية:", "الإجابة الصحيحة:", "الإجابة الصحيحة هي:", "الإجابة هي:", "الجواب النهائي:", "Respuesta\s*:", "Risposta\s*:", "答え\s*:", "答え\s*:", "回答\s*:", "回答\s*:", "解答\s*:", "Jawaban\s*:", "Réponse\s*:", "Resposta\s*:", "Jibu\s*:", "Idahun\s*:", "Ìdáhùn\s*:", "Idáhùn\s*:", "Àmọ̀nà\s*:", "Àdáhùn\s*:", "Ànúgọ\s*:", "Àṣàyàn\s*:", ] EQUALITY_TEMPLATE = r""" Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications Examples: Expression 1: $2x+3$ Expression 2: $3+2x$ Yes Expression 1: 3/2 Expression 2: 1.5 Yes Expression 1: $x^2+2x+1$ Expression 2: $y^2+2y+1$ No Expression 1: $x^2+2x+1$ Expression 2: $(x+1)^2$ Yes Expression 1: 3245/5 Expression 2: 649 No (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) Expression 1: 2/(-3) Expression 2: -2/3 Yes (trivial simplifications are allowed) Expression 1: 72 degrees Expression 2: 72 Yes (give benefit of the doubt to units) Expression 1: 64 Expression 2: 64 square feet Yes (give benefit of the doubt to units) --- YOUR TASK Respond with only "Yes" or "No" (without quotes). Do not include a rationale. Expression 1: %(expression1)s Expression 2: %(expression2)s """.strip() HTML_JINJA = """

Prompt conversation

{% for message in prompt_messages %} {{ message_to_html(message) | safe }} {% endfor %}

Sampled message

{{ message_to_html(next_message) | safe }}

Results

Correct Answer: {{ correct_answer }}

Extracted Answer: {{ extracted_answer }}

Score: {{ score }}

""" def format_multichoice_question(row): return QUERY_TEMPLATE_MULTICHOICE.format(**row) def check_equality(sampler: SamplerBase, expr1: str, expr2: str): prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} sampler_response = sampler([dict(content=prompt, role="user")]) response_text = sampler_response.response_text return response_text.lower().strip() == "yes" def _compute_stat(values: list, stat: str): if stat == "mean": return np.mean(values) elif stat == "std": return np.std(values) elif stat == "min": return np.min(values) elif stat == "max": return np.max(values) elif stat == "n_samples": return len(values) elif stat == "bootstrap_std": return np.std( [np.mean(np.random.choice(values, len(values))) for _ in range(1000)] ) else: raise ValueError(f"Unknown {stat =}") def aggregate_results( single_eval_results: list[SingleEvalResult], default_stats: tuple[str, ...] = ("mean", "std"), name2stats: dict[str, tuple[str]] | None = None, ) -> EvalResult: """ Aggregate results from multiple evaluations into a single EvalResult. """ name2stats = name2stats or {} name2values = defaultdict(list) htmls = [] convos = [] metadata = [] for single_eval_result in single_eval_results: for name, value in single_eval_result.metrics.items(): name2values[name].append(value) if single_eval_result.score is not None: name2values["score"].append(single_eval_result.score) htmls.append(single_eval_result.html) convos.append(single_eval_result.convo) metadata.append(single_eval_result.example_level_metadata) final_metrics = {} for name, values in name2values.items(): stats = name2stats.get(name, default_stats) for stat in stats: key = name if stat == "mean" else f"{name}:{stat}" final_metrics[key] = _compute_stat(values, stat) return EvalResult( score=final_metrics.pop("score", None), metrics=final_metrics, htmls=htmls, convos=convos, metadata={"example_level_metadata": metadata}, ) def map_with_progress( f: Callable, xs: list[Any], num_threads: int = os.cpu_count() or 10, pbar: bool = True, ): """ Apply f to each element of xs, using a ThreadPool, and show progress. """ pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x if os.getenv("debug"): return list(map(f, pbar_fn(xs, total=len(xs)))) else: with ThreadPool(min(num_threads, len(xs))) as pool: return list(pbar_fn(pool.imap(f, xs), total=len(xs))) jinja_env = jinja2.Environment( loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined, autoescape=jinja2.select_autoescape(["html", "xml"]), ) _message_template = """
{{ role }} {% if variant %}({{ variant }}){% endif %}
{{ content }}
""" def message_to_html(message: Message) -> str: """ Generate HTML snippet (inside a
) for a message. """ return jinja_env.from_string(_message_template).render( role=message["role"], content=message["content"], variant=message.get("variant", None), ) jinja_env.globals["message_to_html"] = message_to_html _report_template = """ {% if metrics %}

Metrics

{% for name, value in metrics.items() %} {% endfor %}
Metric Value
Score {{ score | float | round(3) }}
{{ name }} {{ value }}
{% endif %}

Examples

{% for html in htmls %} {{ html | safe }}
{% endfor %} """ def make_report(eval_result: EvalResult) -> str: """ Create a standalone HTML report from an EvalResult. """ return jinja_env.from_string(_report_template).render( score=eval_result.score, metrics=eval_result.metrics, htmls=eval_result.htmls, ) def make_report_from_example_htmls(htmls: list[str]): """ Create a standalone HTML report from a list of example htmls """ return jinja_env.from_string(_report_template).render( score=None, metrics={}, htmls=htmls ) def normalize_response(response: str) -> str: """ Normalize the response by removing markdown and LaTeX formatting that may prevent a match. """ return ( response.replace("**", "") .replace("$\\boxed{", "") .replace("}$", "") .replace("\\$", "") .replace("$\\text{", "") .replace("$", "") .replace("\\mathrm{", "") .replace("\\{", "") .replace("\\text", "") .replace("\\(", "") .replace("\\mathbf{", "") .replace("{", "") .replace("\\boxed", "") ) def normalize_extracted_answer(extracted_answer: str) -> str: return ( # In arabic these are the letters used for A-D in multiple choice questions extracted_answer.replace("أ", " A") .replace("ب", " B") .replace("ج", " C") .replace("د", " D") # In Bengali these are the letters used for A-D in multiple choice questions .replace("অ", " A") .replace("ব", " B") .replace("ড", " C") .replace("ঢ", " D") # In Japanese these are the letters sometimes used for A-D in multiple choice questions .replace("A", " A") .replace("B", " B") .replace("C", " C") .replace("D", " D") .strip() ) def url_to_fileobj(url: str, binary=False) -> Any: response = requests.get(url) response.raise_for_status() return io.BytesIO(response.content) if binary else io.StringIO(response.text) def has_only_user_assistant_messages(messages: list[Message]) -> bool: """ Check if the messages only contain user and assistant messages. """ return all(m["role"] in ("user", "assistant") for m in messages)