| # Adapted from https://github.com/openai/simple-evals/ | |
| import os | |
| import resource | |
| import time | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from multiprocessing.pool import ThreadPool | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import httpx | |
| import jinja2 | |
| import numpy as np | |
| import openai | |
| import requests | |
| from openai import OpenAI | |
| from tqdm import tqdm | |
| OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." | |
| OPENAI_SYSTEM_MESSAGE_CHATGPT = ( | |
| "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." | |
| + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" | |
| ) | |
| Message = Dict[str, Any] # keys role, content | |
| MessageList = List[Message] | |
| class SamplerBase: | |
| """ | |
| Base class for defining a sampling model, which can be evaluated, | |
| or used as part of the grading process. | |
| """ | |
| def __call__(self, message_list: MessageList) -> str: | |
| raise NotImplementedError() | |
| class EvalResult: | |
| """ | |
| Result of running an evaluation (usually consisting of many samples) | |
| """ | |
| score: Optional[float] # top-line metric | |
| metrics: Optional[Dict[str, float]] # other metrics | |
| htmls: List[str] # strings of valid HTML | |
| convos: List[MessageList] # sampled conversations | |
| class SingleEvalResult: | |
| """ | |
| Result of evaluating a single sample | |
| """ | |
| score: Optional[float] | |
| metrics: Dict[str, float] = field(default_factory=dict) | |
| html: Optional[str] = None | |
| convo: Optional[MessageList] = None # sampled conversation | |
| class Eval: | |
| """ | |
| Base class for defining an evaluation. | |
| """ | |
| def __call__(self, sampler: SamplerBase) -> EvalResult: | |
| raise NotImplementedError() | |
| class LargerHttpxClient(httpx.Client): | |
| def __init__(self): | |
| timeout_config = httpx.Timeout(3600) | |
| limits = httpx.Limits( | |
| max_keepalive_connections=3600, | |
| max_connections=3600, | |
| ) | |
| super().__init__(timeout=timeout_config, limits=limits) | |
| class ChatCompletionSampler(SamplerBase): | |
| """ | |
| Sample from OpenAI's chat completion API | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str = None, | |
| model: Optional[str] = None, | |
| system_message: Optional[str] = None, | |
| temperature: float = 0.0, | |
| reasoning_effort: Optional[str] = None, | |
| max_tokens: int = 2048, | |
| extra_body: Optional[Dict[str, Any]] = None, | |
| ): | |
| self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient()) | |
| if model is None: | |
| model = self.client.models.list().data[0].id | |
| self.model = model | |
| self.system_message = system_message | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self.reasoning_effort = reasoning_effort | |
| self.extra_body = extra_body | |
| self.image_format = "url" | |
| print( | |
| f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}" | |
| ) | |
| def _handle_image( | |
| self, | |
| image: str, | |
| encoding: str = "base64", | |
| format: str = "png", | |
| fovea: int = 768, | |
| ): | |
| new_image = { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/{format};{encoding},{image}", | |
| }, | |
| } | |
| return new_image | |
| def _handle_text(self, text: str): | |
| return {"type": "text", "text": text} | |
| def _pack_message(self, role: str, content: Any): | |
| return {"role": str(role), "content": content} | |
| def __call__(self, message_list: MessageList) -> str: | |
| if self.system_message: | |
| message_list = [ | |
| self._pack_message("system", self.system_message) | |
| ] + message_list | |
| trial = 0 | |
| while trial < 6: # 126 seconds in total | |
| try: | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=message_list, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| reasoning_effort=self.reasoning_effort, | |
| extra_body=self.extra_body, | |
| ) | |
| return response.choices[0].message.content | |
| # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU | |
| except openai.BadRequestError as e: | |
| print("Bad Request Error", e) | |
| return "" | |
| except Exception as e: | |
| exception_backoff = 2**trial # expontial back off | |
| print( | |
| f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", | |
| e, | |
| ) | |
| time.sleep(exception_backoff) | |
| trial += 1 | |
| # unknown error shall throw exception | |
| QUERY_TEMPLATE_MULTICHOICE = """ | |
| Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. | |
| {Question} | |
| A) {A} | |
| B) {B} | |
| C) {C} | |
| D) {D} | |
| """.strip() | |
| ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" | |
| ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" | |
| EQUALITY_TEMPLATE = r""" | |
| Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications | |
| Examples: | |
| Expression 1: $2x+3$ | |
| Expression 2: $3+2x$ | |
| Yes | |
| Expression 1: 3/2 | |
| Expression 2: 1.5 | |
| Yes | |
| Expression 1: $x^2+2x+1$ | |
| Expression 2: $y^2+2y+1$ | |
| No | |
| Expression 1: $x^2+2x+1$ | |
| Expression 2: $(x+1)^2$ | |
| Yes | |
| Expression 1: 3245/5 | |
| Expression 2: 649 | |
| No | |
| (these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) | |
| Expression 1: 2/(-3) | |
| Expression 2: -2/3 | |
| Yes | |
| (trivial simplifications are allowed) | |
| Expression 1: 72 degrees | |
| Expression 2: 72 | |
| Yes | |
| (give benefit of the doubt to units) | |
| Expression 1: 64 | |
| Expression 2: 64 square feet | |
| Yes | |
| (give benefit of the doubt to units) | |
| --- | |
| YOUR TASK | |
| Respond with only "Yes" or "No" (without quotes). Do not include a rationale. | |
| Expression 1: %(expression1)s | |
| Expression 2: %(expression2)s | |
| """.strip() | |
| HTML_JINJA = """ | |
| <h3>Prompt conversation</h3> | |
| {% for message in prompt_messages %} | |
| {{ message_to_html(message) | safe }} | |
| {% endfor %} | |
| <h3>Sampled message</h3> | |
| {{ message_to_html(next_message) | safe }} | |
| <h3>Results</h3> | |
| <p>Correct Answer: {{ correct_answer }}</p> | |
| <p>Extracted Answer: {{ extracted_answer }}</p> | |
| <p>Score: {{ score }}</p> | |
| """ | |
| def format_multichoice_question(row): | |
| return QUERY_TEMPLATE_MULTICHOICE.format(**row) | |
| def check_equality(sampler: SamplerBase, expr1: str, expr2: str): | |
| prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} | |
| response = sampler([dict(content=prompt, role="user")]) | |
| return response.lower().strip() == "yes" | |
| def _compute_stat(values: list, stat: str): | |
| if stat == "mean": | |
| return np.mean(values) | |
| elif stat == "std": | |
| return np.std(values) | |
| elif stat == "min": | |
| return np.min(values) | |
| elif stat == "max": | |
| return np.max(values) | |
| else: | |
| raise ValueError(f"Unknown {stat =}") | |
| def aggregate_results( | |
| single_eval_results: List[SingleEvalResult], | |
| default_stats: Tuple[str] = ("mean", "std"), | |
| name2stats: Optional[Dict[str, Tuple[str]]] = None, | |
| ) -> EvalResult: | |
| """ | |
| Aggregate results from multiple evaluations into a single EvalResult. | |
| """ | |
| name2stats = name2stats or {} | |
| name2values = defaultdict(list) | |
| htmls = [] | |
| convos = [] | |
| for single_eval_result in single_eval_results: | |
| # Skip None results | |
| if single_eval_result is None: | |
| continue | |
| for name, value in single_eval_result.metrics.items(): | |
| name2values[name].append(value) | |
| if single_eval_result.score is not None: | |
| name2values["score"].append(single_eval_result.score) | |
| htmls.append(single_eval_result.html) | |
| convos.append(single_eval_result.convo) | |
| final_metrics = {} | |
| for name, values in name2values.items(): | |
| stats = name2stats.get(name, default_stats) | |
| for stat in stats: | |
| key = name if stat == "mean" else f"{name}:{stat}" | |
| final_metrics[key] = _compute_stat(values, stat) | |
| return EvalResult( | |
| score=final_metrics.pop("score", None), | |
| metrics=final_metrics, | |
| htmls=htmls, | |
| convos=convos, | |
| ) | |
| def map_with_progress(f: callable, xs: List[Any], num_threads: int): | |
| """ | |
| Apply f to each element of xs, using a ThreadPool, and show progress. | |
| """ | |
| if os.getenv("debug"): | |
| return list(map(f, tqdm(xs, total=len(xs)))) | |
| else: | |
| with ThreadPool(min(num_threads, len(xs))) as pool: | |
| return list(tqdm(pool.imap(f, xs), total=len(xs))) | |
| jinja_env = jinja2.Environment( | |
| loader=jinja2.BaseLoader(), | |
| undefined=jinja2.StrictUndefined, | |
| autoescape=jinja2.select_autoescape(["html", "xml"]), | |
| ) | |
| _message_template = """ | |
| <div class="message {{ role }}"> | |
| <div class="role"> | |
| {{ role }} | |
| {% if variant %}<span class="variant">({{ variant }})</span>{% endif %} | |
| </div> | |
| <div class="content"> | |
| <pre>{{ content }}</pre> | |
| </div> | |
| </div> | |
| """ | |
| def message_to_html(message: Message) -> str: | |
| """ | |
| Generate HTML snippet (inside a <div>) for a message. | |
| """ | |
| return jinja_env.from_string(_message_template).render( | |
| role=message["role"], | |
| content=message["content"], | |
| variant=message.get("variant", None), | |
| ) | |
| jinja_env.globals["message_to_html"] = message_to_html | |
| _report_template = """<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <style> | |
| .message { | |
| padding: 8px 16px; | |
| margin-bottom: 8px; | |
| border-radius: 4px; | |
| } | |
| .message.user { | |
| background-color: #B2DFDB; | |
| color: #00695C; | |
| } | |
| .message.assistant { | |
| background-color: #B39DDB; | |
| color: #4527A0; | |
| } | |
| .message.system { | |
| background-color: #EEEEEE; | |
| color: #212121; | |
| } | |
| .role { | |
| font-weight: bold; | |
| margin-bottom: 4px; | |
| } | |
| .variant { | |
| color: #795548; | |
| } | |
| table, th, td { | |
| border: 1px solid black; | |
| } | |
| pre { | |
| white-space: pre-wrap; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| {% if metrics %} | |
| <h1>Metrics</h1> | |
| <table> | |
| <tr> | |
| <th>Metric</th> | |
| <th>Value</th> | |
| </tr> | |
| <tr> | |
| <td><b>Score</b></td> | |
| <td>{{ score | float | round(3) }}</td> | |
| </tr> | |
| {% for name, value in metrics.items() %} | |
| <tr> | |
| <td>{{ name }}</td> | |
| <td>{{ value }}</td> | |
| </tr> | |
| {% endfor %} | |
| </table> | |
| {% endif %} | |
| <h1>Examples</h1> | |
| {% for html in htmls %} | |
| {{ html | safe }} | |
| <hr> | |
| {% endfor %} | |
| </body> | |
| </html> | |
| """ | |
| def make_report(eval_result: EvalResult) -> str: | |
| """ | |
| Create a standalone HTML report from an EvalResult. | |
| """ | |
| return jinja_env.from_string(_report_template).render( | |
| score=eval_result.score, | |
| metrics=eval_result.metrics, | |
| htmls=eval_result.htmls, | |
| ) | |
| def make_report_from_example_htmls(htmls: List[str]): | |
| """ | |
| Create a standalone HTML report from a list of example htmls | |
| """ | |
| return jinja_env.from_string(_report_template).render( | |
| score=None, metrics={}, htmls=htmls | |
| ) | |
| def download_dataset(path, url): | |
| print(f"Downloading dataset {path} from {url}") | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get("content-length", 0)) | |
| block_size = 8192 | |
| with open(path, "wb") as f, tqdm( | |
| desc="Downloading", | |
| total=total_size, | |
| unit="iB", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as progress_bar: | |
| for data in response.iter_content(block_size): | |
| size = f.write(data) | |
| progress_bar.update(size) | |
| print(f"Dataset downloaded and saved to {path}") | |
| except requests.RequestException as e: | |
| raise Exception(f"Failed to download dataset: {e}") | |
| def set_ulimit(target_soft_limit=65535): | |
| resource_type = resource.RLIMIT_NOFILE | |
| current_soft, current_hard = resource.getrlimit(resource_type) | |
| if current_soft < target_soft_limit: | |
| try: | |
| resource.setrlimit(resource_type, (target_soft_limit, current_hard)) | |
| except ValueError as e: | |
| print(f"Fail to set RLIMIT_NOFILE: {e}") | |
Xet Storage Details
- Size:
- 13 kB
- Xet hash:
- b78d04a1f5849030ab7c751c0f2129ed0941b331e2ca636b95882498acdd3520
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.