import logging import os import re from collections import defaultdict from pathlib import Path import yaml from latex2sympy2 import latex2sympy from sympy import simplify from word2number import w2n from lmms_eval.llm_judge import get_server from lmms_eval.llm_judge.protocol import ServerConfig try: from dotenv import load_dotenv load_dotenv() except ImportError: pass # dotenv is optional eval_logger = logging.getLogger("lmms-eval") dir_name = os.path.dirname(os.path.abspath(__file__)) emma_config = { "Strategy_Instruction": {"CoT": "Please solve the problem step by step.", "Directly": "Please ensure that your output only contains the final answer without any additional content (such as intermediate reasoning steps)."}, "multi_choice_format": '{context}\n{question}\n{options}\nAnswer with the option\'s letter from the given choices and put the letter in one "\\boxed{{}}". ', "open_ended_format": '{context}\n{question}\nAnswer the question using a single word or phrase and put the answer in one "\\boxed{{}}". ', } with open(Path(__file__).parent / "emma_all.yaml", "r") as f: raw_data = f.readlines() safe_data = [] for i, line in enumerate(raw_data): # remove function definition since yaml load cannot handle it if "!function" not in line: safe_data.append(line) config = yaml.safe_load("".join(safe_data)) hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") cache_dir = os.path.join(hf_home, config["dataset_kwargs"]["cache_dir"]) # Initialize the LLM judge server if config["metadata"]["use_lmms_judge"]: eval_logger.info("Using LMMS judge server for EMMA task.") API_TYPE = os.getenv("API_TYPE", "azure") # Default to azure based on .env # For Azure OpenAI, use DEPLOYMENT_NAME as the model_name DEPLOYMENT_NAME = os.getenv("DEPLOYMENT_NAME", "gpt-4o") server_config = ServerConfig( model_name=DEPLOYMENT_NAME, # Use deployment name for Azure OpenAI ) server = get_server(server_name=API_TYPE, config=server_config) def build_query(sample: dict) -> dict: """Build the text query by combining the context, question and options. The token is still there Return a dictionary with the query and ground truth content. """ context = sample["context"] question = sample["question"] example = "" res_dict = {} strategy = config["metadata"]["strategy"] if sample["type"].lower() == "multiple choice": options = sample["options"] start_chr = "A" for option in options: example += f"{start_chr}: {option}\n" start_chr = chr(ord(start_chr) + 1) empty_prompt_sample_structure = emma_config["multi_choice_format"] empty_prompt = empty_prompt_sample_structure.format(context=context, question=question, options=example) if strategy == "CoT": res_dict["query"] = empty_prompt + emma_config["Strategy_Instruction"]["CoT"] else: res_dict["query"] = empty_prompt + emma_config["Strategy_Instruction"]["Directly"] res_dict["gt_content"] = options[ord(sample["answer"].upper()) - ord("A")] else: empty_prompt_sample_structure = emma_config["open_ended_format"] empty_prompt = empty_prompt_sample_structure.format(context=context, question=question) if strategy == "CoT": res_dict["query"] = empty_prompt + emma_config["Strategy_Instruction"]["CoT"] else: res_dict["query"] = empty_prompt + emma_config["Strategy_Instruction"]["Directly"] res_dict["gt_content"] = sample["answer"] # append existing key and value in data res_dict.update(sample) return res_dict def replace_images_tokens(input_string): "Function to replace tokens with a single token. Use only for interleaved format." for i in range(1, 5): question_text = f"" query_text = "" if question_text in input_string: input_string = input_string.replace(question_text, query_text) return input_string def emma_doc_to_text(doc): res_dict = build_query(doc) return res_dict["query"] def emma_doc_to_visual(doc, lmms_eval_specific_kwargs=None): res_dict = build_query(doc) if config["metadata"]["interleaved_format"]: # ON DEVELOPMENT pass else: image_tokens = re.findall(r"", res_dict["query"]) image_tokens = sorted(list(set([image_token.strip("<>").replace(" ", "_") for image_token in image_tokens]))) visual = [doc[image_token].convert("RGB") for image_token in image_tokens] return visual def emma_doc_to_messages(doc, lmms_eval_specific_kwargs=None): # If you use doc to messages, the interleaved format is always used config["metadata"]["interleaved_format"] = True # Get the text query and visual data res_dict = build_query(doc) query = res_dict["query"] # Get all images based on image tokens in the query image_tokens = re.findall(r"", query) image_tokens = sorted(list(set([image_token.strip("<>").replace(" ", "_") for image_token in image_tokens]))) visuals = [doc[image_token].convert("RGB") for image_token in image_tokens] # Replace EMMA's tokens with generic tokens for splitting processed_query = replace_images_tokens(query) # Initialize message structure messages = [{"role": "user", "content": []}] # Split text by tokens interleaved_content = processed_query.split("") # Build interleaved content for i, (image, text) in enumerate(zip(visuals, interleaved_content)): # Add text part if not empty if text.strip() != "": messages[0]["content"].append({"type": "text", "text": text.strip()}) # Add image part messages[0]["content"].append({"type": "image", "url": image}) # There will be one more text part after the last image if len(interleaved_content) > len(visuals) and interleaved_content[-1].strip(): messages[0]["content"].append({"type": "text", "text": interleaved_content[-1].strip()}) return messages def emma_process_results(doc, results): key_name = "emma_score" for pred in results: res_dict = build_query(doc) gt = doc["answer"] query = res_dict["query"] if config["metadata"]["use_lmms_judge"]: # Use LMM judge to evaluate the prediction submit_prompt = create_test_prompt(score_demo_prompt, doc, pred) try: # Create a Request object for the unified judge API from lmms_eval.llm_judge.protocol import Request request = Request(messages=[{"role": "user", "content": submit_prompt}], config=server_config) # Send the request to the LMM judge server judge_response_obj = server.evaluate(request) judge_response = judge_response_obj.content judge_result = judge_response.strip().lower() # Parse the judge result to determine correctness is_correct = "correct" in judge_result and "incorrect" not in judge_result emma_submission = {"id": doc["pid"], "query": query, "gt_content": gt, "pred": pred, "subject": doc["subject"], "category": doc["category"], "judge_response": judge_response, "is_correct": is_correct} except Exception as e: eval_logger.error(f"Error using LMM judge: {e}") # Fallback to fast_extract_answer if judge fails pred_extracted = fast_extract_answer(pred) is_correct = is_equal(pred_extracted, gt) emma_submission = {"id": doc["pid"], "query": query, "gt_content": gt, "pred": pred, "subject": doc["subject"], "category": doc["category"], "judge_error": str(e), "is_correct": is_correct} else: # for no lmms judge, use fast_extract_answer only pred = fast_extract_answer(pred) emma_submission = {"id": doc["pid"], "query": query, "gt_content": gt, "pred": pred, "subject": doc["subject"], "category": doc["category"], "is_correct": is_equal(pred, gt)} # Note: the key name here is very important. It decides which aggregation function will receive the results # We note down the question id/category to help us aggregate the results later return {key_name: emma_submission} def emma_aggregate_results(results): subject_to_eval_samples = defaultdict(list) category_to_eval_samples = defaultdict(list) total_samples = len(results) total_correct = 0 for sample in results: subject = sample["subject"] category = sample["category"] # Check if using LMM judge results or traditional evaluation if "is_correct" in sample: # Use LMM judge result is_correct = sample["is_correct"] else: # Use traditional evaluation method is_correct = is_equal(sample["pred"], sample["gt_content"]) if is_correct: total_correct += 1 subject_to_eval_samples[subject].append(1) category_to_eval_samples[category].append(1) else: subject_to_eval_samples[subject].append(0) category_to_eval_samples[category].append(0) accuracy = total_correct / total_samples if total_samples > 0 else 0 subject_accuracies = {subject: sum(scores) / len(scores) for subject, scores in subject_to_eval_samples.items()} category_accuracies = {category: sum(scores) / len(scores) for category, scores in category_to_eval_samples.items()} print(f"{'Total Samples':<20}: {total_samples}") print(f"{'Total Correct':<20}: {total_correct}") print(f"{'Overall Accuracy':<20}: {accuracy:.4f}") print() print(f"{'Per-Subject Accuracy':<40}") print("-" * 40) for subject, acc in subject_accuracies.items(): print(f"{subject:<20}: {acc:.4f}") print() print(f"{'Per-Category Accuracy':<40}") print("-" * 40) for category, acc in category_accuracies.items(): print(f"{category:<20}: {acc:.4f}") print("=" * 40) return accuracy ################################################# # Helper functions written by official EMMA repo. ################################################# def extract_full_boxed_content(s): """ https://github.com/EMMA-Bench/EMMA/blob/main/evaluation/utils.py#L22 Extract the full content inside \boxed{}, handling nested braces {{}} properly. """ results = [] i = 0 while i < len(s): if s[i : i + 7] == r"\boxed{": brace_stack = [] start = i + 7 i = start while i < len(s): if s[i] == "{": brace_stack.append(i) elif s[i] == "}": if brace_stack: brace_stack.pop() else: results.append(s[start:i]) break i += 1 i += 1 return results def is_number(s): """https://github.com/EMMA-Bench/EMMA/blob/main/evaluation/utils.py#L14""" try: float(s) return True except ValueError: return False def is_equal(md_ans, gt_ans): # https://github.com/EMMA-Bench/EMMA/blob/main/evaluation/utils.py#L50 md_ans = md_ans.lower() gt_ans = gt_ans.lower() if md_ans.strip() == gt_ans.strip(): return True try: md_ans_cache = str(w2n.word_to_num(md_ans)) if md_ans_cache.strip() == gt_ans.strip(): return True except ValueError: pass # For Math try: # Parse LaTeX expressions into sympy and compare numerical values md_sympy = latex2sympy(md_ans) gt_sympy = latex2sympy(gt_ans) # Compare evaluated results, rounded to 2 decimal places if round(float(md_sympy.evalf()), 2) == round(float(gt_sympy.evalf()), 2): return True # Additionally, compare simplified symbolic expressions if simplify(md_sympy - gt_sympy) == 0: return True except Exception: pass # Ignore parsing errors or evaluation failures return False def fast_extract_answer(response): """ https://github.com/EMMA-Bench/EMMA/blob/main/evaluation/evaluate.py#L11 """ response = response.strip() # Direct Strategy Multi-Choice # A / A: / A. for ch in "ABCDEFGH": if response.upper() == ch or response.startswith(f"{ch}:") or response.startswith(f"{ch}."): return ch # Direct Strategy Open-ended # 1 if is_number(response): return response # CoT strategy if "boxed{" in response: try: model_answers = extract_full_boxed_content(response) if model_answers: # for coding # \\boxed{\\text{}} try: text_content = re.findall(r"\\text{(.*?)}", model_answers[-1]) if text_content: return text_content[-1].strip() except Exception: pass return model_answers[-1].strip() except Exception: pass # for Coding # the correct answer is\n D. for flag in ["final answer is", "correct answer is", "answer should be", "answer is", "answer:"]: if flag in response.lower(): try: model_answer = response.lower().split(flag)[-1].strip() return model_answer.split("\n")[0].split(".")[0] except Exception: pass return "" def create_test_prompt(score_prompt, problem, pred): """ https://github.com/EMMA-Bench/EMMA/blob/main/evaluation/evaluate.py#L54 """ score_prompt = score_prompt.strip() # response = problem[label] response = pred answer = problem["answer"] full_prompt = f"{score_prompt}\n" + f"Response: {response}\n" + f"Answer: {answer}\n" + "Correct_or_not:" return full_prompt score_demo_prompt = """Please read the following example. Then determine whether the response is correct and type it at the end of the prompt. It is worth noting that the final answer in the response is usually in \\boxed{}, You only need to compare the final answer in the response with the answer, without considering the logical correctness of the response itself. Response: The correct answer is:\n\nA Answer: A Correct_or_not: Correct Response: The correct option is:\n\n\\[\n\\boxed{E}\n\\] Answer: C Correct_or_not: Incorrect """ # Alternative implementation using base64 encoding (currently not used, adopted from EMMA repo) # def create_message(sample): # """ # Alternative implementation for creating messages with base64 encoded images. # This would require implementing encode_image_to_base64 function. # """ # query = sample["query"] # all_contents = [] # matches = re.findall(r"<(image_\d+)>", query) # split_text = re.split(r"", query) # for i, fragment in enumerate(split_text): # if fragment.strip(): # all_contents.extend([{"type": "text", "text": fragment}]) # if i < len(matches): # if sample[matches[i]]: # # Would need to implement encode_image_to_base64 function # # img_base64 = encode_image_to_base64(sample[matches[i]]) # # all_contents.extend([{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}}]) # all_contents.extend([{"type": "image", "url": sample[matches[i]]}]) # else: # eval_logger.error(f"The image token {matches[i]} is in the query, but there is no corresponding image provided by the data") # # messages = [{"role": "user", "content": all_contents}] # return messages