Spaces:
Running on Zero
Running on Zero
| import argparse | |
| import itertools | |
| import json | |
| import warnings | |
| from pathlib import Path | |
| from typing import Any, cast | |
| import numpy as np | |
| import linalg_zero.generator.difficulty_config as dc | |
| from linalg_zero.generator.analysis.utils import ( | |
| compute_stepwise_value_statistics, | |
| extract_all_numerical_values, | |
| extract_report_metadata, | |
| extract_values_by_combination, | |
| plot_combination_comparison, | |
| plot_overall_histogram, | |
| plot_target_compliance, | |
| print_statistics_summary, | |
| rank_entropy_combinations, | |
| ) | |
| from linalg_zero.generator.core import DatasetGenerator | |
| from linalg_zero.generator.difficulty_config import ( | |
| DETERMINISTIC_MODE, | |
| determine_difficulty, | |
| get_problem_config, | |
| ) | |
| from linalg_zero.generator.models import Question, Task, Topic | |
| from linalg_zero.generator.registry import ( | |
| FactoryRegistry, | |
| register_problem_type, | |
| ) | |
| from linalg_zero.generator.utils import set_seed | |
| from linalg_zero.shared.utils import get_log_file_path, get_logger, setup_logging | |
| MIN_VALUE_ABS = 2 | |
| STEP_SIZE = 0.1 | |
| SAMPLES_PER_TEST = 8000 | |
| DEFAULT_ENTROPY_JITTER = 0.2 | |
| WRITE_PER_PROBLEM_RANKINGS = True | |
| PROBLEM_DIR = Path("results") / "entropy_analysis" | |
| ALL_ENTROPY_RANGES = { | |
| # 1-Tool Call Problems (Foundation Layer) | |
| Task.ONE_DETERMINANT: { | |
| "entropy_ranges": (0.6, 1.2), | |
| "target_min": -500, | |
| "target_max": 500, | |
| }, | |
| Task.ONE_TRACE: { | |
| "entropy_ranges": (1.5, 2.5), | |
| "target_min": -200, | |
| "target_max": 200, | |
| }, | |
| Task.ONE_FROBENIUS_NORM: { | |
| "entropy_ranges": (2.5, 3.3), | |
| "target_min": 0, | |
| "target_max": 600, # Always positive | |
| }, | |
| Task.ONE_RANK: { | |
| "entropy_ranges": (2.5, 3.0), | |
| "target_min": 1, | |
| "target_max": 3, | |
| }, | |
| Task.ONE_TRANSPOSE: { | |
| "entropy_ranges": (3.0, 3.5), | |
| "target_min": -800, | |
| "target_max": 800, | |
| }, | |
| Task.ONE_COFACTOR: { | |
| "entropy_ranges": (1.4, 2.0), | |
| "target_min": -800, | |
| "target_max": 800, | |
| }, | |
| # 2-Tool Call Problems (Sequential Reasoning) | |
| Task.TWO_TRANSPOSE_DETERMINANT: { | |
| "target_min": -400, | |
| "target_max": 400, | |
| Task.ONE_TRANSPOSE: (0.7, 1.8), | |
| Task.ONE_DETERMINANT: (0.0, 0.0), | |
| }, | |
| Task.TWO_COFACTOR_TRACE: { | |
| "target_min": -800, | |
| "target_max": 800, | |
| Task.ONE_COFACTOR: (1.0, 2.0), | |
| Task.ONE_TRACE: (0.0, 0.0), | |
| }, | |
| Task.TWO_COFACTOR_RANK: { | |
| "target_min": -800, | |
| "target_max": 800, | |
| Task.ONE_COFACTOR: (1.0, 2.0), | |
| Task.ONE_RANK: (0.0, 0.0), | |
| }, | |
| Task.TWO_TRANSPOSE_FROBENIUS: { | |
| "target_min": -800, | |
| "target_max": 800, | |
| Task.ONE_TRANSPOSE: (1.8, 3.2), | |
| Task.ONE_FROBENIUS_NORM: (0.0, 0.0), | |
| }, | |
| # 3-Tool Call Problems (Advanced Sequential/Fan-out) | |
| Task.THREE_TRANSPOSE_COFACTOR_RANK: { | |
| "target_min": -800, | |
| "target_max": 800, | |
| Task.ONE_TRANSPOSE: (2.5, 3.6), | |
| Task.ONE_COFACTOR: (0.0, 0.0), | |
| Task.ONE_RANK: (0.0, 0.0), | |
| }, | |
| Task.THREE_COFACTOR_TRANSPOSE_TRACE: { | |
| "target_min": -800, | |
| "target_max": 800, | |
| Task.ONE_COFACTOR: (2.5, 3.6), | |
| Task.ONE_TRANSPOSE: (0.0, 0.0), | |
| Task.ONE_TRACE: (0.0, 0.0), | |
| }, | |
| Task.THREE_TRANSPOSE_COFACTOR_FROBENIUS: { | |
| "target_min": -800, | |
| "target_max": 800, | |
| Task.ONE_TRANSPOSE: (2.8, 3.3), | |
| Task.ONE_COFACTOR: (0.0, 0.0), | |
| Task.ONE_FROBENIUS_NORM: (0.0, 0.0), | |
| }, | |
| } | |
| setup_logging() | |
| logger = get_logger(__name__) | |
| class EntropyOptimizer: | |
| def __init__(self, registry: FactoryRegistry, topic: Topic): | |
| self.registry: FactoryRegistry = registry | |
| self.generator = DatasetGenerator(topic=topic, registry=registry) | |
| def execute( | |
| self, | |
| component_entropy_ranges: dict[Task, tuple[float, float]], | |
| problem_type: Task, | |
| ) -> dict[tuple[float, ...], list[Question]]: | |
| # Get correct config based on problem type (based on whether it is a 1, 2, or 3 step problem) | |
| difficulty = determine_difficulty(problem_type) | |
| # Generate all combinations of entropy values | |
| grid_points = {} | |
| for component, (min_val, max_val) in component_entropy_ranges.items(): | |
| grid_points[component] = np.arange(min_val, max_val + STEP_SIZE / 2, STEP_SIZE) | |
| for component, values in grid_points.items(): | |
| grid_points[component] = np.round(values, 1) | |
| # Create all combinations | |
| component_names = list(grid_points.keys()) | |
| value_combinations = list(itertools.product(*[grid_points[name] for name in component_names])) | |
| logger.info(f"Total configurations to test: {len(value_combinations)}") | |
| dataset_by_combination = {} | |
| failed_configurations = [] | |
| logger.info(f"Testing {problem_type.value} with {SAMPLES_PER_TEST} questions per configuration") | |
| for i, combination in enumerate(value_combinations): | |
| logger.info(f"Testing configuration {combination} {i + 1}/{len(value_combinations)}") | |
| entropy_ranges = { | |
| Task(component_name): combination_value | |
| for component_name, combination_value in zip(component_names, combination, strict=True) | |
| } | |
| try: | |
| register_problem_type( | |
| self.registry, problem_type, entropy_ranges, DEFAULT_ENTROPY_JITTER, MIN_VALUE_ABS | |
| ) | |
| split = self.generator.generate_exact_for_categories( | |
| requests={ | |
| difficulty: SAMPLES_PER_TEST, | |
| } | |
| ) | |
| dataset_by_combination[combination] = split | |
| except RuntimeError as e: | |
| logger.warning(f"Failed to generate for configuration {combination}: {e}") | |
| failed_configurations.append((combination, str(e))) | |
| continue | |
| # Log summary of failures | |
| if failed_configurations: | |
| logger.warning( | |
| f"Failed to generate for {len(failed_configurations)} out of {len(value_combinations)} configurations" | |
| ) | |
| for combination, error in failed_configurations: | |
| logger.debug(f"Failed configuration {combination}: {error}") | |
| else: | |
| logger.info("All configurations generated successfully") | |
| return dataset_by_combination | |
| def execute_analysis( | |
| topic: Topic, | |
| problem_type: Task, | |
| component_entropy_ranges: dict[Task, tuple[float, float]], | |
| target_min: float, | |
| target_max: float, | |
| print_individual_stats: bool = True, | |
| ) -> tuple[dict[tuple, dict[str, Any]], dict[str, Any]]: | |
| logger.info(f"Optimizing entropy for: {problem_type}") | |
| logger.info("This will systematically test different component-wise entropy configurations.") | |
| logger.info("Configuration:") | |
| logger.info(f" Component ranges: {component_entropy_ranges}") | |
| logger.info(f" Samples per test: {SAMPLES_PER_TEST}") | |
| logger.info(f" Target max value: {target_max}") | |
| logger.info(f" Target min value: {target_min}") | |
| config = get_problem_config(determine_difficulty(problem_type)) | |
| logger.info(f" Problem config: {config}") | |
| registry = FactoryRegistry() | |
| optimizer = EntropyOptimizer(registry, topic) | |
| dataset = optimizer.execute( | |
| component_entropy_ranges=component_entropy_ranges, | |
| problem_type=problem_type, | |
| ) | |
| statistics = {} | |
| for combination, split in dataset.items(): | |
| statistics[combination] = compute_stepwise_value_statistics(split) | |
| if print_individual_stats: | |
| for combination, stats in statistics.items(): | |
| logger.info(f"Combination: {combination}") | |
| print_statistics_summary(stats) | |
| # After computing statistics, rank combinations and optionally write per-problem report + plots | |
| PROBLEM_DIR.mkdir(parents=True, exist_ok=True) | |
| # Ranking based on raw values | |
| ranked = rank_entropy_combinations( | |
| statistics, | |
| target_min=target_min, | |
| target_max=target_max, | |
| ) | |
| top_k = min(10, len(ranked)) | |
| report = { | |
| "top": ranked[:top_k], | |
| } | |
| # Prepare per-problem directory path regardless of flag to avoid unbound warnings | |
| per_problem_dir = PROBLEM_DIR / problem_type.value | |
| if WRITE_PER_PROBLEM_RANKINGS: | |
| per_problem_dir.mkdir(parents=True, exist_ok=True) | |
| with (per_problem_dir / "ranking.json").open("w", encoding="utf-8") as f: | |
| json.dump(report, f, indent=2) | |
| logger.info(f"ranking saved to: {per_problem_dir / 'ranking.json'}") | |
| logger.info(f"number of top combinations: {len(ranked)}") | |
| # Optional plots (use raw values) | |
| all_values = extract_all_numerical_values(statistics, use_min_max=False) | |
| values_by_combination = extract_values_by_combination(statistics, use_min_max=False) | |
| if WRITE_PER_PROBLEM_RANKINGS: | |
| plot_overall_histogram(all_values, output_dir=per_problem_dir, target_min=target_min, target_max=target_max) | |
| plot_combination_comparison(values_by_combination, output_dir=per_problem_dir) | |
| plot_target_compliance(statistics, output_dir=per_problem_dir, target_min=target_min, target_max=target_max) | |
| return statistics, report | |
| def main() -> None: | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| all_reports = {} | |
| all_statistics = {} | |
| for problem_type, config in ALL_ENTROPY_RANGES.items(): | |
| config_dict = cast(dict[str, Any], config) | |
| target_min = config_dict["target_min"] | |
| target_max = config_dict["target_max"] | |
| # Parse simple or nested config | |
| if "entropy_ranges" in config_dict: | |
| # Simple config | |
| entropy_ranges = config_dict["entropy_ranges"] | |
| component_entropy_ranges = {problem_type: entropy_ranges} | |
| else: | |
| # Nested config | |
| component_entropy_ranges = {} | |
| for key, value in config_dict.items(): | |
| if not isinstance(key, str): | |
| component_entropy_ranges[key] = value | |
| statistics, report = execute_analysis( | |
| topic=Topic.LINEAR_ALGEBRA, | |
| problem_type=problem_type, | |
| component_entropy_ranges=component_entropy_ranges, | |
| target_min=target_min, | |
| target_max=target_max, | |
| print_individual_stats=False, | |
| ) | |
| all_reports[problem_type] = report | |
| all_statistics[problem_type] = statistics | |
| logger.info("=" * 80) | |
| logger.info("FINAL ANALYSIS SUMMARY - ALL PROBLEM TYPES") | |
| logger.info("=" * 80) | |
| top_choices = {} | |
| for problem_type, report in all_reports.items(): | |
| logger.info(f"{problem_type.value.upper()} RESULTS:") | |
| logger.info("-" * 50) | |
| # Print top combinations for this problem type | |
| if report["top"]: | |
| config = ALL_ENTROPY_RANGES[problem_type] | |
| config_dict = cast(dict[str, Any], config) | |
| entropy_config = config_dict.get("entropy_ranges", config_dict) | |
| top_choices[problem_type.value] = extract_report_metadata( | |
| top_choice=report["top"][0], | |
| problem_type=problem_type, | |
| entropy_config=entropy_config, | |
| min_value_abs=MIN_VALUE_ABS, | |
| entropy_jitter=DEFAULT_ENTROPY_JITTER, | |
| step_size=STEP_SIZE, | |
| samples_per_test=SAMPLES_PER_TEST, | |
| target_min_value=config_dict["target_min"], | |
| target_max_value=config_dict["target_max"], | |
| ) | |
| top = 5 | |
| logger.info(f"Top entropy combinations (closest to targets) {top}/{len(report['top'])}:") | |
| for i, entry in enumerate(report["top"][:top]): # Show top 5 | |
| combination = entry["combination"] | |
| score = entry["score"] | |
| overall_min = entry["overall_min"] | |
| overall_max = entry["overall_max"] | |
| overall_min_abs = entry["min_abs"] | |
| count = entry["count"] | |
| logger.info( | |
| f" {i + 1}. {combination} -> score={score:.2f}, range=[{overall_min:.2f}, {overall_max:.2f}], min_abs={overall_min_abs:.2f}, count={count}" | |
| ) | |
| # Write consolidated top choices for all problem types | |
| consolidated_path = PROBLEM_DIR / "top_entropy_choices.json" | |
| with consolidated_path.open("w", encoding="utf-8") as f: | |
| json.dump(top_choices, f, indent=2) | |
| logger.info(f"Consolidated top choices saved to: {consolidated_path}") | |
| logger.info(f"Log file path: {get_log_file_path()}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--seed", type=int, default=42) | |
| argv = parser.parse_args() | |
| if argv.seed is not None: | |
| set_seed(argv.seed) | |
| if DETERMINISTIC_MODE: | |
| dc.DETERMINISTIC_BASE_SEED = int(argv.seed) | |
| main() | |