Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| pwd = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(pwd, "../../")) | |
| from google import genai | |
| from google.genai import types | |
| from project_settings import environment, project_path | |
| def get_args(): | |
| """ | |
| python3 eval_gemini_google.py --model_name gemini-2.5-pro --eval_result eval_math_result_gemini-2.5-pro.jsonl | |
| python3 eval_gemini_google.py --model_name gemini-2.5-flash --eval_result eval_math_result_gemini-2.5-flash.jsonl | |
| python3 eval_gemini_google.py --model_name gemini-2.5-flash-lite-preview-06-17 --eval_result eval_math_result_gemini-2.5-flash-lite-preview-06-17.jsonl | |
| :return: | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--google_application_credentials", | |
| default=(project_path / "dotenv/potent-veld-462405-t3-8091a29b2894.json").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--model_name", | |
| # default="gemini-2.5-pro", | |
| # default="gemini-2.5-flash", | |
| default="gemini-2.5-flash-lite-preview-06-17", | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--eval_data", | |
| default=(project_path / "data/arc-easy.jsonl").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--eval_result", | |
| default=(project_path / "data/eval_math_result.jsonl").as_posix(), | |
| type=str | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = get_args() | |
| os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = args.google_application_credentials | |
| client = genai.Client( | |
| vertexai=True, | |
| project="potent-veld-462405-t3", | |
| location="global", | |
| ) | |
| generate_content_config = types.GenerateContentConfig( | |
| temperature=1, | |
| top_p=0.95, | |
| max_output_tokens=8192, | |
| response_modalities=["TEXT"], | |
| ) | |
| total = 0 | |
| total_correct = 0 | |
| # finished | |
| finished_idx_set = set() | |
| if os.path.exists(args.eval_result): | |
| with open(args.eval_result, "r", encoding="utf-8") as f: | |
| for row in f: | |
| row = json.loads(row) | |
| idx = row["id"] | |
| total = row["total"] | |
| total_correct = row["total_correct"] | |
| finished_idx_set.add(idx) | |
| print(f"finished count: {len(finished_idx_set)}") | |
| with open(args.eval_data, "r", encoding="utf-8") as fin, open(args.eval_result, "a+", encoding="utf-8") as fout: | |
| for row in fin: | |
| if total > 20: | |
| break | |
| row = json.loads(row) | |
| idx = row["id"] | |
| question = row["question"] | |
| choices = row["choices"] | |
| answer_key = row["answerkey"] | |
| if idx in finished_idx_set: | |
| continue | |
| finished_idx_set.add(idx) | |
| instruct = "Complete this single-choice question." | |
| choices_str = "" | |
| for choice in choices: | |
| label = choice["label"] | |
| text = choice["text"] | |
| choices_str += f"If you think the answer is `{text}` output: `{label}`\n" | |
| prompt = f""" | |
| {instruct} | |
| Question: | |
| {question} | |
| Choices: | |
| {choices_str} | |
| Remember to output ONLY the corresponding letter. | |
| Your output is: | |
| """.strip() | |
| # print(prompt) | |
| contents = [ | |
| types.Content( | |
| role="user", | |
| parts=[ | |
| types.Part.from_text(text=prompt) | |
| ] | |
| ) | |
| ] | |
| time_begin = time.time() | |
| response: types.GenerateContentResponse = client.models.generate_content( | |
| model=args.model_name, | |
| contents=contents, | |
| config=generate_content_config, | |
| ) | |
| time_cost = time.time() - time_begin | |
| print(time_cost) | |
| try: | |
| prediction = response.candidates[0].content.parts[0].text | |
| except TypeError: | |
| continue | |
| correct = 1 if prediction == answer_key else 0 | |
| total += 1 | |
| total_correct += correct | |
| score = total_correct / total | |
| row_ = { | |
| "id": idx, | |
| "question": question, | |
| "choices": choices, | |
| "ground_true": answer_key, | |
| "prediction": prediction, | |
| "correct": correct, | |
| "total": total, | |
| "total_correct": total_correct, | |
| "score": score, | |
| "time_cost": time_cost, | |
| } | |
| row_ = json.dumps(row_, ensure_ascii=False) | |
| fout.write(f"{row_}\n") | |
| # print(f"score: {score}") | |
| return | |
| if __name__ == "__main__": | |
| main() | |