atomwalk12's picture
initial commit
0dd6c2f
import argparse
import itertools
import json
import warnings
from pathlib import Path
from typing import Any, cast
import numpy as np
import linalg_zero.generator.difficulty_config as dc
from linalg_zero.generator.analysis.utils import (
compute_stepwise_value_statistics,
extract_all_numerical_values,
extract_report_metadata,
extract_values_by_combination,
plot_combination_comparison,
plot_overall_histogram,
plot_target_compliance,
print_statistics_summary,
rank_entropy_combinations,
)
from linalg_zero.generator.core import DatasetGenerator
from linalg_zero.generator.difficulty_config import (
DETERMINISTIC_MODE,
determine_difficulty,
get_problem_config,
)
from linalg_zero.generator.models import Question, Task, Topic
from linalg_zero.generator.registry import (
FactoryRegistry,
register_problem_type,
)
from linalg_zero.generator.utils import set_seed
from linalg_zero.shared.utils import get_log_file_path, get_logger, setup_logging
MIN_VALUE_ABS = 2
STEP_SIZE = 0.1
SAMPLES_PER_TEST = 8000
DEFAULT_ENTROPY_JITTER = 0.2
WRITE_PER_PROBLEM_RANKINGS = True
PROBLEM_DIR = Path("results") / "entropy_analysis"
ALL_ENTROPY_RANGES = {
# 1-Tool Call Problems (Foundation Layer)
Task.ONE_DETERMINANT: {
"entropy_ranges": (0.6, 1.2),
"target_min": -500,
"target_max": 500,
},
Task.ONE_TRACE: {
"entropy_ranges": (1.5, 2.5),
"target_min": -200,
"target_max": 200,
},
Task.ONE_FROBENIUS_NORM: {
"entropy_ranges": (2.5, 3.3),
"target_min": 0,
"target_max": 600, # Always positive
},
Task.ONE_RANK: {
"entropy_ranges": (2.5, 3.0),
"target_min": 1,
"target_max": 3,
},
Task.ONE_TRANSPOSE: {
"entropy_ranges": (3.0, 3.5),
"target_min": -800,
"target_max": 800,
},
Task.ONE_COFACTOR: {
"entropy_ranges": (1.4, 2.0),
"target_min": -800,
"target_max": 800,
},
# 2-Tool Call Problems (Sequential Reasoning)
Task.TWO_TRANSPOSE_DETERMINANT: {
"target_min": -400,
"target_max": 400,
Task.ONE_TRANSPOSE: (0.7, 1.8),
Task.ONE_DETERMINANT: (0.0, 0.0),
},
Task.TWO_COFACTOR_TRACE: {
"target_min": -800,
"target_max": 800,
Task.ONE_COFACTOR: (1.0, 2.0),
Task.ONE_TRACE: (0.0, 0.0),
},
Task.TWO_COFACTOR_RANK: {
"target_min": -800,
"target_max": 800,
Task.ONE_COFACTOR: (1.0, 2.0),
Task.ONE_RANK: (0.0, 0.0),
},
Task.TWO_TRANSPOSE_FROBENIUS: {
"target_min": -800,
"target_max": 800,
Task.ONE_TRANSPOSE: (1.8, 3.2),
Task.ONE_FROBENIUS_NORM: (0.0, 0.0),
},
# 3-Tool Call Problems (Advanced Sequential/Fan-out)
Task.THREE_TRANSPOSE_COFACTOR_RANK: {
"target_min": -800,
"target_max": 800,
Task.ONE_TRANSPOSE: (2.5, 3.6),
Task.ONE_COFACTOR: (0.0, 0.0),
Task.ONE_RANK: (0.0, 0.0),
},
Task.THREE_COFACTOR_TRANSPOSE_TRACE: {
"target_min": -800,
"target_max": 800,
Task.ONE_COFACTOR: (2.5, 3.6),
Task.ONE_TRANSPOSE: (0.0, 0.0),
Task.ONE_TRACE: (0.0, 0.0),
},
Task.THREE_TRANSPOSE_COFACTOR_FROBENIUS: {
"target_min": -800,
"target_max": 800,
Task.ONE_TRANSPOSE: (2.8, 3.3),
Task.ONE_COFACTOR: (0.0, 0.0),
Task.ONE_FROBENIUS_NORM: (0.0, 0.0),
},
}
setup_logging()
logger = get_logger(__name__)
class EntropyOptimizer:
def __init__(self, registry: FactoryRegistry, topic: Topic):
self.registry: FactoryRegistry = registry
self.generator = DatasetGenerator(topic=topic, registry=registry)
def execute(
self,
component_entropy_ranges: dict[Task, tuple[float, float]],
problem_type: Task,
) -> dict[tuple[float, ...], list[Question]]:
# Get correct config based on problem type (based on whether it is a 1, 2, or 3 step problem)
difficulty = determine_difficulty(problem_type)
# Generate all combinations of entropy values
grid_points = {}
for component, (min_val, max_val) in component_entropy_ranges.items():
grid_points[component] = np.arange(min_val, max_val + STEP_SIZE / 2, STEP_SIZE)
for component, values in grid_points.items():
grid_points[component] = np.round(values, 1)
# Create all combinations
component_names = list(grid_points.keys())
value_combinations = list(itertools.product(*[grid_points[name] for name in component_names]))
logger.info(f"Total configurations to test: {len(value_combinations)}")
dataset_by_combination = {}
failed_configurations = []
logger.info(f"Testing {problem_type.value} with {SAMPLES_PER_TEST} questions per configuration")
for i, combination in enumerate(value_combinations):
logger.info(f"Testing configuration {combination} {i + 1}/{len(value_combinations)}")
entropy_ranges = {
Task(component_name): combination_value
for component_name, combination_value in zip(component_names, combination, strict=True)
}
try:
register_problem_type(
self.registry, problem_type, entropy_ranges, DEFAULT_ENTROPY_JITTER, MIN_VALUE_ABS
)
split = self.generator.generate_exact_for_categories(
requests={
difficulty: SAMPLES_PER_TEST,
}
)
dataset_by_combination[combination] = split
except RuntimeError as e:
logger.warning(f"Failed to generate for configuration {combination}: {e}")
failed_configurations.append((combination, str(e)))
continue
# Log summary of failures
if failed_configurations:
logger.warning(
f"Failed to generate for {len(failed_configurations)} out of {len(value_combinations)} configurations"
)
for combination, error in failed_configurations:
logger.debug(f"Failed configuration {combination}: {error}")
else:
logger.info("All configurations generated successfully")
return dataset_by_combination
def execute_analysis(
topic: Topic,
problem_type: Task,
component_entropy_ranges: dict[Task, tuple[float, float]],
target_min: float,
target_max: float,
print_individual_stats: bool = True,
) -> tuple[dict[tuple, dict[str, Any]], dict[str, Any]]:
logger.info(f"Optimizing entropy for: {problem_type}")
logger.info("This will systematically test different component-wise entropy configurations.")
logger.info("Configuration:")
logger.info(f" Component ranges: {component_entropy_ranges}")
logger.info(f" Samples per test: {SAMPLES_PER_TEST}")
logger.info(f" Target max value: {target_max}")
logger.info(f" Target min value: {target_min}")
config = get_problem_config(determine_difficulty(problem_type))
logger.info(f" Problem config: {config}")
registry = FactoryRegistry()
optimizer = EntropyOptimizer(registry, topic)
dataset = optimizer.execute(
component_entropy_ranges=component_entropy_ranges,
problem_type=problem_type,
)
statistics = {}
for combination, split in dataset.items():
statistics[combination] = compute_stepwise_value_statistics(split)
if print_individual_stats:
for combination, stats in statistics.items():
logger.info(f"Combination: {combination}")
print_statistics_summary(stats)
# After computing statistics, rank combinations and optionally write per-problem report + plots
PROBLEM_DIR.mkdir(parents=True, exist_ok=True)
# Ranking based on raw values
ranked = rank_entropy_combinations(
statistics,
target_min=target_min,
target_max=target_max,
)
top_k = min(10, len(ranked))
report = {
"top": ranked[:top_k],
}
# Prepare per-problem directory path regardless of flag to avoid unbound warnings
per_problem_dir = PROBLEM_DIR / problem_type.value
if WRITE_PER_PROBLEM_RANKINGS:
per_problem_dir.mkdir(parents=True, exist_ok=True)
with (per_problem_dir / "ranking.json").open("w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
logger.info(f"ranking saved to: {per_problem_dir / 'ranking.json'}")
logger.info(f"number of top combinations: {len(ranked)}")
# Optional plots (use raw values)
all_values = extract_all_numerical_values(statistics, use_min_max=False)
values_by_combination = extract_values_by_combination(statistics, use_min_max=False)
if WRITE_PER_PROBLEM_RANKINGS:
plot_overall_histogram(all_values, output_dir=per_problem_dir, target_min=target_min, target_max=target_max)
plot_combination_comparison(values_by_combination, output_dir=per_problem_dir)
plot_target_compliance(statistics, output_dir=per_problem_dir, target_min=target_min, target_max=target_max)
return statistics, report
def main() -> None:
warnings.filterwarnings("ignore", category=UserWarning)
all_reports = {}
all_statistics = {}
for problem_type, config in ALL_ENTROPY_RANGES.items():
config_dict = cast(dict[str, Any], config)
target_min = config_dict["target_min"]
target_max = config_dict["target_max"]
# Parse simple or nested config
if "entropy_ranges" in config_dict:
# Simple config
entropy_ranges = config_dict["entropy_ranges"]
component_entropy_ranges = {problem_type: entropy_ranges}
else:
# Nested config
component_entropy_ranges = {}
for key, value in config_dict.items():
if not isinstance(key, str):
component_entropy_ranges[key] = value
statistics, report = execute_analysis(
topic=Topic.LINEAR_ALGEBRA,
problem_type=problem_type,
component_entropy_ranges=component_entropy_ranges,
target_min=target_min,
target_max=target_max,
print_individual_stats=False,
)
all_reports[problem_type] = report
all_statistics[problem_type] = statistics
logger.info("=" * 80)
logger.info("FINAL ANALYSIS SUMMARY - ALL PROBLEM TYPES")
logger.info("=" * 80)
top_choices = {}
for problem_type, report in all_reports.items():
logger.info(f"{problem_type.value.upper()} RESULTS:")
logger.info("-" * 50)
# Print top combinations for this problem type
if report["top"]:
config = ALL_ENTROPY_RANGES[problem_type]
config_dict = cast(dict[str, Any], config)
entropy_config = config_dict.get("entropy_ranges", config_dict)
top_choices[problem_type.value] = extract_report_metadata(
top_choice=report["top"][0],
problem_type=problem_type,
entropy_config=entropy_config,
min_value_abs=MIN_VALUE_ABS,
entropy_jitter=DEFAULT_ENTROPY_JITTER,
step_size=STEP_SIZE,
samples_per_test=SAMPLES_PER_TEST,
target_min_value=config_dict["target_min"],
target_max_value=config_dict["target_max"],
)
top = 5
logger.info(f"Top entropy combinations (closest to targets) {top}/{len(report['top'])}:")
for i, entry in enumerate(report["top"][:top]): # Show top 5
combination = entry["combination"]
score = entry["score"]
overall_min = entry["overall_min"]
overall_max = entry["overall_max"]
overall_min_abs = entry["min_abs"]
count = entry["count"]
logger.info(
f" {i + 1}. {combination} -> score={score:.2f}, range=[{overall_min:.2f}, {overall_max:.2f}], min_abs={overall_min_abs:.2f}, count={count}"
)
# Write consolidated top choices for all problem types
consolidated_path = PROBLEM_DIR / "top_entropy_choices.json"
with consolidated_path.open("w", encoding="utf-8") as f:
json.dump(top_choices, f, indent=2)
logger.info(f"Consolidated top choices saved to: {consolidated_path}")
logger.info(f"Log file path: {get_log_file_path()}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
argv = parser.parse_args()
if argv.seed is not None:
set_seed(argv.seed)
if DETERMINISTIC_MODE:
dc.DETERMINISTIC_BASE_SEED = int(argv.seed)
main()