|
|
import io |
|
|
import os |
|
|
from collections import defaultdict |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from multiprocessing.pool import ThreadPool |
|
|
from typing import Any, Callable |
|
|
|
|
|
import jinja2 |
|
|
import numpy as np |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .types import EvalResult, Message, SamplerBase, SingleEvalResult |
|
|
|
|
|
QUERY_TEMPLATE_MULTICHOICE = """ |
|
|
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. |
|
|
|
|
|
{Question} |
|
|
|
|
|
A) {A} |
|
|
B) {B} |
|
|
C) {C} |
|
|
D) {D} |
|
|
""".strip() |
|
|
|
|
|
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?" |
|
|
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" |
|
|
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( |
|
|
"(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" |
|
|
) |
|
|
|
|
|
MULTILINGUAL_ANSWER_REGEXES = [ |
|
|
"Answer\s*:", |
|
|
"Answer\s*:", |
|
|
"উত্তর\s*:", |
|
|
"उत्तर\s*:", |
|
|
"উত্তরঃ", |
|
|
"উত্তর\s*:", |
|
|
"Antwort\s*:", |
|
|
"답변\s*:", |
|
|
"정답\s*:", |
|
|
"답\s*:", |
|
|
"答案\s*:", |
|
|
"答案\s*:", |
|
|
"答\s*:", |
|
|
"答\s*:", |
|
|
"答复\s*:", |
|
|
"答曰\s*:", |
|
|
"الإجابة:", |
|
|
"الجواب:", |
|
|
"إجابة:", |
|
|
"الإجابة النهائية:", |
|
|
"الإجابة الصحيحة:", |
|
|
"الإجابة الصحيحة هي:", |
|
|
"الإجابة هي:", |
|
|
"الجواب النهائي:", |
|
|
"Respuesta\s*:", |
|
|
"Risposta\s*:", |
|
|
"答え\s*:", |
|
|
"答え\s*:", |
|
|
"回答\s*:", |
|
|
"回答\s*:", |
|
|
"解答\s*:", |
|
|
"Jawaban\s*:", |
|
|
"Réponse\s*:", |
|
|
"Resposta\s*:", |
|
|
"Jibu\s*:", |
|
|
"Idahun\s*:", |
|
|
"Ìdáhùn\s*:", |
|
|
"Idáhùn\s*:", |
|
|
"Àmọ̀nà\s*:", |
|
|
"Àdáhùn\s*:", |
|
|
"Ànúgọ\s*:", |
|
|
"Àṣàyàn\s*:", |
|
|
] |
|
|
|
|
|
|
|
|
EQUALITY_TEMPLATE = r""" |
|
|
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications |
|
|
|
|
|
Examples: |
|
|
|
|
|
Expression 1: $2x+3$ |
|
|
Expression 2: $3+2x$ |
|
|
|
|
|
Yes |
|
|
|
|
|
Expression 1: 3/2 |
|
|
Expression 2: 1.5 |
|
|
|
|
|
Yes |
|
|
|
|
|
Expression 1: $x^2+2x+1$ |
|
|
Expression 2: $y^2+2y+1$ |
|
|
|
|
|
No |
|
|
|
|
|
Expression 1: $x^2+2x+1$ |
|
|
Expression 2: $(x+1)^2$ |
|
|
|
|
|
Yes |
|
|
|
|
|
Expression 1: 3245/5 |
|
|
Expression 2: 649 |
|
|
|
|
|
No |
|
|
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) |
|
|
|
|
|
Expression 1: 2/(-3) |
|
|
Expression 2: -2/3 |
|
|
|
|
|
Yes |
|
|
(trivial simplifications are allowed) |
|
|
|
|
|
Expression 1: 72 degrees |
|
|
Expression 2: 72 |
|
|
|
|
|
Yes |
|
|
(give benefit of the doubt to units) |
|
|
|
|
|
Expression 1: 64 |
|
|
Expression 2: 64 square feet |
|
|
|
|
|
Yes |
|
|
(give benefit of the doubt to units) |
|
|
|
|
|
--- |
|
|
|
|
|
YOUR TASK |
|
|
|
|
|
|
|
|
Respond with only "Yes" or "No" (without quotes). Do not include a rationale. |
|
|
|
|
|
Expression 1: %(expression1)s |
|
|
Expression 2: %(expression2)s |
|
|
""".strip() |
|
|
|
|
|
|
|
|
HTML_JINJA = """ |
|
|
<h3>Prompt conversation</h3> |
|
|
{% for message in prompt_messages %} |
|
|
{{ message_to_html(message) | safe }} |
|
|
{% endfor %} |
|
|
<h3>Sampled message</h3> |
|
|
{{ message_to_html(next_message) | safe }} |
|
|
<h3>Results</h3> |
|
|
<p>Correct Answer: {{ correct_answer }}</p> |
|
|
<p>Extracted Answer: {{ extracted_answer }}</p> |
|
|
<p>Score: {{ score }}</p> |
|
|
""" |
|
|
|
|
|
|
|
|
def format_multichoice_question(row): |
|
|
return QUERY_TEMPLATE_MULTICHOICE.format(**row) |
|
|
|
|
|
|
|
|
def check_equality(sampler: SamplerBase, expr1: str, expr2: str): |
|
|
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} |
|
|
sampler_response = sampler([dict(content=prompt, role="user")]) |
|
|
response_text = sampler_response.response_text |
|
|
return response_text.lower().strip() == "yes" |
|
|
|
|
|
|
|
|
def _compute_stat(values: list, stat: str): |
|
|
if stat == "mean": |
|
|
return np.mean(values) |
|
|
elif stat == "std": |
|
|
return np.std(values) |
|
|
elif stat == "min": |
|
|
return np.min(values) |
|
|
elif stat == "max": |
|
|
return np.max(values) |
|
|
elif stat == "n_samples": |
|
|
return len(values) |
|
|
elif stat == "bootstrap_std": |
|
|
return np.std( |
|
|
[np.mean(np.random.choice(values, len(values))) for _ in range(1000)] |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown {stat =}") |
|
|
|
|
|
|
|
|
def aggregate_results( |
|
|
single_eval_results: list[SingleEvalResult], |
|
|
default_stats: tuple[str, ...] = ("mean", "std"), |
|
|
name2stats: dict[str, tuple[str]] | None = None, |
|
|
) -> EvalResult: |
|
|
""" |
|
|
Aggregate results from multiple evaluations into a single EvalResult. |
|
|
""" |
|
|
name2stats = name2stats or {} |
|
|
name2values = defaultdict(list) |
|
|
htmls = [] |
|
|
convos = [] |
|
|
metadata = [] |
|
|
for single_eval_result in single_eval_results: |
|
|
for name, value in single_eval_result.metrics.items(): |
|
|
name2values[name].append(value) |
|
|
if single_eval_result.score is not None: |
|
|
name2values["score"].append(single_eval_result.score) |
|
|
htmls.append(single_eval_result.html) |
|
|
convos.append(single_eval_result.convo) |
|
|
metadata.append(single_eval_result.example_level_metadata) |
|
|
final_metrics = {} |
|
|
for name, values in name2values.items(): |
|
|
stats = name2stats.get(name, default_stats) |
|
|
for stat in stats: |
|
|
key = name if stat == "mean" else f"{name}:{stat}" |
|
|
final_metrics[key] = _compute_stat(values, stat) |
|
|
return EvalResult( |
|
|
score=final_metrics.pop("score", None), |
|
|
metrics=final_metrics, |
|
|
htmls=htmls, |
|
|
convos=convos, |
|
|
metadata={"example_level_metadata": metadata}, |
|
|
) |
|
|
|
|
|
|
|
|
def map_with_progress( |
|
|
f: Callable, |
|
|
xs: list[Any], |
|
|
num_threads: int = os.cpu_count() or 10, |
|
|
pbar: bool = True, |
|
|
): |
|
|
""" |
|
|
Apply f to each element of xs, using a ThreadPool, and show progress. |
|
|
""" |
|
|
pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x |
|
|
|
|
|
if os.getenv("debug"): |
|
|
return list(map(f, pbar_fn(xs, total=len(xs)))) |
|
|
else: |
|
|
with ThreadPool(min(num_threads, len(xs))) as pool: |
|
|
return list(pbar_fn(pool.imap(f, xs), total=len(xs))) |
|
|
|
|
|
|
|
|
jinja_env = jinja2.Environment( |
|
|
loader=jinja2.BaseLoader(), |
|
|
undefined=jinja2.StrictUndefined, |
|
|
autoescape=jinja2.select_autoescape(["html", "xml"]), |
|
|
) |
|
|
_message_template = """ |
|
|
<div class="message {{ role }}"> |
|
|
<div class="role"> |
|
|
{{ role }} |
|
|
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %} |
|
|
</div> |
|
|
<div class="content"> |
|
|
<pre>{{ content }}</pre> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
def message_to_html(message: Message) -> str: |
|
|
""" |
|
|
Generate HTML snippet (inside a <div>) for a message. |
|
|
""" |
|
|
return jinja_env.from_string(_message_template).render( |
|
|
role=message["role"], |
|
|
content=message["content"], |
|
|
variant=message.get("variant", None), |
|
|
) |
|
|
|
|
|
|
|
|
jinja_env.globals["message_to_html"] = message_to_html |
|
|
|
|
|
|
|
|
_report_template = """<!DOCTYPE html> |
|
|
<html> |
|
|
<head> |
|
|
<style> |
|
|
.message { |
|
|
padding: 8px 16px; |
|
|
margin-bottom: 8px; |
|
|
border-radius: 4px; |
|
|
} |
|
|
.message.user { |
|
|
background-color: #B2DFDB; |
|
|
color: #00695C; |
|
|
} |
|
|
.message.assistant { |
|
|
background-color: #B39DDB; |
|
|
color: #4527A0; |
|
|
} |
|
|
.message.system { |
|
|
background-color: #EEEEEE; |
|
|
color: #212121; |
|
|
} |
|
|
.role { |
|
|
font-weight: bold; |
|
|
margin-bottom: 4px; |
|
|
} |
|
|
.variant { |
|
|
color: #795548; |
|
|
} |
|
|
table, th, td { |
|
|
border: 1px solid black; |
|
|
} |
|
|
pre { |
|
|
white-space: pre-wrap; |
|
|
} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
{% if metrics %} |
|
|
<h1>Metrics</h1> |
|
|
<table> |
|
|
<tr> |
|
|
<th>Metric</th> |
|
|
<th>Value</th> |
|
|
</tr> |
|
|
<tr> |
|
|
<td><b>Score</b></td> |
|
|
<td>{{ score | float | round(3) }}</td> |
|
|
</tr> |
|
|
{% for name, value in metrics.items() %} |
|
|
<tr> |
|
|
<td>{{ name }}</td> |
|
|
<td>{{ value }}</td> |
|
|
</tr> |
|
|
{% endfor %} |
|
|
</table> |
|
|
{% endif %} |
|
|
<h1>Examples</h1> |
|
|
{% for html in htmls %} |
|
|
{{ html | safe }} |
|
|
<hr> |
|
|
{% endfor %} |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
|
|
|
|
|
|
def make_report(eval_result: EvalResult) -> str: |
|
|
""" |
|
|
Create a standalone HTML report from an EvalResult. |
|
|
""" |
|
|
return jinja_env.from_string(_report_template).render( |
|
|
score=eval_result.score, |
|
|
metrics=eval_result.metrics, |
|
|
htmls=eval_result.htmls, |
|
|
) |
|
|
|
|
|
|
|
|
def make_report_from_example_htmls(htmls: list[str]): |
|
|
""" |
|
|
Create a standalone HTML report from a list of example htmls |
|
|
""" |
|
|
return jinja_env.from_string(_report_template).render( |
|
|
score=None, metrics={}, htmls=htmls |
|
|
) |
|
|
|
|
|
|
|
|
def normalize_response(response: str) -> str: |
|
|
""" |
|
|
Normalize the response by removing markdown and LaTeX formatting that may prevent a match. |
|
|
""" |
|
|
|
|
|
return ( |
|
|
response.replace("**", "") |
|
|
.replace("$\\boxed{", "") |
|
|
.replace("}$", "") |
|
|
.replace("\\$", "") |
|
|
.replace("$\\text{", "") |
|
|
.replace("$", "") |
|
|
.replace("\\mathrm{", "") |
|
|
.replace("\\{", "") |
|
|
.replace("\\text", "") |
|
|
.replace("\\(", "") |
|
|
.replace("\\mathbf{", "") |
|
|
.replace("{", "") |
|
|
.replace("\\boxed", "") |
|
|
) |
|
|
|
|
|
|
|
|
def normalize_extracted_answer(extracted_answer: str) -> str: |
|
|
return ( |
|
|
|
|
|
extracted_answer.replace("أ", " A") |
|
|
.replace("ب", " B") |
|
|
.replace("ج", " C") |
|
|
.replace("د", " D") |
|
|
|
|
|
.replace("অ", " A") |
|
|
.replace("ব", " B") |
|
|
.replace("ড", " C") |
|
|
.replace("ঢ", " D") |
|
|
|
|
|
.replace("A", " A") |
|
|
.replace("B", " B") |
|
|
.replace("C", " C") |
|
|
.replace("D", " D") |
|
|
.strip() |
|
|
) |
|
|
|
|
|
|
|
|
def url_to_fileobj(url: str, binary=False) -> Any: |
|
|
response = requests.get(url) |
|
|
response.raise_for_status() |
|
|
return io.BytesIO(response.content) if binary else io.StringIO(response.text) |
|
|
|
|
|
|
|
|
def has_only_user_assistant_messages(messages: list[Message]) -> bool: |
|
|
""" |
|
|
Check if the messages only contain user and assistant messages. |
|
|
""" |
|
|
return all(m["role"] in ("user", "assistant") for m in messages) |