|
|
|
|
|
|
|
|
import json |
|
|
import zlib |
|
|
import pickle |
|
|
import base64 |
|
|
from enum import Enum |
|
|
from datetime import datetime |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
class Platform(Enum): |
|
|
LEETCODE = "leetcode" |
|
|
CODEFORCES = "codeforces" |
|
|
ATCODER = "atcoder" |
|
|
|
|
|
|
|
|
class Difficulty(Enum): |
|
|
EASY = "easy" |
|
|
MEDIUM = "medium" |
|
|
HARD = "hard" |
|
|
|
|
|
|
|
|
class TestType(Enum): |
|
|
STDIN = "stdin" |
|
|
FUNCTIONAL = "functional" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Test: |
|
|
input: str |
|
|
output: str |
|
|
testtype: TestType |
|
|
|
|
|
def __post_init__(self): |
|
|
self.testtype = TestType(self.testtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CodeGenerationProblem: |
|
|
question_title: str |
|
|
question_content: str |
|
|
platform: Platform |
|
|
question_id: str |
|
|
contest_id: str |
|
|
contest_date: datetime |
|
|
starter_code: str |
|
|
difficulty: Difficulty |
|
|
public_test_cases: list[Test] |
|
|
private_test_cases: list[Test] |
|
|
metadata: dict |
|
|
|
|
|
def __post_init__(self): |
|
|
self.platform = Platform(self.platform) |
|
|
self.difficulty = Difficulty(self.difficulty) |
|
|
self.contest_date = datetime.fromisoformat(self.contest_date) |
|
|
|
|
|
self.public_test_cases = json.loads(self.public_test_cases) |
|
|
self.public_test_cases = [Test(**t) for t in self.public_test_cases] |
|
|
|
|
|
try: |
|
|
self.private_test_cases = json.loads(self.private_test_cases) |
|
|
except Exception: |
|
|
self.private_test_cases = json.loads( |
|
|
pickle.loads( |
|
|
zlib.decompress( |
|
|
base64.b64decode(self.private_test_cases.encode("utf-8")) |
|
|
) |
|
|
) |
|
|
) |
|
|
self.private_test_cases = [Test(**t) for t in self.private_test_cases] |
|
|
|
|
|
self.metadata = json.loads(self.metadata) |
|
|
|
|
|
def insert_output(self, output_list: list[str], code_list: list[str]) -> dict: |
|
|
return { |
|
|
"question_title": self.question_title, |
|
|
"question_content": self.question_content, |
|
|
"platform": self.platform.value, |
|
|
"question_id": self.question_id, |
|
|
"contest_id": self.contest_id, |
|
|
"contest_date": self.contest_date.isoformat(), |
|
|
"starter_code": self.starter_code, |
|
|
"difficulty": self.difficulty.value, |
|
|
"output_list": output_list, |
|
|
"code_list": code_list, |
|
|
} |
|
|
|
|
|
def insert_output_evaluation( |
|
|
self, |
|
|
output_list: list[str], |
|
|
code_list: list[str], |
|
|
graded_list: list[bool], |
|
|
**kwargs, |
|
|
) -> dict: |
|
|
output = self.insert_output(output_list, code_list) |
|
|
output["graded_list"] = graded_list |
|
|
output["pass@1"] = graded_list.count(True) / len(graded_list) |
|
|
for k, v in kwargs.items(): |
|
|
output[k] = v |
|
|
return output |
|
|
|
|
|
def get_evaluation_sample(self): |
|
|
return { |
|
|
"input_output": json.dumps( |
|
|
{ |
|
|
"inputs": [ |
|
|
t.input |
|
|
for t in self.public_test_cases + self.private_test_cases |
|
|
], |
|
|
"outputs": [ |
|
|
t.output |
|
|
for t in self.public_test_cases + self.private_test_cases |
|
|
], |
|
|
"fn_name": self.metadata.get("func_name", None), |
|
|
} |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
def load_code_generation_dataset(release_version="release_v1", cache_dir: str = None, start_date=None, end_date=None) -> list[CodeGenerationProblem]: |
|
|
dataset = load_dataset("livecodebench/code_generation_lite", split="test", version_tag=release_version, trust_remote_code=True, cache_dir=cache_dir) |
|
|
dataset = [CodeGenerationProblem(**p) for p in dataset] |
|
|
if start_date is not None: |
|
|
p_start_date = datetime.strptime(start_date, "%Y-%m-%d") |
|
|
dataset = [e for e in dataset if p_start_date <= e.contest_date] |
|
|
|
|
|
if end_date is not None: |
|
|
p_end_date = datetime.strptime(end_date, "%Y-%m-%d") |
|
|
dataset = [e for e in dataset if e.contest_date <= p_end_date] |
|
|
|
|
|
|
|
|
return dataset |
|
|
|
|
|
|
|
|
def load_code_generation_dataset_not_fast(release_version="release_v1") -> list[CodeGenerationProblem]: |
|
|
dataset = load_dataset("livecodebench/code_generation", split="test") |
|
|
dataset = [CodeGenerationProblem(**p) for p in dataset] |
|
|
|
|
|
return dataset |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dataset = load_code_generation_dataset() |