| from typing import List, Optional, Tuple |
| import numpy as np |
| from pathlib import Path |
| from shinka.database import Program |
| from shinka.prompts import ( |
| construct_eval_history_msg, |
| perf_str, |
| format_text_feedback_section, |
| format_merged_aux_metrics_section, |
| format_recent_attempts_summary, |
| BASE_SYSTEM_MSG, |
| DIFF_SYS_FORMAT, |
| DIFF_ITER_MSG, |
| FULL_ITER_MSG, |
| FULL_SYS_FORMATS, |
| CROSS_SYS_FORMAT, |
| CROSS_ITER_MSG, |
| get_cross_component, |
| ) |
| from shinka.prompts.prompts_init import INIT_SYSTEM_MSG, INIT_USER_MSG |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class PromptSampler: |
| def __init__( |
| self, |
| task_sys_msg: Optional[str] = None, |
| language: str = "python", |
| patch_types: Optional[List[str]] = None, |
| patch_type_probs: Optional[List[float]] = None, |
| use_text_feedback: bool = False, |
| results_dir: Optional[str] = None, |
| ): |
| if patch_types is None: |
| patch_types = ["diff"] |
| if patch_type_probs is None: |
| patch_type_probs = [1.0] |
|
|
| self.task_sys_msg = task_sys_msg |
| self.language = language |
| self.patch_types = patch_types |
| self.patch_type_probs = patch_type_probs |
| self.results_dir = results_dir |
| |
| prob_sum = np.sum(patch_type_probs) |
| if not np.isclose(prob_sum, 1.0, atol=1e-6): |
| raise ValueError( |
| f"Coding type probabilities must sum to 1.0, got {prob_sum:.6f}" |
| ) |
| |
| self.use_text_feedback = use_text_feedback |
|
|
| def initial_program_prompt(self) -> Tuple[str, str]: |
| """Generate the prompt for the initial program.""" |
| if self.task_sys_msg is None: |
| sys_msg = INIT_SYSTEM_MSG |
| task_description = "The user has not provided a task description." |
| else: |
| sys_msg = self.task_sys_msg |
| task_description = self.task_sys_msg |
|
|
| user_msg = INIT_USER_MSG.format( |
| language=self.language, |
| task_description=task_description, |
| ) |
| return sys_msg, user_msg |
|
|
| def sample( |
| self, |
| parent: Program, |
| archive_inspirations: List[Program], |
| top_k_inspirations: List[Program], |
| meta_recommendations: Optional[str] = None, |
| recent_programs: Optional[List[Program]] = None, |
| best_program: Optional[Program] = None, |
| ) -> Tuple[str, str, str, Optional[List[str]]]: |
| if self.task_sys_msg is None: |
| sys_msg = BASE_SYSTEM_MSG |
| else: |
| sys_msg = self.task_sys_msg |
|
|
| |
| |
| if len(archive_inspirations) == 0 and len(top_k_inspirations) == 0: |
| valid_types = [t for t in self.patch_types if t != "cross"] |
| valid_probs = [ |
| p |
| for t, p in zip(self.patch_types, self.patch_type_probs) |
| if t != "cross" |
| ] |
| |
| valid_probs = [p / sum(valid_probs) for p in valid_probs] |
| patch_type = np.random.choice(valid_types, p=valid_probs) |
| else: |
| patch_type = np.random.choice( |
| self.patch_types, |
| p=self.patch_type_probs, |
| ) |
|
|
| if patch_type == "diff": |
| sys_msg += DIFF_SYS_FORMAT |
| elif patch_type == "full": |
| |
| full_variant_idx = np.random.randint(0, len(FULL_SYS_FORMATS)) |
| selected_format = FULL_SYS_FORMATS[full_variant_idx] |
| sys_msg += selected_format |
| elif patch_type == "cross": |
| sys_msg += CROSS_SYS_FORMAT |
|
|
| if len(archive_inspirations) > 0: |
| eval_history_msg = construct_eval_history_msg( |
| archive_inspirations, |
| language=self.language, |
| include_text_feedback=self.use_text_feedback, |
| ) |
| else: |
| eval_history_msg = "" |
|
|
| |
| |
| if len(top_k_inspirations) > 0: |
| eval_history_msg += construct_eval_history_msg( |
| top_k_inspirations, |
| language=self.language, |
| include_text_feedback=self.use_text_feedback, |
| ) |
|
|
| |
| text_feedback_section = "" |
| if self.use_text_feedback: |
| text_feedback_section = "\n" + format_text_feedback_section( |
| parent.text_feedback |
| ) |
|
|
| |
| if self.results_dir: |
| diagnostic_path = Path(self.results_dir) / "eval_agent_memory" / "diagnostic_report.md" |
| if diagnostic_path.exists(): |
| try: |
| diagnostic = diagnostic_path.read_text(encoding="utf-8").strip() |
| if diagnostic: |
| text_feedback_section += ( |
| "\n\nNote: The following diagnosis may refer to a different generation's code " |
| "than the current program above. Use the performance metrics of the current " |
| "program as ground truth, and treat the diagnosis as directional advice.\n\n" |
| + diagnostic |
| ) |
| except Exception: |
| pass |
|
|
| |
| aux_metrics_section = "" |
| if self.use_text_feedback: |
| aux_metrics_section = format_merged_aux_metrics_section( |
| [parent] + archive_inspirations + top_k_inspirations |
| ) |
|
|
| if patch_type == "diff": |
| iter_msg = DIFF_ITER_MSG.format( |
| language=self.language, |
| code_content=parent.code, |
| performance_metrics=perf_str( |
| parent.combined_score, parent.public_metrics |
| ), |
| text_feedback_section=text_feedback_section, |
| ) |
| elif patch_type == "full": |
| iter_msg = FULL_ITER_MSG.format( |
| language=self.language, |
| code_content=parent.code, |
| performance_metrics=perf_str( |
| parent.combined_score, parent.public_metrics |
| ), |
| text_feedback_section=text_feedback_section, |
| ) |
| elif patch_type == "cross": |
| iter_msg = CROSS_ITER_MSG.format( |
| language=self.language, |
| code_content=parent.code, |
| performance_metrics=perf_str( |
| parent.combined_score, parent.public_metrics |
| ), |
| text_feedback_section=text_feedback_section, |
| ) |
| iter_msg += "\n\n" + get_cross_component( |
| archive_inspirations, |
| top_k_inspirations, |
| language=self.language, |
| ) |
| elif patch_type == "paper": |
| raise NotImplementedError("Paper edit not implemented.") |
| else: |
| raise ValueError(f"Invalid patch type: {patch_type}") |
|
|
| |
| recent_attempts_section = format_recent_attempts_summary( |
| recent_programs or [], best_program |
| ) |
| if recent_attempts_section: |
| iter_msg += recent_attempts_section |
|
|
| if aux_metrics_section: |
| iter_msg += aux_metrics_section |
|
|
| |
| sum_rec_msg = "" |
| if meta_recommendations not in [None, "none"] and patch_type != "cross": |
| sum_rec_msg += "\n\n# Potential Recommendations" |
| sum_rec_msg += ( |
| "\nThe following are potential recommendations for the " |
| "next program generations:\n\n" |
| ) |
| sum_rec_msg += f"\n{meta_recommendations}" |
|
|
| |
| images = self._collect_visualization_images(parent, archive_inspirations, top_k_inspirations) |
| |
| |
| if images: |
| image_note = f"\n\n# Visual Information\nAttached are {len(images)} visualization(s) showing the current and/or inspiration program results." |
| iter_msg += image_note |
|
|
| return ( |
| sys_msg + sum_rec_msg, |
| eval_history_msg + "\n" + iter_msg, |
| patch_type, |
| images, |
| ) |
| |
| def _collect_visualization_images( |
| self, |
| parent: Program, |
| archive_inspirations: List[Program], |
| top_k_inspirations: List[Program], |
| ) -> Optional[List[str]]: |
| """ |
| Collect visualization images from parent and inspiration programs. |
| |
| Returns: |
| List of image paths, or None if no images found |
| """ |
| images = [] |
| |
| |
| |
| if self.results_dir and parent.generation is not None: |
| parent_results_dir = Path(self.results_dir) / f"gen_{parent.generation}" / "results" |
| parent_viz = parent_results_dir / "packing_viz.png" |
| if parent_viz.exists(): |
| images.append(str(parent_viz)) |
| logger.info(f"Found parent visualization: {parent_viz}") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return images if images else None |
|
|