JustinTX's picture
Add files using upload-large-folder tool
14c9c2b verified
"""
Evaluator for circle packing example (n=26) with improved timeout handling
"""
import os
import argparse
import numpy as np
from typing import Tuple, Optional, List, Dict, Any
from shinka.core import run_shinka_eval
try:
import matplotlib
matplotlib.use('Agg') # Non-interactive backend
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Rectangle
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
def format_centers_string(centers: np.ndarray) -> str:
"""Formats circle centers into a multi-line string for display."""
return "\n".join(
[
f" centers[{i}] = ({x_coord:.4f}, {y_coord:.4f})"
for i, (x_coord, y_coord) in enumerate(centers)
]
)
def generate_circle_packing_visualization(
centers: np.ndarray,
radii: np.ndarray,
output_path: str,
sum_radii: float,
) -> bool:
"""
Generate a visualization of the circle packing arrangement.
Args:
centers: Array of circle centers (n, 2)
radii: Array of circle radii (n,)
output_path: Path to save the PNG image
sum_radii: Sum of all radii (for display)
Returns:
True if successful, False otherwise
"""
if not MATPLOTLIB_AVAILABLE:
return False
try:
fig, ax = plt.subplots(figsize=(10, 10))
# Draw unit square boundary
rect = Rectangle((0, 0), 1, 1, fill=False, edgecolor='black', linewidth=3)
ax.add_patch(rect)
# Normalize radii for color mapping
max_radius = radii.max() if radii.max() > 0 else 1.0
# Draw all circles
for i, (center, radius) in enumerate(zip(centers, radii)):
# Color based on radius size (larger = warmer color)
color_intensity = radius / max_radius
circle = Circle(
center,
radius,
fill=True,
alpha=0.6,
facecolor=plt.cm.viridis(color_intensity),
edgecolor='darkblue',
linewidth=1.5
)
ax.add_patch(circle)
# Add title and score
ax.set_title(
f'Circle Packing (n={len(centers)})\nSum of Radii: {sum_radii:.4f}',
fontsize=16,
fontweight='bold',
pad=20
)
# Add grid for better spatial understanding
ax.grid(True, alpha=0.3, linestyle='--')
# Set axis properties
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.set_aspect('equal')
ax.set_xlabel('X', fontsize=12)
ax.set_ylabel('Y', fontsize=12)
# Add colorbar to show radius scale
sm = plt.cm.ScalarMappable(
cmap=plt.cm.viridis,
norm=plt.Normalize(vmin=0, vmax=max_radius)
)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Circle Radius', fontsize=12)
# Save figure
plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close(fig)
return True
except Exception as e:
print(f"Warning: Failed to generate visualization: {e}")
return False
def adapted_validate_packing(
run_output: Tuple[np.ndarray, np.ndarray, float],
atol=1e-6,
) -> Tuple[bool, Optional[str]]:
"""
Validates circle packing results based on the output of 'run_packing'.
Args:
run_output: Tuple (centers, radii, reported_sum) from run_packing.
Returns:
(is_valid: bool, error_message: Optional[str])
"""
centers, radii, reported_sum = run_output
msg = "The circles are placed correctly. There are no overlaps or any circles outside the unit square."
if not isinstance(centers, np.ndarray):
centers = np.array(centers)
if not isinstance(radii, np.ndarray):
radii = np.array(radii)
n_expected = 26
if centers.shape != (n_expected, 2):
msg = (
f"Centers shape incorrect. Expected ({n_expected}, 2), got {centers.shape}"
)
return False, msg
if radii.shape != (n_expected,):
msg = f"Radii shape incorrect. Expected ({n_expected},), got {radii.shape}"
return False, msg
if np.any(radii < 0):
negative_indices = np.where(radii < 0)[0]
msg = f"Negative radii found for circles at indices: {negative_indices}"
return False, msg
if not np.isclose(np.sum(radii), reported_sum, atol=atol):
msg = (
f"Sum of radii ({np.sum(radii):.6f}) does not match "
f"reported ({reported_sum:.6f})"
)
return False, msg
for i in range(n_expected):
x, y = centers[i]
r = radii[i]
is_outside = (
x - r < -atol or x + r > 1 + atol or y - r < -atol or y + r > 1 + atol
)
if is_outside:
msg = (
f"Circle {i} (x={x:.4f}, y={y:.4f}, r={r:.4f}) is outside unit square."
)
return False, msg
for i in range(n_expected):
for j in range(i + 1, n_expected):
dist = np.sqrt(np.sum((centers[i] - centers[j]) ** 2))
if dist < radii[i] + radii[j] - atol:
msg = (
f"Circles {i} & {j} overlap. Dist: {dist:.4f}, "
f"Sum Radii: {(radii[i] + radii[j]):.4f}"
)
return False, msg
return True, msg
def get_circle_packing_kwargs(run_index: int) -> Dict[str, Any]:
"""Provides keyword arguments for circle packing runs (none needed)."""
return {}
def aggregate_circle_packing_metrics(
results: List[Tuple[np.ndarray, np.ndarray, float]], results_dir: str
) -> Dict[str, Any]:
"""
Aggregates metrics for circle packing. Assumes num_runs=1.
Saves extra.npz with detailed packing information and generates visualization.
"""
if not results:
return {"combined_score": 0.0, "error": "No results to aggregate"}
centers, radii, reported_sum = results[0]
public_metrics = {
"centers_str": format_centers_string(centers),
"num_circles": centers.shape[0],
}
private_metrics = {
"reported_sum_of_radii": float(reported_sum),
}
metrics = {
"combined_score": float(reported_sum),
"public": public_metrics,
"private": private_metrics,
}
# Save numpy data
extra_file = os.path.join(results_dir, "extra.npz")
try:
np.savez(
extra_file,
centers=centers,
radii=radii,
reported_sum=reported_sum,
)
print(f"Detailed packing data saved to {extra_file}")
except Exception as e:
print(f"Error saving extra.npz: {e}")
metrics["extra_npz_save_error"] = str(e)
# Generate visualization
viz_file = os.path.join(results_dir, "packing_viz.png")
try:
success = generate_circle_packing_visualization(
centers=centers,
radii=radii,
output_path=viz_file,
sum_radii=reported_sum,
)
if success:
print(f"Visualization saved to {viz_file}")
metrics["visualization_path"] = viz_file
else:
if not MATPLOTLIB_AVAILABLE:
print("Warning: matplotlib not available, skipping visualization")
metrics["visualization_error"] = "Failed to generate visualization"
except Exception as e:
print(f"Error generating visualization: {e}")
metrics["visualization_error"] = str(e)
return metrics
def main(program_path: str, results_dir: str):
"""Runs the circle packing evaluation using shinka.eval."""
print(f"Evaluating program: {program_path}")
print(f"Saving results to: {results_dir}")
os.makedirs(results_dir, exist_ok=True)
num_experiment_runs = 1
# Define a nested function to pass results_dir to the aggregator
def _aggregator_with_context(
r: List[Tuple[np.ndarray, np.ndarray, float]],
) -> Dict[str, Any]:
return aggregate_circle_packing_metrics(r, results_dir)
metrics, correct, error_msg = run_shinka_eval(
program_path=program_path,
results_dir=results_dir,
experiment_fn_name="run_packing",
num_runs=num_experiment_runs,
get_experiment_kwargs=get_circle_packing_kwargs,
validate_fn=adapted_validate_packing,
aggregate_metrics_fn=_aggregator_with_context,
)
if correct:
print("Evaluation and Validation completed successfully.")
else:
print(f"Evaluation or Validation failed: {error_msg}")
print("Metrics:")
for key, value in metrics.items():
if isinstance(value, str) and len(value) > 100:
print(f" {key}: <string_too_long_to_display>")
else:
print(f" {key}: {value}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Circle packing evaluator using shinka.eval"
)
parser.add_argument(
"--program_path",
type=str,
default="initial.py",
help="Path to program to evaluate (must contain 'run_packing')",
)
parser.add_argument(
"--results_dir",
type=str,
default="results",
help="Dir to save results (metrics.json, correct.json, extra.npz)",
)
parsed_args = parser.parse_args()
main(parsed_args.program_path, parsed_args.results_dir)