| | import sys |
| | import os |
| | from datetime import datetime |
| | from constants.paths import * |
| |
|
| | from models.Gemini import Gemini |
| | from models.OpenAI import OpenAIModel |
| |
|
| | from results.Results import Results |
| |
|
| | from promptings.PromptingFactory import PromptingFactory |
| | from datasets.DatasetFactory import DatasetFactory |
| | from models.ModelFactory import ModelFactory |
| |
|
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "--dataset", |
| | type=str, |
| | default="HumanEval", |
| | choices=[ |
| | "HumanEval", |
| | "MBPP", |
| | "APPS", |
| | "xCodeEval", |
| | "CC", |
| | ] |
| | ) |
| | parser.add_argument( |
| | "--strategy", |
| | type=str, |
| | default="MapCoder", |
| | choices=[ |
| | "Direct", |
| | "CoT", |
| | "SelfPlanning", |
| | "Analogical", |
| | "MapCoder", |
| | "DebateCoder", |
| | ] |
| | ) |
| | parser.add_argument( |
| | "--model", |
| | type=str, |
| | default="ChatGPT", |
| | choices=[ |
| | "ChatGPT", |
| | "GPT4", |
| | "Gemini", |
| | "DeepSeek", |
| | "Pangu", |
| | "Qwen", |
| | "Pangu72B", |
| | ] |
| | ) |
| | parser.add_argument( |
| | "--temperature", |
| | type=float, |
| | default=0 |
| | ) |
| | parser.add_argument( |
| | "--pass_at_k", |
| | type=int, |
| | default=1 |
| | ) |
| | parser.add_argument( |
| | "--language", |
| | type=str, |
| | default="Python3", |
| | choices=[ |
| | "C", |
| | "C#", |
| | "C++", |
| | "Go", |
| | "PHP", |
| | "Python3", |
| | "Ruby", |
| | "Rust", |
| | ] |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | DATASET = args.dataset |
| | STRATEGY = args.strategy |
| | MODEL_NAME = args.model |
| | TEMPERATURE = args.temperature |
| | PASS_AT_K = args.pass_at_k |
| | LANGUAGE = args.language |
| |
|
| | RUN_NAME = f"{MODEL_NAME}-{STRATEGY}-{DATASET}-{LANGUAGE}-{TEMPERATURE}-{PASS_AT_K}" |
| | RESULTS_PATH = f"./outputs/{RUN_NAME}.jsonl" |
| |
|
| | print(f"#########################\nRunning start {RUN_NAME}, Time: {datetime.now()}\n##########################\n") |
| |
|
| | |
| | model_class = ModelFactory.get_model_class(MODEL_NAME) |
| | model_instance = model_class(temperature=TEMPERATURE) |
| |
|
| | |
| | def _format_model_info(model_obj): |
| | info_lines = [] |
| | info_lines.append(f"model_class={model_obj.__class__.__name__}") |
| | |
| | for k, v in getattr(model_obj, "__dict__", {}).items(): |
| | try: |
| | if isinstance(v, (str, int, float, bool)): |
| | info_lines.append(f"{k}={v}") |
| | else: |
| | info_lines.append(f"{k}=<{type(v).__name__}>") |
| | except Exception: |
| | info_lines.append(f"{k}=<unrepr>") |
| | return "; ".join(info_lines) |
| |
|
| | model_info_str = _format_model_info(model_instance) |
| |
|
| | |
| | print(f"[MODEL INFO] {model_info_str}") |
| |
|
| | strategy = PromptingFactory.get_prompting_class(STRATEGY)( |
| | model=model_instance, |
| | data=DatasetFactory.get_dataset_class(DATASET)(), |
| | language=LANGUAGE, |
| | pass_at_k=PASS_AT_K, |
| | results=Results(RESULTS_PATH), |
| | ) |
| |
|
| | strategy.run() |
| |
|
| | print(f"#########################\nRunning end {RUN_NAME}, Time: {datetime.now()}\n##########################\n") |
| |
|
| |
|