DebateCoder / src /main.py
smallan13's picture
Upload 105 files
01f199c verified
raw
history blame
3.04 kB
import sys
import os
from datetime import datetime
from constants.paths import *
from models.Gemini import Gemini
from models.OpenAI import OpenAIModel
from results.Results import Results
from promptings.PromptingFactory import PromptingFactory
from datasets.DatasetFactory import DatasetFactory
from models.ModelFactory import ModelFactory
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="HumanEval",
choices=[
"HumanEval",
"MBPP",
"APPS",
"xCodeEval",
"CC",
]
)
parser.add_argument(
"--strategy",
type=str,
default="MapCoder",
choices=[
"Direct",
"CoT",
"SelfPlanning",
"Analogical",
"MapCoder",
"DebateCoder",
]
)
parser.add_argument(
"--model",
type=str,
default="ChatGPT",
choices=[
"ChatGPT",
"GPT4",
"Gemini",
"DeepSeek",
"Pangu",
"Qwen",
"Pangu72B",
]
)
parser.add_argument(
"--temperature",
type=float,
default=0
)
parser.add_argument(
"--pass_at_k",
type=int,
default=1
)
parser.add_argument(
"--language",
type=str,
default="Python3",
choices=[
"C",
"C#",
"C++",
"Go",
"PHP",
"Python3",
"Ruby",
"Rust",
]
)
args = parser.parse_args()
DATASET = args.dataset
STRATEGY = args.strategy
MODEL_NAME = args.model
TEMPERATURE = args.temperature
PASS_AT_K = args.pass_at_k
LANGUAGE = args.language
RUN_NAME = f"{MODEL_NAME}-{STRATEGY}-{DATASET}-{LANGUAGE}-{TEMPERATURE}-{PASS_AT_K}"
RESULTS_PATH = f"./outputs/{RUN_NAME}.jsonl"
print(f"#########################\nRunning start {RUN_NAME}, Time: {datetime.now()}\n##########################\n")
# Instantiate model separately so we can log its info
model_class = ModelFactory.get_model_class(MODEL_NAME)
model_instance = model_class(temperature=TEMPERATURE)
# Build a compact model info string
def _format_model_info(model_obj):
info_lines = []
info_lines.append(f"model_class={model_obj.__class__.__name__}")
# include simple attributes
for k, v in getattr(model_obj, "__dict__", {}).items():
try:
if isinstance(v, (str, int, float, bool)):
info_lines.append(f"{k}={v}")
else:
info_lines.append(f"{k}=<{type(v).__name__}>")
except Exception:
info_lines.append(f"{k}=<unrepr>")
return "; ".join(info_lines)
model_info_str = _format_model_info(model_instance)
# also print to console
print(f"[MODEL INFO] {model_info_str}")
strategy = PromptingFactory.get_prompting_class(STRATEGY)(
model=model_instance,
data=DatasetFactory.get_dataset_class(DATASET)(),
language=LANGUAGE,
pass_at_k=PASS_AT_K,
results=Results(RESULTS_PATH),
)
strategy.run()
print(f"#########################\nRunning end {RUN_NAME}, Time: {datetime.now()}\n##########################\n")