import ast from pathlib import Path from typing import Any import matplotlib.pyplot as plt import numpy as np from matplotlib.ticker import MaxNLocator from linalg_zero.generator.models import DifficultyCategory, Question, Task from linalg_zero.shared.utils import get_logger logger = get_logger(__name__) def print_statistics_summary(statistics: dict[str, Any]) -> None: # pragma: no cover """Print a formatted summary of dataset statistics.""" if not statistics: logger.info("No statistics available.") return logger.info("=" * 50) logger.info("DATASET STATISTICS SUMMARY") logger.info("=" * 50) # Overall statistics overall_min = statistics.get("overall_min") overall_max = statistics.get("overall_max") overall_min_abs = statistics.get("overall_min_abs") logger.info(f"Overall Range: {overall_min} to {overall_max}") logger.info(f"Overall Min Absolute: {overall_min_abs}") # Per-step statistics per_step = statistics.get("per_step", {}) if per_step: logger.info("Per-Step Statistics:") for step_idx in sorted(per_step.keys()): step_stats = per_step[step_idx] logger.info( f" Step {step_idx}: min={step_stats.get('min')}, max={step_stats.get('max')}, min_abs={step_stats.get('min_abs')}, count={step_stats.get('count')}" ) # Per-problem-type statistics per_problem_type = statistics.get("per_problem_type", {}) if per_problem_type: logger.info("Per-Problem-Type Statistics:") for problem_type, type_stats in per_problem_type.items(): logger.info( f" {problem_type}: min={type_stats.get('min')}, max={type_stats.get('max')}, min_abs={type_stats.get('min_abs')}, count={type_stats.get('count')}" ) # Per-question statistics per_question = statistics.get("per_question", []) if per_question: logger.info(f"Per-Question Statistics: {len(per_question)} questions analyzed") # Show first few questions as examples for i, q_stats in enumerate(per_question[:3]): logger.info( f" Q{i + 1}: min={q_stats.get('min')}, max={q_stats.get('max')}, min_abs={q_stats.get('min_abs')}, count={q_stats.get('count')}" ) if len(per_question) > 3: logger.info(f" ... and {len(per_question) - 3} more questions") logger.info("=" * 50) def _extract_numeric_values_from_object(obj: Any) -> list[float]: """Recursively extract numeric values (as floats) from an arbitrary object.""" values: list[float] = [] if isinstance(obj, int | float): values.append(float(obj)) return values if isinstance(obj, complex): raise TypeError(f"Complex number found: {obj}") if isinstance(obj, list | tuple): for item in obj: values.extend(_extract_numeric_values_from_object(item)) return values if isinstance(obj, dict): for v in obj.values(): values.extend(_extract_numeric_values_from_object(v)) return values return values def compute_stepwise_value_statistics(questions: list[Question]) -> dict[str, Any]: """Scan stepwise results from all questions and compute statistics. Returns a dictionary with: - overall_min: float | None - overall_max: float | None - overall_min_abs: float | None - per_question: list of {index, min, max, min_abs, count} - per_step: dict[int, {min, max, min_abs, count}] aggregated across all questions by step index - per_problem_type: dict[str, {min, max, min_abs, count}] aggregated by problem type - all_values: flat list[float] of every numeric value encountered across all steps/questions """ overall_min: float | None = None overall_max: float | None = None overall_min_abs: float | None = None per_question: list[dict[str, Any]] = [] per_step: dict[int, dict[str, Any]] = {} per_problem_type: dict[str, dict[str, Any]] = {} all_values: list[float] = [] for q_index, question in enumerate(questions): q_min: float | None = None q_max: float | None = None q_min_abs: float | None = None q_count: int = 0 # Resolve problem type key once per question pt_key = getattr(question.problem_type, "value", str(question.problem_type)) for step_index, step in enumerate(question.stepwise): # Parse the step result into a Python object result_str = step.get("result") if result_str is None: raise ValueError(f"Step {step_index} has no result") parsed = ast.literal_eval(result_str) # Extract numeric values numeric_values = _extract_numeric_values_from_object(parsed) if not numeric_values: # Initialize per-step entry with zero count if not present if step_index not in per_step: per_step[step_index] = {"min": None, "max": None, "min_abs": None, "count": 0} continue step_min = min(numeric_values) step_max = max(numeric_values) step_min_abs = min(abs(v) for v in numeric_values) step_count = len(numeric_values) # Aggregate raw values all_values.extend(float(v) for v in numeric_values) # Update overall stats overall_min = step_min if overall_min is None else min(overall_min, step_min) overall_max = step_max if overall_max is None else max(overall_max, step_max) overall_min_abs = step_min_abs if overall_min_abs is None else min(overall_min_abs, step_min_abs) # Update question stats q_min = step_min if q_min is None else min(q_min, step_min) q_max = step_max if q_max is None else max(q_max, step_max) q_min_abs = step_min_abs if q_min_abs is None else min(q_min_abs, step_min_abs) q_count += step_count # Update per-step aggregated stats if step_index not in per_step: per_step[step_index] = {"min": step_min, "max": step_max, "min_abs": step_min_abs, "count": step_count} else: ps = per_step[step_index] ps_min = ps["min"] ps_max = ps["max"] ps_min_abs = ps["min_abs"] ps["min"] = step_min if ps_min is None else min(ps_min, step_min) ps["max"] = step_max if ps_max is None else max(ps_max, step_max) ps["min_abs"] = step_min_abs if ps_min_abs is None else min(ps_min_abs, step_min_abs) ps["count"] += step_count # Update per-problem-type aggregated stats if pt_key not in per_problem_type: per_problem_type[pt_key] = { "min": step_min, "max": step_max, "min_abs": step_min_abs, "count": step_count, } else: ppt = per_problem_type[pt_key] ppt_min = ppt["min"] ppt_max = ppt["max"] ppt_min_abs = ppt["min_abs"] ppt["min"] = step_min if ppt_min is None else min(ppt_min, step_min) ppt["max"] = step_max if ppt_max is None else max(ppt_max, step_max) ppt["min_abs"] = step_min_abs if ppt_min_abs is None else min(ppt_min_abs, step_min_abs) ppt["count"] += step_count per_question.append({"index": q_index, "min": q_min, "max": q_max, "min_abs": q_min_abs, "count": q_count}) return { "overall_min": overall_min, "overall_max": overall_max, "overall_min_abs": overall_min_abs if overall_min_abs is not None else 0.0, "per_question": per_question, "per_step": per_step, "per_problem_type": per_problem_type, "all_values": all_values, } def extract_all_numerical_values(statistics: dict[tuple, dict[str, Any]], use_min_max: bool = False) -> list[float]: """Extract all raw stepwise numerical values from all runs. Uses the aggregated `all_values` emitted by `compute_stepwise_value_statistics` for each combination rather than only per-question min/max boundaries. """ all_values: list[float] = [] for _, stats in statistics.items(): if use_min_max: values = [] per_question = stats.get("per_question", []) for q_stats in per_question: values.append(q_stats["min"]) values.append(q_stats["max"]) else: values = stats.get("all_values", []) all_values.extend(values) for value in all_values: assert value is not None, "Value is None" return all_values def extract_values_by_combination( statistics: dict[tuple, dict[str, Any]], use_min_max: bool = False ) -> dict[tuple, list[float]]: """Extract raw stepwise numerical values grouped by entropy combination.""" values_by_combination: dict[tuple, list[float]] = {} for combination, stats in statistics.items(): if use_min_max: values = [] per_question = stats.get("per_question", []) for q_stats in per_question: values.append(q_stats["min"]) values.append(q_stats["max"]) else: values = list(stats.get("all_values", [])) values_by_combination[combination] = values for key, items in values_by_combination.items(): assert all(value is not None for value in items), f"Values are None for combination {key}" return values_by_combination def plot_overall_histogram( all_values: list[float], target_min: float = -1000, target_max: float = 1000, output_dir: Path | None = None ) -> None: """Plot histogram of all numerical values across all runs.""" plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.hist(all_values, bins=50, alpha=0.7, color="skyblue", edgecolor="black") plt.axvline(target_min, color="red", linestyle="--", label=f"Target Min ({target_min})") plt.axvline(target_max, color="red", linestyle="--", label=f"Target Max ({target_max})") plt.xlabel("Numerical Values") plt.ylabel("Frequency") plt.title("Distribution of All Numerical Values") plt.legend() plt.grid(True, alpha=0.3) # Log scale version plt.subplot(1, 2, 2) # Filter out zero and negative values for log scale positive_values = [v for v in all_values if v > 0] if positive_values: plt.hist(positive_values, bins=50, alpha=0.7, color="lightcoral", edgecolor="black") plt.axvline(target_max, color="red", linestyle="--", label=f"Target Max ({target_max})") plt.xlabel("Numerical Values (log scale)") plt.ylabel("Frequency") plt.title("Distribution of Positive Values (Log Scale)") plt.xscale("log") plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() filename = "overall_distribution.png" filepath = output_dir / filename if output_dir else filename plt.savefig(filepath, dpi=300, bbox_inches="tight") plt.show() def plot_combination_comparison( values_by_combination: dict[tuple, list[float]], max_combinations: int = 12, output_dir: Path | None = None ) -> None: """Plot comparison of value distributions across different entropy combinations.""" # Limit to top combinations by number of values sorted_combinations = sorted(values_by_combination.items(), key=lambda x: len(x[1]), reverse=True)[ :max_combinations ] _, axes = plt.subplots(3, 4, figsize=(16, 12)) axes = axes.flatten() for i, (combination, values) in enumerate(sorted_combinations): if i >= len(axes): break ax = axes[i] if values: # Only plot if we have values ax.hist(values, bins=20, alpha=0.7, color=plt.colormaps["tab10"](i % 10), edgecolor="black") ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.set_title(f"Combo: {combination}", fontsize=10) ax.set_xlabel("Values") ax.set_ylabel("Freq") ax.grid(True, alpha=0.3) # Add statistics text if values: mean_val = np.mean(values) std_val = np.std(values) ax.text( 0.05, 0.95, f"μ={mean_val:.1f}\no={std_val:.1f}\nn={len(values)}", transform=ax.transAxes, verticalalignment="top", bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.8}, ) # Hide unused subplots for i in range(len(sorted_combinations), len(axes)): axes[i].set_visible(False) plt.suptitle("Value Distributions by Entropy Combination", fontsize=14) plt.tight_layout() filename = "by_combination.png" filepath = output_dir / filename if output_dir else filename plt.savefig(filepath, dpi=300, bbox_inches="tight") plt.show() def plot_target_compliance( statistics: dict[tuple, dict[str, Any]], target_min: float = -1000, target_max: float = 1000, output_dir: Path | None = None, ) -> None: """Plot how well each combination complies with target ranges.""" compliance_data = [] for combination, stats in statistics.items(): overall_min = stats.get("overall_min") overall_max = stats.get("overall_max") if overall_min is None or overall_max is None: raise ValueError(f"Overall min or max is None for combination {combination}") # overall_min/overall_max are guaranteed non-None above; avoid falsy-zero filtering within_range = (target_min <= overall_min) and (overall_max <= target_max) compliance_data.append({ "combination": str(combination), "min": overall_min, "max": overall_max, "compliant": within_range, }) if not compliance_data: logger.warning("No compliance data available for plotting") return compliant = [d for d in compliance_data if d["compliant"]] non_compliant = [d for d in compliance_data if not d["compliant"]] plt.figure(figsize=(12, 8)) # Plot compliant combinations if compliant: plt.scatter( [d["min"] for d in compliant], [d["max"] for d in compliant], c="green", alpha=0.7, s=100, label=f"Compliant ({len(compliant)})", ) # Plot non-compliant combinations if non_compliant: plt.scatter( [d["min"] for d in non_compliant], [d["max"] for d in non_compliant], c="red", alpha=0.7, s=100, label=f"Non-compliant ({len(non_compliant)})", ) # Add target range box plt.axhline(target_max, color="blue", linestyle="--", alpha=0.5, label=f"Target Max ({target_max})") plt.axvline(target_min, color="blue", linestyle="--", alpha=0.5, label=f"Target Min ({target_min})") plt.xlabel("Overall Minimum Value") plt.ylabel("Overall Maximum Value") plt.title("Entropy Combinations: Target Range Compliance") plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() filename = "compliance.png" filepath = output_dir / filename if output_dir else filename plt.savefig(filepath, dpi=300, bbox_inches="tight") plt.show() # Print summary total_combinations = len(compliance_data) compliant_count = len(compliant) logger.info("Target Compliance Summary:") logger.info(f" Total combinations: {total_combinations}") logger.info(f" Compliant: {compliant_count} ({compliant_count / total_combinations * 100:.1f}%)") logger.info( f" Non-compliant: {total_combinations - compliant_count} ({(total_combinations - compliant_count) / total_combinations * 100:.1f}%)" ) def extract_report_metadata( top_choice: dict[str, Any], problem_type: Task, entropy_config: tuple[float, float] | dict[Task, tuple[float, float]], min_value_abs: float, entropy_jitter: float, *, step_size: float, samples_per_test: int, target_min_value: float, target_max_value: float, ) -> dict[str, Any]: if isinstance(entropy_config, tuple): # Single-step problem is_single_step = True components = [problem_type.name] difficulty_category = DifficultyCategory.ONE_TOOL_CALL.name else: # Multi-step problem - get components from the dict keys is_single_step = False components = [task.name for task in entropy_config if isinstance(task, Task)] # Determine difficulty category based on number of components num_components = len(components) if num_components == 2: difficulty_category = DifficultyCategory.TWO_TOOL_CALLS.name elif num_components == 3: difficulty_category = DifficultyCategory.THREE_TOOL_CALLS.name else: raise ValueError(f"Unexpected number of components for {problem_type}: {num_components}") # Validate the ordered combination length matches the number of components selected_combination = ( list(top_choice["combination"]) if isinstance(top_choice["combination"], list | tuple) else [top_choice["combination"]] ) if len(selected_combination) != len(components): raise ValueError( f"Mismatch between combination length ({len(selected_combination)}) and components ({len(components)}) for {problem_type.name}" ) # Optimized within the searched grid if the chosen entropy is strictly below # the configured upper bound (i.e., the search did not hit the boundary). optimized = False if isinstance(entropy_config, tuple): if len(selected_combination) > 0 and selected_combination[0] < (entropy_config[1] - entropy_jitter): optimized = True else: # Multi-step: mark optimized if ANY component's chosen entropy is strictly # below its configured upper bound (didn't hit boundary for that component). component_index = 0 for task in entropy_config: if isinstance(task, Task) and component_index < len(selected_combination): if selected_combination[component_index] < (entropy_config[task][1] - entropy_jitter): optimized = True break component_index += 1 return { "combination": selected_combination, "score": top_choice["score"], "overall_min": top_choice["overall_min"], "overall_max": top_choice["overall_max"], "min_abs": top_choice["min_abs"], "count": top_choice["count"], "optimized": optimized, "metadata": { "is_single_step": is_single_step, "components": components, "difficulty_category": difficulty_category, "task_enum": problem_type.name, "entropy_jitter": entropy_jitter, "min_element_abs": min_value_abs, "step_size": step_size, "samples_per_test": samples_per_test, "target_min_value": target_min_value, "target_max_value": target_max_value, }, } def rank_entropy_combinations( statistics: dict[tuple, dict[str, Any]], *, target_min: float, target_max: float, weights: dict[str, float] | None = None, ) -> list[dict[str, Any]]: """Rank entropy combinations by distribution quality. Returns a list of dicts sorted by descending score, each with keys: - combination: tuple of entropy values - score: float in [0, 1] - metrics: dict as computed by _compute_distribution_metrics """ if weights is None: weights = {"compliance": 0.4, "center": 0.2, "coverage": 0.2, "balance": 0.1, "zero": 0.1} ranked: list[dict[str, Any]] = [] for combination, stats in statistics.items(): # Hard gate 1: overall range must be fully within targets if required overall_min = stats["overall_min"] overall_max = stats["overall_max"] if overall_min is None or overall_max is None: continue if not (target_min <= float(overall_min) <= float(target_max)): continue if not (target_min <= float(overall_max) <= float(target_max)): continue # Calculate distance from targets (lower is better) min_distance = abs(float(overall_min) - target_min) max_distance = abs(float(overall_max) - target_max) total_distance = min_distance + max_distance # Get min_abs and total count from statistics overall_min_abs = stats["overall_min_abs"] # Calculate total count across all questions total_count = sum(q_stats["count"] for q_stats in stats["per_question"]) ranked.append({ "combination": combination, "score": total_distance, "min_distance": min_distance, "max_distance": max_distance, "overall_min": overall_min, "overall_max": overall_max, "min_abs": overall_min_abs, "count": total_count, }) # Sort by total distance (ascending - closest first) ranked.sort(key=lambda d: d["score"]) return ranked