Spaces:
Running on Zero
Running on Zero
File size: 14,378 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 | import ast
import json
import random
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any
import numpy as np
from datasets import Dataset, DatasetDict
from sympy.core.random import seed
from linalg_zero.generator.models import Question
from linalg_zero.grpo.verify import parse_string, verify_answers
from linalg_zero.shared.lib import get_lib
from linalg_zero.shared.utils import get_logger
logger = get_logger(__name__)
def load_entropy_settings(path: str) -> dict[str, Any]:
"""
Load entropy analysis results from JSON file.
"""
results_file = Path(path)
if not results_file.exists():
raise FileNotFoundError(
f"Entropy analysis results not found at {results_file}. "
"Run the entropy analysis first using linalg_zero.generator.analysis.analyse"
)
with results_file.open("r", encoding="utf-8") as f:
settings = json.load(f)
return settings
def print_entropy_settings(settings: dict[str, Any]) -> None:
"""Print the loaded entropy settings in a readable format."""
logger.info("=" * 60)
logger.info("RECOMMENDED ENTROPY SETTINGS")
logger.info("=" * 60)
for problem_type, config in settings.items():
combination = config["combination"]
score = config["score"]
logger.info(f"{problem_type}: {combination} (score: {score:.2f})")
logger.info("=" * 60)
def _verify_step_result(step: dict[str, Any], lib: dict[str, Any]) -> None:
"""Verify a single step's result against library function output."""
step_id = step["step_id"]
if "result" not in step:
raise ValueError(f"Step {step_id} has no result - implementation is bugged")
result_value = parse_string(step["result"])
if result_value is None:
raise ValueError(f"Step {step_id}: invalid result - implementation is bugged")
fn_type = step["tool"]
lib_fn = lib[fn_type]
input_data = json.loads(step["verification"]["input"])
fn_result = lib_fn(**input_data)
if not verify_answers(result_value, fn_result):
raise ValueError(f"Step mismatch - step - {json.dumps(step)} - lib_fn - {fn_type}")
def _verify_step_dependencies(step: dict[str, Any], question_stepwise: list[dict[str, Any]]) -> None:
"""Verify step dependencies against referenced steps."""
step_id = step.get("step_id", "unknown")
dependent_on = step["verification"].get("dependent_on", None)
if dependent_on is None:
return
if not isinstance(dependent_on, dict):
raise TypeError(f"Step {step_id}: dependent_on must be a dict, got {type(dependent_on)}")
# Verify each input_* field against its corresponding referenced step's result
for input_name, input_value in step["verification"].items():
if input_name.startswith("input_"):
expected_step_index = dependent_on[input_name]
# Validate the reference step exists
if not isinstance(expected_step_index, int):
raise TypeError(
f"Step {step_id}: dependency index for '{input_name}' must be an integer, got {type(expected_step_index)}"
)
if expected_step_index < 0 or expected_step_index >= len(question_stepwise):
raise ValueError(
f"Step {step_id}: dependent_on index {expected_step_index} for '{input_name}' out of bounds "
f"(stepwise has {len(question_stepwise)} steps)"
)
referenced_step = question_stepwise[expected_step_index]
referenced_result = parse_string(referenced_step["result"])
if referenced_result is None:
raise ValueError(f"Step {step_id}: referenced step {expected_step_index} has invalid result")
field_value = json.loads(input_value)
if not verify_answers(field_value, referenced_result) or field_value != referenced_result:
raise ValueError(
f"Step {step_id}: dependency verification failed - "
f"{input_name} ({field_value}) does not match referenced step {expected_step_index} result ({referenced_result})"
)
def _verify_golden_answer(question: Question, question_index: int) -> None:
"""Verify the golden answer matches the final stepwise result."""
if not question.golden or "final_answer" not in question.golden:
raise ValueError(f"Question {question_index} has no golden final answer - implementation is bugged")
golden_value = parse_string(question.golden["final_answer"])
answer_value = parse_string(question.stepwise[-1]["result"])
if golden_value is None:
raise ValueError(f"Question {question_index}: invalid golden answer - implementation is bugged")
if answer_value is None:
raise ValueError(f"Question {question_index}: invalid formatted answer - implementation is bugged")
if not verify_answers(golden_value, answer_value):
raise ValueError(
f"Question {question_index}: Golden answer mismatch - implementation is bugged. "
f"Golden={golden_value}, Answer={answer_value}"
)
def verify_dataset(dataset: list[Question]) -> dict[str, Any]:
"""
Verify a dataset of questions by checking constituent ground truths and target values.
"""
# NOTE: this function is temporary
verification_results = {
"total_questions": len(dataset),
"verified_questions": 0,
"stepwise_verifications": 0,
"golden_verifications": 0,
}
lib = get_lib()
for i, question in enumerate(dataset):
if len(question.stepwise) == 0:
raise ValueError(f"Question {i} has no stepwise results - implementation is bugged")
# Verify stepwise results
for step in question.stepwise:
_verify_step_result(step, lib)
_verify_step_dependencies(step, question.stepwise)
verification_results["stepwise_verifications"] += 1
# Verify golden answer
_verify_golden_answer(question, i)
verification_results["golden_verifications"] += 1
verification_results["verified_questions"] += 1
logger.info(
"Dataset verification complete: All %d questions verified successfully (%d stepwise checks, %d golden checks)",
verification_results["total_questions"],
verification_results["stepwise_verifications"],
verification_results["golden_verifications"],
)
return verification_results
def check_constraints(dataset: list[Question], config: dict[str, Any], statistics: dict[str, Any]) -> None:
"""Check the constraints for the given config."""
for problem_type, stats in statistics.get("per_problem_type", {}).items():
if problem_type in config:
actual_min = stats.get("min")
actual_max = stats.get("max")
expected_min = config[problem_type]["metadata"].get("target_min_value")
expected_max = config[problem_type]["metadata"].get("target_max_value")
if actual_min is not None and expected_min is not None and actual_min < expected_min:
raise ValueError(f"{problem_type}: min {actual_min} < expected {expected_min}")
if actual_max is not None and expected_max is not None and actual_max > expected_max:
raise ValueError(f"{problem_type}: max {actual_max} > expected {expected_max}")
def set_seed(seed_val: int = 42) -> None:
"""Set the seed for the deterministic generation."""
random.seed(seed_val)
np.random.seed(seed_val)
seed(seed_val)
def print_dataset(questions: list[Question], include_invalid: bool = False) -> None: # pragma: no cover
"""Display a formatted dataset of questions."""
questions_to_print = questions if include_invalid else [q for q in questions if q.is_valid]
if not questions_to_print:
logger.info("No questions to display.")
return
logger.info("=" * 30)
logger.info("GENERATED DATASET")
logger.info("=" * 30)
# Questions
for i, question in enumerate(questions_to_print, 1):
status = " [INVALID]" if not question.is_valid else ""
logger.info("Question %d:%s", i, status)
logger.info("Q: %s", question.question)
logger.info("A: %s", ast.literal_eval(question.answer))
logger.info("")
# Metadata
topics = {q.topic for q in questions_to_print}
problem_types = {q.problem_type for q in questions_to_print}
difficulties = {q.difficulty for q in questions_to_print}
entropy_values = [q.entropy_used for q in questions_to_print]
tool_calls = [q.tool_calls_required for q in questions_to_print]
# Summary
logger.info("Dataset Summary:")
logger.info(" Total Questions: %d", len(questions_to_print))
logger.info(" Topics: %s", ", ".join(sorted(topic.value for topic in topics)))
logger.info(" Problem Types: %s", ", ".join(sorted(pt.value for pt in problem_types)))
logger.info(" Difficulties: %s", ", ".join(sorted(str(difficulty) for difficulty in difficulties)))
logger.info(
" Entropy Used: %.2f - %.2f (avg: %.2f)",
min(entropy_values),
max(entropy_values),
sum(entropy_values) / len(entropy_values),
)
logger.info(
" Tool Calls Required: %d - %d (avg: %.1f)",
min(tool_calls),
max(tool_calls),
sum(tool_calls) / len(tool_calls),
)
# Distributions
by_difficulty = Counter(q.difficulty for q in questions_to_print)
logger.info(" By Difficulty:")
for diff, count in sorted(
by_difficulty.items(), key=lambda x: x[0].value if hasattr(x[0], "value") else str(x[0])
):
logger.info(" %s: %d", str(diff), count)
# Per-difficulty averages (entropy and tool calls)
buckets: dict = defaultdict(list)
for q in questions_to_print:
buckets[q.difficulty].append(q)
logger.info(" Per-Difficulty Averages:")
for diff, qs in sorted(buckets.items(), key=lambda x: x[0].value if hasattr(x[0], "value") else str(x[0])):
avg_entropy = sum(q.entropy_used for q in qs) / len(qs)
avg_tool_calls = sum(q.tool_calls_required for q in qs) / len(qs)
logger.info(" %s -> entropy avg: %.2f, tool calls avg: %.2f", str(diff), avg_entropy, avg_tool_calls)
logger.info("=" * 30)
def print_split_statistics(splits: DatasetDict) -> None: # pragma: no cover
"""Display statistics about train/validation/test splits."""
logger.info("=" * 30)
logger.info("DATASET SPLITS")
logger.info("=" * 30)
total = sum(len(splits[split_name]) for split_name in splits)
for split_name in ["train", "validation", "test"]:
if split_name in splits:
count = len(splits[split_name])
percentage = 100 * count / total if total > 0 else 0
logger.info(" %s: %d samples (%.1f%%)", split_name.capitalize(), count, percentage)
logger.info(" Total: %d samples", total)
logger.info("=" * 30)
def _question_to_example(q: Question) -> dict[str, Any]:
"""Map a Question to a flat example for Hugging Face datasets."""
stepwise_truths: list[dict[str, Any]] = []
for step in q.stepwise:
tool_name = step.get("tool")
result_value = parse_string(step.get("result"))
if tool_name is None or result_value is None:
continue
stepwise_truths.append({tool_name: result_value})
# Derive composition metadata from stepwise verification and problem_type
composition_type = "sequential" if len(q.stepwise) > 1 else "single"
dependency_edges: list[tuple[int, int]] = []
for idx, step in enumerate(q.stepwise):
verification = step.get("verification", {})
if isinstance(verification, dict):
dependent_on = verification.get("dependent_on")
if isinstance(dependent_on, dict):
for _, from_idx in dependent_on.items():
if isinstance(from_idx, int):
dependency_edges.append((from_idx, idx))
dependency_type = "strict"
return {
"query": q.question,
"ground_truth": q.golden.get("final_answer", q.answer),
"stepwise_ground_truths": json.dumps(stepwise_truths),
"difficulty": getattr(q.difficulty, "name", str(q.difficulty)),
"problem_type": getattr(q.problem_type, "value", str(q.problem_type)),
"composition_type": composition_type,
"composition_dependencies": dependency_type,
"dependency_edges": json.dumps(dependency_edges) if dependency_edges else None,
}
def convert_to_dataset_dict(questions: list[Question]) -> DatasetDict:
"""Convert questions to a single-split DatasetDict (train)."""
examples = [_question_to_example(q) for q in questions if q.is_valid]
return DatasetDict({"train": Dataset.from_list(examples)})
def convert_to_dataset_splits(
questions: list[Question],
test_size: float = 0.1,
val_size: float = 0.1,
seed: int = 42,
stratify_by: str | None = None,
) -> DatasetDict:
"""Create train/validation/test DatasetDict using HF's split utilities."""
examples = [_question_to_example(q) for q in questions if q.is_valid]
ds = Dataset.from_list(examples).shuffle(seed=seed)
stratify_column = stratify_by if stratify_by in ds.column_names else None
# Convert stratification column to ClassLabel if needed
if stratify_column and ds.features[stratify_column]._type != "ClassLabel":
from datasets import ClassLabel
unique_values = ds.unique(stratify_column)
ds = ds.cast_column(stratify_column, ClassLabel(names=sorted(unique_values)))
split = ds.train_test_split(test_size=test_size, seed=seed, stratify_by_column=stratify_column)
if val_size and val_size > 0:
# Adjust val proportion relative to remaining train portion
relative_val = val_size / (1 - test_size)
train_val = split["train"].train_test_split(
test_size=relative_val, seed=seed, stratify_by_column=stratify_column
)
return DatasetDict(train=train_val["train"], validation=train_val["test"], test=split["test"]) # type: ignore[reportCallIssue]
return DatasetDict(train=split["train"], test=split["test"]) # type: ignore[reportCallIssue]
|