shinka-backup / shinka /core /sampler.py
JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
from typing import List, Optional, Tuple
import numpy as np
from pathlib import Path
from shinka.database import Program
from shinka.prompts import (
construct_eval_history_msg,
perf_str,
format_text_feedback_section,
format_merged_aux_metrics_section,
format_recent_attempts_summary,
BASE_SYSTEM_MSG,
DIFF_SYS_FORMAT,
DIFF_ITER_MSG,
FULL_ITER_MSG,
FULL_SYS_FORMATS,
CROSS_SYS_FORMAT,
CROSS_ITER_MSG,
get_cross_component,
)
from shinka.prompts.prompts_init import INIT_SYSTEM_MSG, INIT_USER_MSG
import logging
logger = logging.getLogger(__name__)
class PromptSampler:
def __init__(
self,
task_sys_msg: Optional[str] = None,
language: str = "python",
patch_types: Optional[List[str]] = None,
patch_type_probs: Optional[List[float]] = None,
use_text_feedback: bool = False,
results_dir: Optional[str] = None,
):
if patch_types is None:
patch_types = ["diff"]
if patch_type_probs is None:
patch_type_probs = [1.0]
self.task_sys_msg = task_sys_msg
self.language = language
self.patch_types = patch_types
self.patch_type_probs = patch_type_probs
self.results_dir = results_dir
# Check if probabilities sum to 1.0 w. tolerance for errors
prob_sum = np.sum(patch_type_probs)
if not np.isclose(prob_sum, 1.0, atol=1e-6):
raise ValueError(
f"Coding type probabilities must sum to 1.0, got {prob_sum:.6f}"
)
# Whether to use text feedback in the prompt
self.use_text_feedback = use_text_feedback
def initial_program_prompt(self) -> Tuple[str, str]:
"""Generate the prompt for the initial program."""
if self.task_sys_msg is None:
sys_msg = INIT_SYSTEM_MSG
task_description = "The user has not provided a task description."
else:
sys_msg = self.task_sys_msg
task_description = self.task_sys_msg
user_msg = INIT_USER_MSG.format(
language=self.language,
task_description=task_description,
)
return sys_msg, user_msg
def sample(
self,
parent: Program,
archive_inspirations: List[Program],
top_k_inspirations: List[Program],
meta_recommendations: Optional[str] = None,
recent_programs: Optional[List[Program]] = None,
best_program: Optional[Program] = None,
) -> Tuple[str, str, str, Optional[List[str]]]:
if self.task_sys_msg is None:
sys_msg = BASE_SYSTEM_MSG
else:
sys_msg = self.task_sys_msg
# Sample coding type
# Filter out crossover if no inspirations
if len(archive_inspirations) == 0 and len(top_k_inspirations) == 0:
valid_types = [t for t in self.patch_types if t != "cross"]
valid_probs = [
p
for t, p in zip(self.patch_types, self.patch_type_probs)
if t != "cross"
]
# Renormalize probabilities
valid_probs = [p / sum(valid_probs) for p in valid_probs]
patch_type = np.random.choice(valid_types, p=valid_probs)
else:
patch_type = np.random.choice(
self.patch_types,
p=self.patch_type_probs,
)
if patch_type == "diff":
sys_msg += DIFF_SYS_FORMAT
elif patch_type == "full":
# Randomly sample from different full rewrite variants
full_variant_idx = np.random.randint(0, len(FULL_SYS_FORMATS))
selected_format = FULL_SYS_FORMATS[full_variant_idx]
sys_msg += selected_format
elif patch_type == "cross":
sys_msg += CROSS_SYS_FORMAT
if len(archive_inspirations) > 0:
eval_history_msg = construct_eval_history_msg(
archive_inspirations,
language=self.language,
include_text_feedback=self.use_text_feedback,
)
else:
eval_history_msg = ""
# Add top-k inspirations
# TODO(RobertTLange): Check if order needs inversion
if len(top_k_inspirations) > 0:
eval_history_msg += construct_eval_history_msg(
top_k_inspirations,
language=self.language,
include_text_feedback=self.use_text_feedback,
)
# Format text feedback section for current program
text_feedback_section = ""
if self.use_text_feedback:
text_feedback_section = "\n" + format_text_feedback_section(
parent.text_feedback
)
# Inject persistent diagnostic report (independent of parent selection)
if self.results_dir:
diagnostic_path = Path(self.results_dir) / "eval_agent_memory" / "diagnostic_report.md"
if diagnostic_path.exists():
try:
diagnostic = diagnostic_path.read_text(encoding="utf-8").strip()
if diagnostic:
text_feedback_section += (
"\n\nNote: The following diagnosis may refer to a different generation's code "
"than the current program above. Use the performance metrics of the current "
"program as ground truth, and treat the diagnosis as directional advice.\n\n"
+ diagnostic
)
except Exception:
pass
# Merge auxiliary metric descriptions into one deduplicated section.
aux_metrics_section = ""
if self.use_text_feedback:
aux_metrics_section = format_merged_aux_metrics_section(
[parent] + archive_inspirations + top_k_inspirations
)
if patch_type == "diff":
iter_msg = DIFF_ITER_MSG.format(
language=self.language,
code_content=parent.code,
performance_metrics=perf_str(
parent.combined_score, parent.public_metrics
),
text_feedback_section=text_feedback_section,
)
elif patch_type == "full":
iter_msg = FULL_ITER_MSG.format(
language=self.language,
code_content=parent.code,
performance_metrics=perf_str(
parent.combined_score, parent.public_metrics
),
text_feedback_section=text_feedback_section,
)
elif patch_type == "cross":
iter_msg = CROSS_ITER_MSG.format(
language=self.language,
code_content=parent.code,
performance_metrics=perf_str(
parent.combined_score, parent.public_metrics
),
text_feedback_section=text_feedback_section,
)
iter_msg += "\n\n" + get_cross_component(
archive_inspirations,
top_k_inspirations,
language=self.language,
)
elif patch_type == "paper":
raise NotImplementedError("Paper edit not implemented.")
else:
raise ValueError(f"Invalid patch type: {patch_type}")
# Inject recent attempt history table
recent_attempts_section = format_recent_attempts_summary(
recent_programs or [], best_program
)
if recent_attempts_section:
iter_msg += recent_attempts_section
if aux_metrics_section:
iter_msg += aux_metrics_section
# Add meta-recommendations if provided
sum_rec_msg = ""
if meta_recommendations not in [None, "none"] and patch_type != "cross":
sum_rec_msg += "\n\n# Potential Recommendations"
sum_rec_msg += (
"\nThe following are potential recommendations for the "
"next program generations:\n\n"
)
sum_rec_msg += f"\n{meta_recommendations}"
# Check for visualization images
images = self._collect_visualization_images(parent, archive_inspirations, top_k_inspirations)
# Add note about images if any found
if images:
image_note = f"\n\n# Visual Information\nAttached are {len(images)} visualization(s) showing the current and/or inspiration program results."
iter_msg += image_note
return (
sys_msg + sum_rec_msg,
eval_history_msg + "\n" + iter_msg,
patch_type,
images,
)
def _collect_visualization_images(
self,
parent: Program,
archive_inspirations: List[Program],
top_k_inspirations: List[Program],
) -> Optional[List[str]]:
"""
Collect visualization images from parent and inspiration programs.
Returns:
List of image paths, or None if no images found
"""
images = []
# Try to find parent's visualization
# Construct results path from results_dir and generation
if self.results_dir and parent.generation is not None:
parent_results_dir = Path(self.results_dir) / f"gen_{parent.generation}" / "results"
parent_viz = parent_results_dir / "packing_viz.png"
if parent_viz.exists():
images.append(str(parent_viz))
logger.info(f"Found parent visualization: {parent_viz}")
# Optionally add inspiration visualizations (limited to avoid too many images)
# max_inspiration_imgs = 2
# for prog in (archive_inspirations + top_k_inspirations)[:max_inspiration_imgs]:
# if self.results_dir and prog.generation is not None:
# insp_results_dir = Path(self.results_dir) / f"gen_{prog.generation}" / "results"
# insp_viz = insp_results_dir / "packing_viz.png"
# if insp_viz.exists():
# images.append(str(insp_viz))
return images if images else None