from __future__ import annotations import hashlib import random import re from typing import Any from env.generator import HIDDEN_TEST_COUNT, TOTAL_TEST_CASES, VISIBLE_TEST_COUNT MAX_IO_CHARS = 4096 DEFAULT_DATASET_NAME = "deepmind/code_contests" DEFAULT_SPLIT = "train" DEFAULT_MAX_PROBLEMS = 5000 def _load_raw_dataset( dataset_name: str, split: str = DEFAULT_SPLIT, max_problems: int = DEFAULT_MAX_PROBLEMS, ) -> list[dict[str, Any]]: from datasets import load_dataset dataset = load_dataset(dataset_name, split=split, trust_remote_code=True) rows: list[dict[str, Any]] = [] for raw_row in dataset: row = dict(raw_row) statement = str(row.get("description") or row.get("question") or "").strip() if not statement: continue if not _extract_pairs(row.get("public_tests")): continue if not (_extract_pairs(row.get("private_tests")) or _extract_pairs(row.get("generated_tests"))): continue rows.append(row) if len(rows) >= int(max_problems): break return rows def _normalise_row(raw_row: dict[str, Any], dataset_name: str) -> dict[str, Any] | None: statement = _extract_problem_statement(raw_row) if not statement: return None public_pairs = _extract_pairs(raw_row.get("public_tests")) private_pairs = _extract_pairs(raw_row.get("private_tests")) generated_pairs = _extract_pairs(raw_row.get("generated_tests")) visible_pairs = public_pairs[:VISIBLE_TEST_COUNT] hidden_pairs = (private_pairs + generated_pairs)[:HIDDEN_TEST_COUNT] if len(visible_pairs) != VISIBLE_TEST_COUNT or len(hidden_pairs) != HIDDEN_TEST_COUNT: return None test_cases: list[dict[str, Any]] = [] seen_inputs: set[str] = set() for index, (raw_input, raw_output) in enumerate(visible_pairs + hidden_pairs): if len(str(raw_output or "")) > MAX_IO_CHARS: return None normalized_input = _normalize_io_text(raw_input, ensure_trailing_newline=True) normalized_output = _normalize_io_text(raw_output, ensure_trailing_newline=False) if not normalized_input or normalized_input in seen_inputs: return None seen_inputs.add(normalized_input) test_cases.append( { "input": normalized_input, "output": normalized_output, "is_visible": index < VISIBLE_TEST_COUNT, } ) if len(test_cases) != TOTAL_TEST_CASES: return None difficulty_label, difficulty_value = _difficulty_fields(raw_row, dataset_name) input_format = _extract_section(statement, "input") or "Read from stdin." constraints = _extract_constraints(statement) problem_type = _infer_problem_type(raw_row, statement) problem_id = _problem_id(raw_row, dataset_name) visible_examples = [dict(test_case) for test_case in test_cases[:VISIBLE_TEST_COUNT]] return { "problem_id": problem_id, "problem_type": problem_type, "difficulty": difficulty_value, "difficulty_label": difficulty_label, "problem": statement, "input_format": input_format, "constraints": constraints, "test_cases": test_cases, "visible_problem": { "problem": statement, "input_format": input_format, "constraints": constraints, "examples": visible_examples, }, "generation_mode": "dataset", "validity_bonus": 1.0, } class DatasetProblemBank: def __init__( self, dataset_name: str = DEFAULT_DATASET_NAME, split: str = DEFAULT_SPLIT, max_problems: int = DEFAULT_MAX_PROBLEMS, ) -> None: self.dataset_name = dataset_name self.split = split self.max_problems = int(max_problems) self._by_difficulty: dict[str, list[dict[str, Any]]] = { "easy": [], "medium": [], "hard": [], } self._by_id: dict[str, dict[str, Any]] = {} raw_rows = _load_raw_dataset(dataset_name=dataset_name, split=split, max_problems=max_problems) for raw_row in raw_rows: normalized = _normalise_row(raw_row, dataset_name) if normalized is None: continue problem_id = str(normalized["problem_id"]) if problem_id in self._by_id: continue difficulty = str(normalized.get("difficulty_label", "medium")).lower() if difficulty not in self._by_difficulty: difficulty = "medium" normalized["difficulty_label"] = difficulty stored = _copy_problem(normalized) self._by_difficulty[difficulty].append(stored) self._by_id[problem_id] = stored if not self._by_id: raise ValueError( f"No usable problems were found in dataset `{dataset_name}` split `{split}` with max_problems={max_problems}." ) def sample(self, difficulty: str, rng: random.Random, recent_types: list[str]) -> dict[str, Any] | None: requested = str(difficulty).strip().lower() candidates = list(self._by_difficulty.get(requested, [])) if not candidates: candidates = [problem for bucket in self._by_difficulty.values() for problem in bucket] if not candidates: return None recent = {problem_type for problem_type in recent_types[-3:] if problem_type} diverse = [problem for problem in candidates if str(problem.get("problem_type", "")) not in recent] pool = diverse or candidates return _copy_problem(rng.choice(pool)) def all_problem_ids(self) -> list[str]: return sorted(self._by_id) def get_by_id(self, problem_id: str) -> dict[str, Any]: return _copy_problem(self._by_id[str(problem_id)]) def problem_types_for_difficulty(self, difficulty: str) -> list[str]: requested = str(difficulty).strip().lower() candidates = self._by_difficulty.get(requested, []) return sorted({str(problem.get("problem_type", "")) for problem in candidates if problem.get("problem_type")}) _BANK: DatasetProblemBank | None = None _BANK_CONFIG: tuple[str, str, int] | None = None def get_problem_bank(**kwargs: Any) -> DatasetProblemBank: global _BANK, _BANK_CONFIG dataset_name = str(kwargs.get("dataset_name", DEFAULT_DATASET_NAME)) split = str(kwargs.get("split", DEFAULT_SPLIT)) max_problems = int(kwargs.get("max_problems", DEFAULT_MAX_PROBLEMS)) config = (dataset_name, split, max_problems) if _BANK is None or _BANK_CONFIG != config: _BANK = DatasetProblemBank( dataset_name=dataset_name, split=split, max_problems=max_problems, ) _BANK_CONFIG = config return _BANK def _extract_problem_statement(raw_row: dict[str, Any]) -> str: value = raw_row.get("description") or raw_row.get("question") or raw_row.get("problem") return str(value or "").strip() def _extract_pairs(raw_value: Any) -> list[tuple[str, str]]: pairs: list[tuple[str, str]] = [] if raw_value is None: return pairs if isinstance(raw_value, dict): inputs = raw_value.get("input") or raw_value.get("inputs") or raw_value.get("stdin") or [] outputs = raw_value.get("output") or raw_value.get("outputs") or raw_value.get("stdout") or [] if isinstance(inputs, str): inputs = [inputs] if isinstance(outputs, str): outputs = [outputs] for raw_input, raw_output in zip(list(inputs), list(outputs)): pairs.append((str(raw_input), str(raw_output))) return pairs if isinstance(raw_value, list): for item in raw_value: if isinstance(item, dict): raw_input = item.get("input") or item.get("stdin") or item.get("in") raw_output = item.get("output") or item.get("stdout") or item.get("out") if raw_input is None or raw_output is None: continue pairs.append((str(raw_input), str(raw_output))) elif isinstance(item, (list, tuple)) and len(item) >= 2: pairs.append((str(item[0]), str(item[1]))) return pairs return pairs def _normalize_io_text(value: Any, *, ensure_trailing_newline: bool) -> str: text = str(value or "")[:MAX_IO_CHARS] if ensure_trailing_newline: text = text.rstrip("\n") return f"{text}\n" if text else "\n" return text.strip() def _difficulty_fields(raw_row: dict[str, Any], dataset_name: str) -> tuple[str, float]: dataset_key = dataset_name.lower() if "code_contests" in dataset_key or "code-contests" in dataset_key: raw_difficulty = raw_row.get("difficulty") try: rating = float(raw_difficulty) except (TypeError, ValueError): return "medium", 0.5 normalized = max(0.0, min((rating - 800.0) / (3500.0 - 800.0), 1.0)) if rating <= 1200: label = "easy" elif rating <= 1800: label = "medium" else: label = "hard" return label, round(normalized, 4) raw_difficulty = str(raw_row.get("difficulty") or "").strip().lower() if raw_difficulty in {"easy", "medium", "hard"}: return raw_difficulty, {"easy": 0.25, "medium": 0.5, "hard": 0.75}[raw_difficulty] return "medium", 0.5 def _problem_id(raw_row: dict[str, Any], dataset_name: str) -> str: prefix = "cc" if "code_contests" in dataset_name.lower() else "ds" for key in ("problem_id", "id", "name", "source"): value = raw_row.get(key) if value is not None and str(value).strip(): candidate = re.sub(r"[^a-zA-Z0-9_]+", "_", str(value).strip()).strip("_") if candidate: return f"{prefix}_{candidate}" digest = hashlib.sha256(repr(sorted(raw_row.items())).encode("utf-8")).hexdigest() return f"{prefix}_{digest[:12]}" def _extract_section(statement: str, heading: str) -> str: pattern = re.compile( rf"{heading}\s*:?[\r\n]+(.*?)(?=\n[A-Z][A-Za-z ]{{1,30}}:?[\r\n]|\Z)", flags=re.IGNORECASE | re.DOTALL, ) match = pattern.search(statement) return match.group(1).strip() if match else "" def _extract_constraints(statement: str) -> str: constraints = _extract_section(statement, "constraints") return constraints or "See problem statement." def _infer_problem_type(raw_row: dict[str, Any], statement: str) -> str: parts: list[str] = [statement] tags = raw_row.get("tags") if isinstance(tags, list): parts.extend(str(tag) for tag in tags) elif isinstance(tags, str): parts.append(tags) for key in ("source", "name"): value = raw_row.get(key) if value: parts.append(str(value)) text = " ".join(parts).lower() keyword_map = { "graph": "graph", "tree": "tree", "dynamic programming": "dp", " dp ": "dp", "string": "string", "array": "array", "greedy": "greedy", "math": "math", "sort": "sorting", "binary search": "search", } padded = f" {text} " for needle, problem_type in keyword_map.items(): haystack = padded if needle.startswith(" ") and needle.endswith(" ") else text if needle in haystack: return problem_type return "implementation" def _copy_problem(problem: dict[str, Any]) -> dict[str, Any]: copied = dict(problem) copied["test_cases"] = [dict(test_case) for test_case in problem.get("test_cases", [])] copied["visible_problem"] = dict(problem.get("visible_problem", {})) examples = copied["visible_problem"].get("examples") if isinstance(examples, list): copied["visible_problem"]["examples"] = [dict(example) for example in examples] return copied