Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| from collections import defaultdict | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from multiprocessing.pool import ThreadPool | |
| from typing import Any, Callable | |
| import jinja2 | |
| import numpy as np | |
| import requests | |
| from tqdm import tqdm | |
| from .types import EvalResult, Message, SamplerBase, SingleEvalResult | |
| QUERY_TEMPLATE_MULTICHOICE = """ | |
| Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. | |
| {Question} | |
| A) {A} | |
| B) {B} | |
| C) {C} | |
| D) {D} | |
| """.strip() | |
| ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?" | |
| ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" | |
| MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( | |
| "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" | |
| ) | |
| # All the different ways "Answer" is written in different languages | |
| MULTILINGUAL_ANSWER_REGEXES = [ | |
| "Answer\s*:", | |
| "Answer\s*:", # Korean invisible character | |
| "উত্তর\s*:", | |
| "उत्तर\s*:", | |
| "উত্তরঃ", | |
| "উত্তর\s*:", | |
| "Antwort\s*:", | |
| "답변\s*:", | |
| "정답\s*:", | |
| "답\s*:", | |
| "答案\s*:", | |
| "答案\s*:", | |
| "答\s*:", | |
| "答\s*:", | |
| "答复\s*:", | |
| "答曰\s*:", | |
| "الإجابة:", | |
| "الجواب:", | |
| "إجابة:", | |
| "الإجابة النهائية:", | |
| "الإجابة الصحيحة:", | |
| "الإجابة الصحيحة هي:", | |
| "الإجابة هي:", | |
| "الجواب النهائي:", | |
| "Respuesta\s*:", | |
| "Risposta\s*:", | |
| "答え\s*:", | |
| "答え\s*:", | |
| "回答\s*:", | |
| "回答\s*:", | |
| "解答\s*:", | |
| "Jawaban\s*:", | |
| "Réponse\s*:", | |
| "Resposta\s*:", | |
| "Jibu\s*:", | |
| "Idahun\s*:", | |
| "Ìdáhùn\s*:", | |
| "Idáhùn\s*:", | |
| "Àmọ̀nà\s*:", | |
| "Àdáhùn\s*:", | |
| "Ànúgọ\s*:", | |
| "Àṣàyàn\s*:", | |
| ] | |
| EQUALITY_TEMPLATE = r""" | |
| Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications | |
| Examples: | |
| Expression 1: $2x+3$ | |
| Expression 2: $3+2x$ | |
| Yes | |
| Expression 1: 3/2 | |
| Expression 2: 1.5 | |
| Yes | |
| Expression 1: $x^2+2x+1$ | |
| Expression 2: $y^2+2y+1$ | |
| No | |
| Expression 1: $x^2+2x+1$ | |
| Expression 2: $(x+1)^2$ | |
| Yes | |
| Expression 1: 3245/5 | |
| Expression 2: 649 | |
| No | |
| (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) | |
| Expression 1: 2/(-3) | |
| Expression 2: -2/3 | |
| Yes | |
| (trivial simplifications are allowed) | |
| Expression 1: 72 degrees | |
| Expression 2: 72 | |
| Yes | |
| (give benefit of the doubt to units) | |
| Expression 1: 64 | |
| Expression 2: 64 square feet | |
| Yes | |
| (give benefit of the doubt to units) | |
| --- | |
| YOUR TASK | |
| Respond with only "Yes" or "No" (without quotes). Do not include a rationale. | |
| Expression 1: %(expression1)s | |
| Expression 2: %(expression2)s | |
| """.strip() | |
| HTML_JINJA = """ | |
| <h3>Prompt conversation</h3> | |
| {% for message in prompt_messages %} | |
| {{ message_to_html(message) | safe }} | |
| {% endfor %} | |
| <h3>Sampled message</h3> | |
| {{ message_to_html(next_message) | safe }} | |
| <h3>Results</h3> | |
| <p>Correct Answer: {{ correct_answer }}</p> | |
| <p>Extracted Answer: {{ extracted_answer }}</p> | |
| <p>Score: {{ score }}</p> | |
| """ | |
| def format_multichoice_question(row): | |
| return QUERY_TEMPLATE_MULTICHOICE.format(**row) | |
| def check_equality(sampler: SamplerBase, expr1: str, expr2: str): | |
| prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} | |
| sampler_response = sampler([dict(content=prompt, role="user")]) | |
| response_text = sampler_response.response_text | |
| return response_text.lower().strip() == "yes" | |
| def _compute_stat(values: list, stat: str): | |
| if stat == "mean": | |
| return np.mean(values) | |
| elif stat == "std": | |
| return np.std(values) | |
| elif stat == "min": | |
| return np.min(values) | |
| elif stat == "max": | |
| return np.max(values) | |
| elif stat == "n_samples": | |
| return len(values) | |
| elif stat == "bootstrap_std": | |
| return np.std( | |
| [np.mean(np.random.choice(values, len(values))) for _ in range(1000)] | |
| ) | |
| else: | |
| raise ValueError(f"Unknown {stat =}") | |
| def aggregate_results( | |
| single_eval_results: list[SingleEvalResult], | |
| default_stats: tuple[str, ...] = ("mean", "std"), | |
| name2stats: dict[str, tuple[str]] | None = None, | |
| ) -> EvalResult: | |
| """ | |
| Aggregate results from multiple evaluations into a single EvalResult. | |
| """ | |
| name2stats = name2stats or {} | |
| name2values = defaultdict(list) | |
| htmls = [] | |
| convos = [] | |
| metadata = [] | |
| for single_eval_result in single_eval_results: | |
| for name, value in single_eval_result.metrics.items(): | |
| name2values[name].append(value) | |
| if single_eval_result.score is not None: | |
| name2values["score"].append(single_eval_result.score) | |
| htmls.append(single_eval_result.html) | |
| convos.append(single_eval_result.convo) | |
| metadata.append(single_eval_result.example_level_metadata) | |
| final_metrics = {} | |
| for name, values in name2values.items(): | |
| stats = name2stats.get(name, default_stats) | |
| for stat in stats: | |
| key = name if stat == "mean" else f"{name}:{stat}" | |
| final_metrics[key] = _compute_stat(values, stat) | |
| return EvalResult( | |
| score=final_metrics.pop("score", None), | |
| metrics=final_metrics, | |
| htmls=htmls, | |
| convos=convos, | |
| metadata={"example_level_metadata": metadata}, | |
| ) | |
| def map_with_progress( | |
| f: Callable, | |
| xs: list[Any], | |
| num_threads: int = os.cpu_count() or 10, | |
| pbar: bool = True, | |
| ): | |
| """ | |
| Apply f to each element of xs, using a ThreadPool, and show progress. | |
| """ | |
| pbar_fn = tqdm if pbar else lambda x, *args, **kwargs: x | |
| if os.getenv("debug"): | |
| return list(map(f, pbar_fn(xs, total=len(xs)))) | |
| else: | |
| with ThreadPool(min(num_threads, len(xs))) as pool: | |
| return list(pbar_fn(pool.imap(f, xs), total=len(xs))) | |
| jinja_env = jinja2.Environment( | |
| loader=jinja2.BaseLoader(), | |
| undefined=jinja2.StrictUndefined, | |
| autoescape=jinja2.select_autoescape(["html", "xml"]), | |
| ) | |
| _message_template = """ | |
| <div class="message {{ role }}"> | |
| <div class="role"> | |
| {{ role }} | |
| {% if variant %}<span class="variant">({{ variant }})</span>{% endif %} | |
| </div> | |
| <div class="content"> | |
| <pre>{{ content }}</pre> | |
| </div> | |
| </div> | |
| """ | |
| def message_to_html(message: Message) -> str: | |
| """ | |
| Generate HTML snippet (inside a <div>) for a message. | |
| """ | |
| return jinja_env.from_string(_message_template).render( | |
| role=message["role"], | |
| content=message["content"], | |
| variant=message.get("variant", None), | |
| ) | |
| jinja_env.globals["message_to_html"] = message_to_html | |
| _report_template = """<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <style> | |
| .message { | |
| padding: 8px 16px; | |
| margin-bottom: 8px; | |
| border-radius: 4px; | |
| } | |
| .message.user { | |
| background-color: #B2DFDB; | |
| color: #00695C; | |
| } | |
| .message.assistant { | |
| background-color: #B39DDB; | |
| color: #4527A0; | |
| } | |
| .message.system { | |
| background-color: #EEEEEE; | |
| color: #212121; | |
| } | |
| .role { | |
| font-weight: bold; | |
| margin-bottom: 4px; | |
| } | |
| .variant { | |
| color: #795548; | |
| } | |
| table, th, td { | |
| border: 1px solid black; | |
| } | |
| pre { | |
| white-space: pre-wrap; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| {% if metrics %} | |
| <h1>Metrics</h1> | |
| <table> | |
| <tr> | |
| <th>Metric</th> | |
| <th>Value</th> | |
| </tr> | |
| <tr> | |
| <td><b>Score</b></td> | |
| <td>{{ score | float | round(3) }}</td> | |
| </tr> | |
| {% for name, value in metrics.items() %} | |
| <tr> | |
| <td>{{ name }}</td> | |
| <td>{{ value }}</td> | |
| </tr> | |
| {% endfor %} | |
| </table> | |
| {% endif %} | |
| <h1>Examples</h1> | |
| {% for html in htmls %} | |
| {{ html | safe }} | |
| <hr> | |
| {% endfor %} | |
| </body> | |
| </html> | |
| """ | |
| def make_report(eval_result: EvalResult) -> str: | |
| """ | |
| Create a standalone HTML report from an EvalResult. | |
| """ | |
| return jinja_env.from_string(_report_template).render( | |
| score=eval_result.score, | |
| metrics=eval_result.metrics, | |
| htmls=eval_result.htmls, | |
| ) | |
| def make_report_from_example_htmls(htmls: list[str]): | |
| """ | |
| Create a standalone HTML report from a list of example htmls | |
| """ | |
| return jinja_env.from_string(_report_template).render( | |
| score=None, metrics={}, htmls=htmls | |
| ) | |
| def normalize_response(response: str) -> str: | |
| """ | |
| Normalize the response by removing markdown and LaTeX formatting that may prevent a match. | |
| """ | |
| return ( | |
| response.replace("**", "") | |
| .replace("$\\boxed{", "") | |
| .replace("}$", "") | |
| .replace("\\$", "") | |
| .replace("$\\text{", "") | |
| .replace("$", "") | |
| .replace("\\mathrm{", "") | |
| .replace("\\{", "") | |
| .replace("\\text", "") | |
| .replace("\\(", "") | |
| .replace("\\mathbf{", "") | |
| .replace("{", "") | |
| .replace("\\boxed", "") | |
| ) | |
| def normalize_extracted_answer(extracted_answer: str) -> str: | |
| return ( | |
| # In arabic these are the letters used for A-D in multiple choice questions | |
| extracted_answer.replace("أ", " A") | |
| .replace("ب", " B") | |
| .replace("ج", " C") | |
| .replace("د", " D") | |
| # In Bengali these are the letters used for A-D in multiple choice questions | |
| .replace("অ", " A") | |
| .replace("ব", " B") | |
| .replace("ড", " C") | |
| .replace("ঢ", " D") | |
| # In Japanese these are the letters sometimes used for A-D in multiple choice questions | |
| .replace("A", " A") | |
| .replace("B", " B") | |
| .replace("C", " C") | |
| .replace("D", " D") | |
| .strip() | |
| ) | |
| def url_to_fileobj(url: str, binary=False) -> Any: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| return io.BytesIO(response.content) if binary else io.StringIO(response.text) | |
| def has_only_user_assistant_messages(messages: list[Message]) -> bool: | |
| """ | |
| Check if the messages only contain user and assistant messages. | |
| """ | |
| return all(m["role"] in ("user", "assistant") for m in messages) |