atomwalk12's picture
initial commit
0dd6c2f
import ast
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator
from linalg_zero.generator.models import DifficultyCategory, Question, Task
from linalg_zero.shared.utils import get_logger
logger = get_logger(__name__)
def print_statistics_summary(statistics: dict[str, Any]) -> None: # pragma: no cover
"""Print a formatted summary of dataset statistics."""
if not statistics:
logger.info("No statistics available.")
return
logger.info("=" * 50)
logger.info("DATASET STATISTICS SUMMARY")
logger.info("=" * 50)
# Overall statistics
overall_min = statistics.get("overall_min")
overall_max = statistics.get("overall_max")
overall_min_abs = statistics.get("overall_min_abs")
logger.info(f"Overall Range: {overall_min} to {overall_max}")
logger.info(f"Overall Min Absolute: {overall_min_abs}")
# Per-step statistics
per_step = statistics.get("per_step", {})
if per_step:
logger.info("Per-Step Statistics:")
for step_idx in sorted(per_step.keys()):
step_stats = per_step[step_idx]
logger.info(
f" Step {step_idx}: min={step_stats.get('min')}, max={step_stats.get('max')}, min_abs={step_stats.get('min_abs')}, count={step_stats.get('count')}"
)
# Per-problem-type statistics
per_problem_type = statistics.get("per_problem_type", {})
if per_problem_type:
logger.info("Per-Problem-Type Statistics:")
for problem_type, type_stats in per_problem_type.items():
logger.info(
f" {problem_type}: min={type_stats.get('min')}, max={type_stats.get('max')}, min_abs={type_stats.get('min_abs')}, count={type_stats.get('count')}"
)
# Per-question statistics
per_question = statistics.get("per_question", [])
if per_question:
logger.info(f"Per-Question Statistics: {len(per_question)} questions analyzed")
# Show first few questions as examples
for i, q_stats in enumerate(per_question[:3]):
logger.info(
f" Q{i + 1}: min={q_stats.get('min')}, max={q_stats.get('max')}, min_abs={q_stats.get('min_abs')}, count={q_stats.get('count')}"
)
if len(per_question) > 3:
logger.info(f" ... and {len(per_question) - 3} more questions")
logger.info("=" * 50)
def _extract_numeric_values_from_object(obj: Any) -> list[float]:
"""Recursively extract numeric values (as floats) from an arbitrary object."""
values: list[float] = []
if isinstance(obj, int | float):
values.append(float(obj))
return values
if isinstance(obj, complex):
raise TypeError(f"Complex number found: {obj}")
if isinstance(obj, list | tuple):
for item in obj:
values.extend(_extract_numeric_values_from_object(item))
return values
if isinstance(obj, dict):
for v in obj.values():
values.extend(_extract_numeric_values_from_object(v))
return values
return values
def compute_stepwise_value_statistics(questions: list[Question]) -> dict[str, Any]:
"""Scan stepwise results from all questions and compute statistics.
Returns a dictionary with:
- overall_min: float | None
- overall_max: float | None
- overall_min_abs: float | None
- per_question: list of {index, min, max, min_abs, count}
- per_step: dict[int, {min, max, min_abs, count}] aggregated across all questions by step index
- per_problem_type: dict[str, {min, max, min_abs, count}] aggregated by problem type
- all_values: flat list[float] of every numeric value encountered across all steps/questions
"""
overall_min: float | None = None
overall_max: float | None = None
overall_min_abs: float | None = None
per_question: list[dict[str, Any]] = []
per_step: dict[int, dict[str, Any]] = {}
per_problem_type: dict[str, dict[str, Any]] = {}
all_values: list[float] = []
for q_index, question in enumerate(questions):
q_min: float | None = None
q_max: float | None = None
q_min_abs: float | None = None
q_count: int = 0
# Resolve problem type key once per question
pt_key = getattr(question.problem_type, "value", str(question.problem_type))
for step_index, step in enumerate(question.stepwise):
# Parse the step result into a Python object
result_str = step.get("result")
if result_str is None:
raise ValueError(f"Step {step_index} has no result")
parsed = ast.literal_eval(result_str)
# Extract numeric values
numeric_values = _extract_numeric_values_from_object(parsed)
if not numeric_values:
# Initialize per-step entry with zero count if not present
if step_index not in per_step:
per_step[step_index] = {"min": None, "max": None, "min_abs": None, "count": 0}
continue
step_min = min(numeric_values)
step_max = max(numeric_values)
step_min_abs = min(abs(v) for v in numeric_values)
step_count = len(numeric_values)
# Aggregate raw values
all_values.extend(float(v) for v in numeric_values)
# Update overall stats
overall_min = step_min if overall_min is None else min(overall_min, step_min)
overall_max = step_max if overall_max is None else max(overall_max, step_max)
overall_min_abs = step_min_abs if overall_min_abs is None else min(overall_min_abs, step_min_abs)
# Update question stats
q_min = step_min if q_min is None else min(q_min, step_min)
q_max = step_max if q_max is None else max(q_max, step_max)
q_min_abs = step_min_abs if q_min_abs is None else min(q_min_abs, step_min_abs)
q_count += step_count
# Update per-step aggregated stats
if step_index not in per_step:
per_step[step_index] = {"min": step_min, "max": step_max, "min_abs": step_min_abs, "count": step_count}
else:
ps = per_step[step_index]
ps_min = ps["min"]
ps_max = ps["max"]
ps_min_abs = ps["min_abs"]
ps["min"] = step_min if ps_min is None else min(ps_min, step_min)
ps["max"] = step_max if ps_max is None else max(ps_max, step_max)
ps["min_abs"] = step_min_abs if ps_min_abs is None else min(ps_min_abs, step_min_abs)
ps["count"] += step_count
# Update per-problem-type aggregated stats
if pt_key not in per_problem_type:
per_problem_type[pt_key] = {
"min": step_min,
"max": step_max,
"min_abs": step_min_abs,
"count": step_count,
}
else:
ppt = per_problem_type[pt_key]
ppt_min = ppt["min"]
ppt_max = ppt["max"]
ppt_min_abs = ppt["min_abs"]
ppt["min"] = step_min if ppt_min is None else min(ppt_min, step_min)
ppt["max"] = step_max if ppt_max is None else max(ppt_max, step_max)
ppt["min_abs"] = step_min_abs if ppt_min_abs is None else min(ppt_min_abs, step_min_abs)
ppt["count"] += step_count
per_question.append({"index": q_index, "min": q_min, "max": q_max, "min_abs": q_min_abs, "count": q_count})
return {
"overall_min": overall_min,
"overall_max": overall_max,
"overall_min_abs": overall_min_abs if overall_min_abs is not None else 0.0,
"per_question": per_question,
"per_step": per_step,
"per_problem_type": per_problem_type,
"all_values": all_values,
}
def extract_all_numerical_values(statistics: dict[tuple, dict[str, Any]], use_min_max: bool = False) -> list[float]:
"""Extract all raw stepwise numerical values from all runs.
Uses the aggregated `all_values` emitted by `compute_stepwise_value_statistics` for
each combination rather than only per-question min/max boundaries.
"""
all_values: list[float] = []
for _, stats in statistics.items():
if use_min_max:
values = []
per_question = stats.get("per_question", [])
for q_stats in per_question:
values.append(q_stats["min"])
values.append(q_stats["max"])
else:
values = stats.get("all_values", [])
all_values.extend(values)
for value in all_values:
assert value is not None, "Value is None"
return all_values
def extract_values_by_combination(
statistics: dict[tuple, dict[str, Any]], use_min_max: bool = False
) -> dict[tuple, list[float]]:
"""Extract raw stepwise numerical values grouped by entropy combination."""
values_by_combination: dict[tuple, list[float]] = {}
for combination, stats in statistics.items():
if use_min_max:
values = []
per_question = stats.get("per_question", [])
for q_stats in per_question:
values.append(q_stats["min"])
values.append(q_stats["max"])
else:
values = list(stats.get("all_values", []))
values_by_combination[combination] = values
for key, items in values_by_combination.items():
assert all(value is not None for value in items), f"Values are None for combination {key}"
return values_by_combination
def plot_overall_histogram(
all_values: list[float], target_min: float = -1000, target_max: float = 1000, output_dir: Path | None = None
) -> None:
"""Plot histogram of all numerical values across all runs."""
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(all_values, bins=50, alpha=0.7, color="skyblue", edgecolor="black")
plt.axvline(target_min, color="red", linestyle="--", label=f"Target Min ({target_min})")
plt.axvline(target_max, color="red", linestyle="--", label=f"Target Max ({target_max})")
plt.xlabel("Numerical Values")
plt.ylabel("Frequency")
plt.title("Distribution of All Numerical Values")
plt.legend()
plt.grid(True, alpha=0.3)
# Log scale version
plt.subplot(1, 2, 2)
# Filter out zero and negative values for log scale
positive_values = [v for v in all_values if v > 0]
if positive_values:
plt.hist(positive_values, bins=50, alpha=0.7, color="lightcoral", edgecolor="black")
plt.axvline(target_max, color="red", linestyle="--", label=f"Target Max ({target_max})")
plt.xlabel("Numerical Values (log scale)")
plt.ylabel("Frequency")
plt.title("Distribution of Positive Values (Log Scale)")
plt.xscale("log")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
filename = "overall_distribution.png"
filepath = output_dir / filename if output_dir else filename
plt.savefig(filepath, dpi=300, bbox_inches="tight")
plt.show()
def plot_combination_comparison(
values_by_combination: dict[tuple, list[float]], max_combinations: int = 12, output_dir: Path | None = None
) -> None:
"""Plot comparison of value distributions across different entropy combinations."""
# Limit to top combinations by number of values
sorted_combinations = sorted(values_by_combination.items(), key=lambda x: len(x[1]), reverse=True)[
:max_combinations
]
_, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()
for i, (combination, values) in enumerate(sorted_combinations):
if i >= len(axes):
break
ax = axes[i]
if values: # Only plot if we have values
ax.hist(values, bins=20, alpha=0.7, color=plt.colormaps["tab10"](i % 10), edgecolor="black")
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_title(f"Combo: {combination}", fontsize=10)
ax.set_xlabel("Values")
ax.set_ylabel("Freq")
ax.grid(True, alpha=0.3)
# Add statistics text
if values:
mean_val = np.mean(values)
std_val = np.std(values)
ax.text(
0.05,
0.95,
f"μ={mean_val:.1f}\no={std_val:.1f}\nn={len(values)}",
transform=ax.transAxes,
verticalalignment="top",
bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.8},
)
# Hide unused subplots
for i in range(len(sorted_combinations), len(axes)):
axes[i].set_visible(False)
plt.suptitle("Value Distributions by Entropy Combination", fontsize=14)
plt.tight_layout()
filename = "by_combination.png"
filepath = output_dir / filename if output_dir else filename
plt.savefig(filepath, dpi=300, bbox_inches="tight")
plt.show()
def plot_target_compliance(
statistics: dict[tuple, dict[str, Any]],
target_min: float = -1000,
target_max: float = 1000,
output_dir: Path | None = None,
) -> None:
"""Plot how well each combination complies with target ranges."""
compliance_data = []
for combination, stats in statistics.items():
overall_min = stats.get("overall_min")
overall_max = stats.get("overall_max")
if overall_min is None or overall_max is None:
raise ValueError(f"Overall min or max is None for combination {combination}")
# overall_min/overall_max are guaranteed non-None above; avoid falsy-zero filtering
within_range = (target_min <= overall_min) and (overall_max <= target_max)
compliance_data.append({
"combination": str(combination),
"min": overall_min,
"max": overall_max,
"compliant": within_range,
})
if not compliance_data:
logger.warning("No compliance data available for plotting")
return
compliant = [d for d in compliance_data if d["compliant"]]
non_compliant = [d for d in compliance_data if not d["compliant"]]
plt.figure(figsize=(12, 8))
# Plot compliant combinations
if compliant:
plt.scatter(
[d["min"] for d in compliant],
[d["max"] for d in compliant],
c="green",
alpha=0.7,
s=100,
label=f"Compliant ({len(compliant)})",
)
# Plot non-compliant combinations
if non_compliant:
plt.scatter(
[d["min"] for d in non_compliant],
[d["max"] for d in non_compliant],
c="red",
alpha=0.7,
s=100,
label=f"Non-compliant ({len(non_compliant)})",
)
# Add target range box
plt.axhline(target_max, color="blue", linestyle="--", alpha=0.5, label=f"Target Max ({target_max})")
plt.axvline(target_min, color="blue", linestyle="--", alpha=0.5, label=f"Target Min ({target_min})")
plt.xlabel("Overall Minimum Value")
plt.ylabel("Overall Maximum Value")
plt.title("Entropy Combinations: Target Range Compliance")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
filename = "compliance.png"
filepath = output_dir / filename if output_dir else filename
plt.savefig(filepath, dpi=300, bbox_inches="tight")
plt.show()
# Print summary
total_combinations = len(compliance_data)
compliant_count = len(compliant)
logger.info("Target Compliance Summary:")
logger.info(f" Total combinations: {total_combinations}")
logger.info(f" Compliant: {compliant_count} ({compliant_count / total_combinations * 100:.1f}%)")
logger.info(
f" Non-compliant: {total_combinations - compliant_count} ({(total_combinations - compliant_count) / total_combinations * 100:.1f}%)"
)
def extract_report_metadata(
top_choice: dict[str, Any],
problem_type: Task,
entropy_config: tuple[float, float] | dict[Task, tuple[float, float]],
min_value_abs: float,
entropy_jitter: float,
*,
step_size: float,
samples_per_test: int,
target_min_value: float,
target_max_value: float,
) -> dict[str, Any]:
if isinstance(entropy_config, tuple):
# Single-step problem
is_single_step = True
components = [problem_type.name]
difficulty_category = DifficultyCategory.ONE_TOOL_CALL.name
else:
# Multi-step problem - get components from the dict keys
is_single_step = False
components = [task.name for task in entropy_config if isinstance(task, Task)]
# Determine difficulty category based on number of components
num_components = len(components)
if num_components == 2:
difficulty_category = DifficultyCategory.TWO_TOOL_CALLS.name
elif num_components == 3:
difficulty_category = DifficultyCategory.THREE_TOOL_CALLS.name
else:
raise ValueError(f"Unexpected number of components for {problem_type}: {num_components}")
# Validate the ordered combination length matches the number of components
selected_combination = (
list(top_choice["combination"])
if isinstance(top_choice["combination"], list | tuple)
else [top_choice["combination"]]
)
if len(selected_combination) != len(components):
raise ValueError(
f"Mismatch between combination length ({len(selected_combination)}) and components ({len(components)}) for {problem_type.name}"
)
# Optimized within the searched grid if the chosen entropy is strictly below
# the configured upper bound (i.e., the search did not hit the boundary).
optimized = False
if isinstance(entropy_config, tuple):
if len(selected_combination) > 0 and selected_combination[0] < (entropy_config[1] - entropy_jitter):
optimized = True
else:
# Multi-step: mark optimized if ANY component's chosen entropy is strictly
# below its configured upper bound (didn't hit boundary for that component).
component_index = 0
for task in entropy_config:
if isinstance(task, Task) and component_index < len(selected_combination):
if selected_combination[component_index] < (entropy_config[task][1] - entropy_jitter):
optimized = True
break
component_index += 1
return {
"combination": selected_combination,
"score": top_choice["score"],
"overall_min": top_choice["overall_min"],
"overall_max": top_choice["overall_max"],
"min_abs": top_choice["min_abs"],
"count": top_choice["count"],
"optimized": optimized,
"metadata": {
"is_single_step": is_single_step,
"components": components,
"difficulty_category": difficulty_category,
"task_enum": problem_type.name,
"entropy_jitter": entropy_jitter,
"min_element_abs": min_value_abs,
"step_size": step_size,
"samples_per_test": samples_per_test,
"target_min_value": target_min_value,
"target_max_value": target_max_value,
},
}
def rank_entropy_combinations(
statistics: dict[tuple, dict[str, Any]],
*,
target_min: float,
target_max: float,
weights: dict[str, float] | None = None,
) -> list[dict[str, Any]]:
"""Rank entropy combinations by distribution quality.
Returns a list of dicts sorted by descending score, each with keys:
- combination: tuple of entropy values
- score: float in [0, 1]
- metrics: dict as computed by _compute_distribution_metrics
"""
if weights is None:
weights = {"compliance": 0.4, "center": 0.2, "coverage": 0.2, "balance": 0.1, "zero": 0.1}
ranked: list[dict[str, Any]] = []
for combination, stats in statistics.items():
# Hard gate 1: overall range must be fully within targets if required
overall_min = stats["overall_min"]
overall_max = stats["overall_max"]
if overall_min is None or overall_max is None:
continue
if not (target_min <= float(overall_min) <= float(target_max)):
continue
if not (target_min <= float(overall_max) <= float(target_max)):
continue
# Calculate distance from targets (lower is better)
min_distance = abs(float(overall_min) - target_min)
max_distance = abs(float(overall_max) - target_max)
total_distance = min_distance + max_distance
# Get min_abs and total count from statistics
overall_min_abs = stats["overall_min_abs"]
# Calculate total count across all questions
total_count = sum(q_stats["count"] for q_stats in stats["per_question"])
ranked.append({
"combination": combination,
"score": total_distance,
"min_distance": min_distance,
"max_distance": max_distance,
"overall_min": overall_min,
"overall_max": overall_max,
"min_abs": overall_min_abs,
"count": total_count,
})
# Sort by total distance (ascending - closest first)
ranked.sort(key=lambda d: d["score"])
return ranked