JustinTX's picture
Add files using upload-large-folder tool
2facf1f verified
import json
import logging
import sqlite3
import time
from dataclasses import asdict, dataclass, field
from functools import wraps
from pathlib import Path
import random
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union
import math
from .complexity import analyze_code_metrics
from .parents import CombinedParentSelector
from .inspirations import CombinedContextSelector
from .islands import CombinedIslandManager
from .display import DatabaseDisplay
from shinka.llm.embedding import EmbeddingClient
logger = logging.getLogger(__name__)
def clean_nan_values(obj: Any) -> Any:
"""
Recursively clean NaN values from a data structure, replacing them with
None. This ensures JSON serialization works correctly.
"""
if isinstance(obj, dict):
return {key: clean_nan_values(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [clean_nan_values(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(clean_nan_values(item) for item in obj)
elif isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
return None
elif isinstance(obj, np.floating) and (np.isnan(obj) or np.isinf(obj)):
return None
elif hasattr(obj, "dtype") and np.issubdtype(obj.dtype, np.floating):
# Handle numpy arrays and scalars
if np.isscalar(obj):
if np.isnan(obj) or np.isinf(obj):
return None
else:
return float(obj)
else:
# For numpy arrays, convert to list and clean recursively
return clean_nan_values(obj.tolist())
else:
return obj
@dataclass
class DatabaseConfig:
db_path: str = "evolution_db.sqlite"
num_islands: int = 4
archive_size: int = 100
# Inspiration parameters
elite_selection_ratio: float = 0.3 # Prop of elites inspirations
num_archive_inspirations: int = 5 # No. inspiration programs
num_top_k_inspirations: int = 2 # No. top-k inspiration programs
# Island model/migration parameters
migration_interval: int = 10 # Migrate every N generations
migration_rate: float = 0.1 # Prop. of island pop. to migrate
island_elitism: bool = True # Keep best prog on their islands
enforce_island_separation: bool = (
True # Enforce full island separation for inspirations
)
# Parent selection parameters
parent_selection_strategy: str = (
"power_law" # "weighted"/"power_law" / "beam_search"
)
# Power-law parent selection parameters
exploitation_alpha: float = 1.0 # 0=uniform, 1=power-law
exploitation_ratio: float = 0.2 # Chance to pick from archive
# Weighted tree parent selection parameters
parent_selection_lambda: float = 10.0 # >0 sharpness of sigmoid
# Beam search parent selection parameters
num_beams: int = 5
# Embedding model name
embedding_model: str = "text-embedding-3-small"
def db_retry(max_retries=5, initial_delay=0.1, backoff_factor=2):
"""
A decorator to retry database operations on specific SQLite errors.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
delay = initial_delay
for i in range(max_retries):
try:
return func(*args, **kwargs)
except (
sqlite3.OperationalError,
sqlite3.DatabaseError,
sqlite3.IntegrityError,
) as e:
if i == max_retries - 1:
logger.error(
f"DB operation {func.__name__} failed after "
f"{max_retries} retries: {e}"
)
raise
logger.warning(
f"DB operation {func.__name__} failed with "
f"{type(e).__name__}: {e}. "
f"Retrying in {delay:.2f}s..."
)
time.sleep(delay)
delay *= backoff_factor
# This part should not be reachable if max_retries > 0
raise RuntimeError(
f"DB retry logic failed for function {func.__name__} without "
"raising an exception."
)
return wrapper
return decorator
@dataclass
class Program:
"""Represents a program in the database"""
# Program identification
id: str
code: str
language: str = "python"
# Evolution information
parent_id: Optional[str] = None
archive_inspiration_ids: List[str] = field(
default_factory=list
) # IDs of programs used as archive inspiration
top_k_inspiration_ids: List[str] = field(
default_factory=list
) # IDs of programs used as top-k inspiration
island_idx: Optional[int] = None
generation: int = 0
timestamp: float = field(default_factory=time.time)
code_diff: Optional[str] = None
# Performance metrics
combined_score: float = 0.0
public_metrics: Dict[str, Any] = field(default_factory=dict)
private_metrics: Dict[str, Any] = field(default_factory=dict)
text_feedback: Union[str, List[str]] = ""
correct: bool = False # Whether the program is functionally correct
children_count: int = 0
# Derived features
complexity: float = 0.0 # Calculated based on code or other features
embedding: List[float] = field(default_factory=list)
embedding_pca_2d: List[float] = field(default_factory=list)
embedding_pca_3d: List[float] = field(default_factory=list)
embedding_cluster_id: Optional[int] = None
# Migration history
migration_history: List[Dict[str, Any]] = field(default_factory=list)
# Metadata
metadata: Dict[str, Any] = field(default_factory=dict)
# Archive status
in_archive: bool = False
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict representation, cleaning NaN values for JSON."""
data = asdict(self)
return clean_nan_values(data)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Program":
"""Create from dictionary representation, ensuring correct types for
nested dicts."""
# Ensure metrics and metadata are dictionaries, even if None/empty from
# DB or input
data["public_metrics"] = (
data.get("public_metrics")
if isinstance(data.get("public_metrics"), dict)
else {}
)
data["private_metrics"] = (
data.get("private_metrics")
if isinstance(data.get("private_metrics"), dict)
else {}
)
data["metadata"] = (
data.get("metadata") if isinstance(data.get("metadata"), dict) else {}
)
# Ensure inspiration_ids is a list
archive_ids_val = data.get("archive_inspiration_ids")
if isinstance(archive_ids_val, list):
data["archive_inspiration_ids"] = archive_ids_val
else:
data["archive_inspiration_ids"] = []
top_k_ids_val = data.get("top_k_inspiration_ids")
if isinstance(top_k_ids_val, list):
data["top_k_inspiration_ids"] = top_k_ids_val
else:
data["top_k_inspiration_ids"] = []
# Ensure embedding is a list
embedding_val = data.get("embedding")
if isinstance(embedding_val, list):
data["embedding"] = embedding_val
else:
data["embedding"] = []
embedding_pca_2d_val = data.get("embedding_pca_2d")
if isinstance(embedding_pca_2d_val, list):
data["embedding_pca_2d"] = embedding_pca_2d_val
else:
data["embedding_pca_2d"] = []
embedding_pca_3d_val = data.get("embedding_pca_3d")
if isinstance(embedding_pca_3d_val, list):
data["embedding_pca_3d"] = embedding_pca_3d_val
else:
data["embedding_pca_3d"] = []
# Ensure migration_history is a list
migration_history_val = data.get("migration_history")
if isinstance(migration_history_val, list):
data["migration_history"] = migration_history_val
else:
data["migration_history"] = []
# Filter out keys not in Program fields to avoid TypeError with **data
program_fields = {f.name for f in cls.__dataclass_fields__.values()}
filtered_data = {k: v for k, v in data.items() if k in program_fields}
return cls(**filtered_data)
class ProgramDatabase:
"""
SQLite-backed database for storing and managing programs during an
evolutionary process.
Supports MAP-Elites style feature-based organization, island
populations, and an archive of elites.
"""
def __init__(
self,
config: DatabaseConfig,
embedding_model: str = "text-embedding-3-small",
read_only: bool = False,
):
self.config = config
self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
self.read_only = read_only
# Only create embedding client if not in read-only mode
# (e.g., WebUI doesn't need it for visualization)
if not read_only:
self.embedding_client = EmbeddingClient(model_name=embedding_model)
else:
self.embedding_client = None
self.last_iteration: int = 0
self.best_program_id: Optional[str] = None
self.beam_search_parent_id: Optional[str] = None
# For deferring expensive operations
self._schedule_migration: bool = False
# Initialize island manager (will be set after db connection)
self.island_manager: Optional[CombinedIslandManager] = None
db_path_str = getattr(self.config, "db_path", None)
if db_path_str:
db_file = Path(db_path_str).resolve()
if not read_only:
# Robustness check for unclean shutdown with WAL
db_wal_file = Path(f"{db_file}-wal")
db_shm_file = Path(f"{db_file}-shm")
if (
db_file.exists()
and db_file.stat().st_size == 0
and (db_wal_file.exists() or db_shm_file.exists())
):
logger.warning(
f"Database file {db_file} is empty but WAL/SHM files "
"exist. This may indicate an unclean shutdown. "
"Removing WAL/SHM files to attempt recovery."
)
if db_wal_file.exists():
db_wal_file.unlink()
if db_shm_file.exists():
db_shm_file.unlink()
db_file.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(str(db_file), timeout=30.0)
logger.debug(f"Connected to SQLite database: {db_file}")
else:
if not db_file.exists():
raise FileNotFoundError(
f"Database file not found for read-only connection: {db_file}"
)
db_uri = f"file:{db_file}?mode=ro"
self.conn = sqlite3.connect(db_uri, uri=True, timeout=30.0)
logger.debug(
"Connected to SQLite database in read-only mode: %s",
db_file,
)
else:
self.conn = sqlite3.connect(":memory:")
logger.info("Initialized in-memory SQLite database.")
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
if not self.read_only:
self._create_tables()
self._load_metadata_from_db()
# Initialize island manager now that database is ready
self.island_manager = CombinedIslandManager(
cursor=self.cursor,
conn=self.conn,
config=self.config,
)
count = self._count_programs_in_db()
logger.debug(f"DB initialized with {count} programs.")
logger.debug(
f"Last iter: {self.last_iteration}. Best ID: {self.best_program_id}"
)
def _create_tables(self):
if not self.cursor or not self.conn:
raise ConnectionError("DB not connected.")
# Set SQLite pragmas for better performance and stability
# Use WAL mode for better concurrency support and reduced locking
self.cursor.execute("PRAGMA journal_mode = WAL;")
self.cursor.execute("PRAGMA busy_timeout = 30000;") # 30 second busy timeout
self.cursor.execute(
"PRAGMA wal_autocheckpoint = 1000;"
) # Checkpoint every 1000 pages
self.cursor.execute("PRAGMA synchronous = NORMAL;") # Safer, faster
self.cursor.execute("PRAGMA cache_size = -64000;") # 64MB cache
self.cursor.execute("PRAGMA temp_store = MEMORY;")
self.cursor.execute("PRAGMA foreign_keys = ON;") # For data integrity
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS programs (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
language TEXT NOT NULL,
parent_id TEXT,
archive_inspiration_ids TEXT, -- JSON serialized List[str]
top_k_inspiration_ids TEXT, -- JSON serialized List[str]
generation INTEGER NOT NULL,
timestamp REAL NOT NULL,
code_diff TEXT, -- Stores edit difference
combined_score REAL,
public_metrics TEXT, -- JSON serialized Dict[str, Any]
private_metrics TEXT, -- JSON serialized Dict[str, Any]
text_feedback TEXT, -- Text feedback for the program
complexity REAL, -- Calculated complexity metric
embedding TEXT, -- JSON serialized List[float]
embedding_pca_2d TEXT, -- JSON serialized List[float]
embedding_pca_3d TEXT, -- JSON serialized List[float]
embedding_cluster_id INTEGER,
correct BOOLEAN DEFAULT 0, -- Correct (0=False, 1=True)
children_count INTEGER NOT NULL DEFAULT 0,
metadata TEXT, -- JSON serialized Dict[str, Any]
migration_history TEXT, -- JSON of migration events
island_idx INTEGER -- Add island_idx to the schema
)
"""
)
# Add indices for common query patterns
idx_cmds = [
"CREATE INDEX IF NOT EXISTS idx_programs_generation ON "
"programs(generation)",
"CREATE INDEX IF NOT EXISTS idx_programs_timestamp ON programs(timestamp)",
"CREATE INDEX IF NOT EXISTS idx_programs_complexity ON "
"programs(complexity)",
"CREATE INDEX IF NOT EXISTS idx_programs_parent_id ON programs(parent_id)",
"CREATE INDEX IF NOT EXISTS idx_programs_children_count ON "
"programs(children_count)",
"CREATE INDEX IF NOT EXISTS idx_programs_island_idx ON "
"programs(island_idx)",
]
for cmd in idx_cmds:
self.cursor.execute(cmd)
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS archive (
program_id TEXT PRIMARY KEY,
FOREIGN KEY (program_id) REFERENCES programs(id)
ON DELETE CASCADE
)
"""
)
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY, value TEXT
)
"""
)
self.conn.commit()
# Run any necessary migrations
self._run_migrations()
logger.debug("Database tables and indices ensured to exist.")
def _run_migrations(self):
"""Run database migrations for schema changes."""
if not self.cursor or not self.conn:
raise ConnectionError("DB not connected.")
# Migration 1: Add text_feedback column if it doesn't exist
try:
# Check if text_feedback column exists
self.cursor.execute("PRAGMA table_info(programs)")
columns = [row[1] for row in self.cursor.fetchall()]
if "text_feedback" not in columns:
logger.info("Adding text_feedback column to programs table")
self.cursor.execute(
"ALTER TABLE programs ADD COLUMN text_feedback TEXT DEFAULT ''"
)
self.conn.commit()
logger.info("Successfully added text_feedback column")
except sqlite3.Error as e:
logger.error(f"Error during text_feedback migration: {e}")
# Don't raise - this is not critical for existing functionality
@db_retry()
def _load_metadata_from_db(self):
if not self.cursor:
raise ConnectionError("DB cursor not available.")
self.cursor.execute(
"SELECT value FROM metadata_store WHERE key = 'last_iteration'"
)
row = self.cursor.fetchone()
self.last_iteration = (
int(row["value"]) if row and row["value"] is not None else 0
)
if not row or row["value"] is not None: # Initialize in DB if first time
if not self.read_only:
self._update_metadata_in_db("last_iteration", str(self.last_iteration))
self.cursor.execute(
"SELECT value FROM metadata_store WHERE key = 'best_program_id'"
)
row = self.cursor.fetchone()
self.best_program_id = (
str(row["value"])
if row and row["value"] is not None and row["value"] != "None"
else None
)
if (
not row or row["value"] is None or row["value"] == "None"
): # Initialize or clear if stored as 'None' string
if not self.read_only:
self._update_metadata_in_db("best_program_id", None)
self.cursor.execute(
"SELECT value FROM metadata_store WHERE key = 'beam_search_parent_id'"
)
row = self.cursor.fetchone()
self.beam_search_parent_id = (
str(row["value"])
if row and row["value"] is not None and row["value"] != "None"
else None
)
if not row or row["value"] is None or row["value"] == "None":
if not self.read_only:
self._update_metadata_in_db("beam_search_parent_id", None)
@db_retry()
def _update_metadata_in_db(self, key: str, value: Optional[str]):
if not self.cursor or not self.conn:
raise ConnectionError("DB not connected.")
self.cursor.execute(
"INSERT OR REPLACE INTO metadata_store (key, value) VALUES (?, ?)",
(key, value), # SQLite handles None as NULL
)
self.conn.commit()
@db_retry()
def _count_programs_in_db(self) -> int:
if not self.cursor:
return 0
self.cursor.execute("SELECT COUNT(*) FROM programs")
return (self.cursor.fetchone() or {"COUNT(*)": 0})["COUNT(*)"]
@db_retry()
def add(self, program: Program, verbose: bool = False) -> str:
"""
Add a program to the database with optimized performance.
This method uses batched transactions and defers expensive operations
to improve performance with large databases. After adding a program,
you should call check_scheduled_operations() to run any deferred
operations like migrations.
Example:
db.add(program) # Fast add
db.check_scheduled_operations() # Run deferred operations
Args:
program: The Program object to add
Returns:
str: The ID of the added program
"""
if self.read_only:
raise PermissionError("Cannot add program in read-only mode.")
if not self.cursor or not self.conn:
raise ConnectionError("DB not connected.")
self.island_manager.assign_island(program)
# Calculate complexity if not pre-set (or if default 0.0)
if program.complexity == 0.0:
try:
code_metrics = analyze_code_metrics(program.code, program.language)
program.complexity = code_metrics.get("complexity_score", 0.0)
if program.metadata is None:
program.metadata = {}
program.metadata["code_analysis_metrics"] = code_metrics
except Exception as e:
logger.warning(
f"Could not calculate complexity for program {program.id}: {e}"
)
program.complexity = float(len(program.code)) # Fallback to length
# Embedding is expected to be provided by the user.
# Ensure program.embedding is a list, even if empty.
if not isinstance(program.embedding, list):
logger.warning(
f"Program {program.id} embedding is not a list, "
"defaulting to empty list."
)
program.embedding = []
# Pre-serialize all JSON data once
public_metrics_json = json.dumps(program.public_metrics or {})
private_metrics_json = json.dumps(program.private_metrics or {})
metadata_json = json.dumps(program.metadata or {})
archive_insp_ids_json = json.dumps(program.archive_inspiration_ids or [])
top_k_insp_ids_json = json.dumps(program.top_k_inspiration_ids or [])
embedding_json = json.dumps(program.embedding) # Serialize embedding
embedding_pca_2d_json = json.dumps(program.embedding_pca_2d or [])
embedding_pca_3d_json = json.dumps(program.embedding_pca_3d or [])
migration_history_json = json.dumps(program.migration_history or [])
# Handle text_feedback - convert to string if it's a list
text_feedback_str = program.text_feedback
if isinstance(text_feedback_str, list):
# Join list items with newlines for readability
text_feedback_str = "\n".join(str(item) for item in text_feedback_str)
elif text_feedback_str is None:
text_feedback_str = ""
else:
text_feedback_str = str(text_feedback_str)
# Begin transaction - this improves performance by batching operations
self.conn.execute("BEGIN TRANSACTION")
try:
# Insert the program in a single operation
self.cursor.execute(
"""
INSERT INTO programs
(id, code, language, parent_id, archive_inspiration_ids,
top_k_inspiration_ids, generation, timestamp, code_diff,
combined_score, public_metrics, private_metrics,
text_feedback, complexity, embedding, embedding_pca_2d,
embedding_pca_3d, embedding_cluster_id, correct,
children_count, metadata, island_idx, migration_history)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?, ?, ?, ?)
""",
(
program.id,
program.code,
program.language,
program.parent_id,
archive_insp_ids_json,
top_k_insp_ids_json,
program.generation,
program.timestamp,
program.code_diff,
program.combined_score,
public_metrics_json,
private_metrics_json,
text_feedback_str,
program.complexity,
embedding_json, # Use serialized embedding
embedding_pca_2d_json,
embedding_pca_3d_json,
program.embedding_cluster_id,
program.correct,
program.children_count,
metadata_json,
program.island_idx,
migration_history_json,
),
)
# Increment parent's children_count
if program.parent_id:
self.cursor.execute(
"UPDATE programs SET children_count = children_count + 1 "
"WHERE id = ?",
(program.parent_id,),
)
# Commit the main program insertion and related operations
self.conn.commit()
logger.info(
"Program %s added to DB - score: %s.",
program.id,
program.combined_score,
)
except sqlite3.IntegrityError as e:
self.conn.rollback()
logger.error(f"IntegrityError for program {program.id}: {e}")
raise
except Exception as e:
self.conn.rollback()
logger.error(f"Error adding program {program.id}: {e}")
raise
self._update_archive(program)
# Update best program tracking
self._update_best_program(program)
# Recompute embeddings and clusters for all programs
self._recompute_embeddings_and_clusters()
# Update generation tracking
if program.generation > self.last_iteration:
self.last_iteration = program.generation
self._update_metadata_in_db("last_iteration", str(self.last_iteration))
# Print verbose summary if requested
if verbose:
self._print_program_summary(program)
# Check if this program needs to be copied to other islands
if self.island_manager.needs_island_copies(program):
logger.info(
f"Creating copies of initial program {program.id} for all islands"
)
self.island_manager.copy_program_to_islands(program)
# Remove the flag from the original program's metadata
if program.metadata:
program.metadata.pop("_needs_island_copies", None)
metadata_json = json.dumps(program.metadata)
self.cursor.execute(
"UPDATE programs SET metadata = ? WHERE id = ?",
(metadata_json, program.id),
)
self.conn.commit()
# Check if migration should be scheduled
if self.island_manager.should_schedule_migration(program):
self._schedule_migration = True
self.check_scheduled_operations()
return program.id
def _program_from_row(self, row: sqlite3.Row) -> Optional[Program]:
"""Helper to create a Program object from a database row."""
if not row:
return None
program_data = dict(row)
# Use faster json loads
public_metrics_text = program_data.get("public_metrics")
if public_metrics_text:
try:
program_data["public_metrics"] = json.loads(public_metrics_text)
except json.JSONDecodeError:
program_data["public_metrics"] = {}
else:
program_data["public_metrics"] = {}
private_metrics_text = program_data.get("private_metrics")
if private_metrics_text:
try:
program_data["private_metrics"] = json.loads(private_metrics_text)
except json.JSONDecodeError:
program_data["private_metrics"] = {}
else:
program_data["private_metrics"] = {}
# Same for metadata
metadata_text = program_data.get("metadata")
if metadata_text:
try:
program_data["metadata"] = json.loads(metadata_text)
except json.JSONDecodeError:
program_data["metadata"] = {}
else:
program_data["metadata"] = {}
# Handle text_feedback (simple string field)
if "text_feedback" not in program_data or program_data["text_feedback"] is None:
program_data["text_feedback"] = ""
# Handle inspiration_ids
archive_insp_ids_text = program_data.get("archive_inspiration_ids")
if archive_insp_ids_text:
try:
program_data["archive_inspiration_ids"] = json.loads(
archive_insp_ids_text
)
except json.JSONDecodeError:
program_data["archive_inspiration_ids"] = []
else:
program_data["archive_inspiration_ids"] = []
top_k_insp_ids_text = program_data.get("top_k_inspiration_ids")
if top_k_insp_ids_text:
try:
program_data["top_k_inspiration_ids"] = json.loads(top_k_insp_ids_text)
except json.JSONDecodeError:
logger.warning(
"Could not decode top_k_inspiration_ids for "
f"program {program_data.get('id')}. "
"Defaulting to empty list."
)
program_data["top_k_inspiration_ids"] = []
else:
program_data["top_k_inspiration_ids"] = []
# Handle embedding
embedding_text = program_data.get("embedding")
if embedding_text:
try:
program_data["embedding"] = json.loads(embedding_text)
except json.JSONDecodeError:
logger.warning(
f"Could not decode embedding for program "
f"{program_data.get('id')}. Defaulting to empty list."
)
program_data["embedding"] = []
else:
program_data["embedding"] = []
embedding_pca_2d_text = program_data.get("embedding_pca_2d")
if embedding_pca_2d_text:
try:
program_data["embedding_pca_2d"] = json.loads(embedding_pca_2d_text)
except json.JSONDecodeError:
program_data["embedding_pca_2d"] = []
else:
program_data["embedding_pca_2d"] = []
embedding_pca_3d_text = program_data.get("embedding_pca_3d")
if embedding_pca_3d_text:
try:
program_data["embedding_pca_3d"] = json.loads(embedding_pca_3d_text)
except json.JSONDecodeError:
program_data["embedding_pca_3d"] = []
else:
program_data["embedding_pca_3d"] = []
# Handle migration_history
migration_history_text = program_data.get("migration_history")
if migration_history_text:
try:
program_data["migration_history"] = json.loads(migration_history_text)
except json.JSONDecodeError:
logger.warning(
f"Could not decode migration_history for program "
f"{program_data.get('id')}. Defaulting to empty list."
)
program_data["migration_history"] = []
else:
program_data["migration_history"] = []
# Handle archive status
program_data["in_archive"] = bool(program_data.get("in_archive", 0))
return Program.from_dict(program_data)
@db_retry()
def get(self, program_id: str) -> Optional[Program]:
"""Get a program by its ID with optimized JSON operations."""
if not self.cursor:
raise ConnectionError("DB not connected.")
self.cursor.execute("SELECT * FROM programs WHERE id = ?", (program_id,))
row = self.cursor.fetchone()
return self._program_from_row(row)
@db_retry()
def sample(
self,
target_generation=None,
novelty_attempt=None,
max_novelty_attempts=None,
resample_attempt=None,
max_resample_attempts=None,
) -> Tuple[Program, List[Program], List[Program]]:
if not self.cursor:
raise ConnectionError("DB not connected.")
# Check if all islands are initialized
if not self.island_manager.are_all_islands_initialized():
# Get initial program (first program in database)
self.cursor.execute("SELECT * FROM programs ORDER BY timestamp ASC LIMIT 1")
row = self.cursor.fetchone()
if not row:
raise RuntimeError("No programs found in database")
parent = self._program_from_row(row)
if not parent:
raise RuntimeError("Failed to load initial program")
logger.info(
f"Not all islands initialized. Using initial program {parent.id} "
"without inspirations."
)
# Print sampling summary
self._print_sampling_summary_helper(
parent,
[],
[],
target_generation,
novelty_attempt,
max_novelty_attempts,
resample_attempt,
max_resample_attempts,
)
return parent, [], []
# All islands initialized - sample island + constrain parents
initialized_islands = self.island_manager.get_initialized_islands()
sampled_island = random.choice(initialized_islands)
logger.debug(f"Sampling from island {sampled_island}")
# Use CombinedParentSelector with island constraint
parent_selector = CombinedParentSelector(
cursor=self.cursor,
conn=self.conn,
config=self.config,
get_program_func=self.get,
best_program_id=self.best_program_id,
beam_search_parent_id=self.beam_search_parent_id,
last_iteration=self.last_iteration,
update_metadata_func=self._update_metadata_in_db,
get_best_program_func=self.get_best_program,
)
parent = parent_selector.sample_parent(island_idx=sampled_island)
if not parent:
raise RuntimeError(f"Failed to sample parent from island {sampled_island}")
num_archive_insp = (
self.config.num_archive_inspirations
if hasattr(self.config, "num_archive_inspirations")
else 5
)
num_top_k_insp = (
self.config.num_top_k_inspirations
if hasattr(self.config, "num_top_k_inspirations")
else 2
)
# Use the combined context selector
context_selector = CombinedContextSelector(
cursor=self.cursor,
conn=self.conn,
config=self.config,
get_program_func=self.get,
best_program_id=self.best_program_id,
get_island_idx_func=self.island_manager.get_island_idx,
program_from_row_func=self._program_from_row,
)
archive_inspirations, top_k_inspirations = context_selector.sample_context(
parent, num_archive_insp, num_top_k_insp
)
logger.debug(
f"Sampled parent {parent.id} from island {sampled_island}, "
f"{len(archive_inspirations)} archive inspirations, "
f"{len(top_k_inspirations)} top-k inspirations."
)
# Print sampling summary
self._print_sampling_summary_helper(
parent,
archive_inspirations,
top_k_inspirations,
target_generation,
novelty_attempt,
max_novelty_attempts,
resample_attempt,
max_resample_attempts,
)
return parent, archive_inspirations, top_k_inspirations
def _print_sampling_summary_helper(
self,
parent,
archive_inspirations,
top_k_inspirations,
target_generation=None,
novelty_attempt=None,
max_novelty_attempts=None,
resample_attempt=None,
max_resample_attempts=None,
):
"""Helper method to print sampling summary."""
if not hasattr(self, "_database_display"):
self._database_display = DatabaseDisplay(
cursor=self.cursor,
conn=self.conn,
config=self.config,
island_manager=self.island_manager,
count_programs_func=self._count_programs_in_db,
get_best_program_func=self.get_best_program,
)
self._database_display.print_sampling_summary(
parent,
archive_inspirations,
top_k_inspirations,
target_generation,
novelty_attempt,
max_novelty_attempts,
resample_attempt,
max_resample_attempts,
)
@db_retry()
def get_best_program(self, metric: Optional[str] = None) -> Optional[Program]:
if not self.cursor:
raise ConnectionError("DB not connected.")
# Attempt to use tracked best_program_id first if no specific metric
if metric is None and self.best_program_id:
program = self.get(self.best_program_id)
if program and program.correct: # Ensure best program is correct
return program
else: # Stale ID or incorrect program
logger.warning(
f"Tracked best_program_id '{self.best_program_id}' "
"not found or incorrect. Re-evaluating."
)
if not self.read_only:
self._update_metadata_in_db("best_program_id", None)
self.best_program_id = None
# Fetch only correct programs and sort in Python.
self.cursor.execute("SELECT * FROM programs WHERE correct = 1")
all_rows = self.cursor.fetchall()
if not all_rows:
logger.debug("No correct programs found in database.")
return None
programs = []
for row_data in all_rows:
p_dict = dict(row_data)
p_dict["public_metrics"] = (
json.loads(p_dict["public_metrics"])
if p_dict.get("public_metrics")
else {}
)
p_dict["private_metrics"] = (
json.loads(p_dict["private_metrics"])
if p_dict.get("private_metrics")
else {}
)
p_dict["metadata"] = (
json.loads(p_dict["metadata"]) if p_dict.get("metadata") else {}
)
programs.append(Program.from_dict(p_dict))
if not programs:
return None
sorted_p: List[Program] = []
log_key = "average metrics"
if metric:
progs_with_metric = [
p for p in programs if p.public_metrics and metric in p.public_metrics
]
sorted_p = sorted(
progs_with_metric,
key=lambda p_item: p_item.public_metrics.get(metric, -float("inf")),
reverse=True,
)
log_key = f"metric '{metric}'"
elif any(p.combined_score is not None for p in programs):
progs_with_cs = [p for p in programs if p.combined_score is not None]
sorted_p = sorted(
progs_with_cs,
key=lambda p_item: p_item.combined_score or -float("inf"),
reverse=True,
)
log_key = "combined_score"
else:
progs_with_metrics = [p for p in programs if p.public_metrics]
sorted_p = sorted(
progs_with_metrics,
key=lambda p_item: sum(p_item.public_metrics.values())
/ len(p_item.public_metrics)
if p_item.public_metrics
else -float("inf"),
reverse=True,
)
if not sorted_p:
logger.debug("No correct programs matched criteria for get_best_program.")
return None
best_overall = sorted_p[0]
logger.debug(f"Best correct program by {log_key}: {best_overall.id}")
if self.best_program_id != best_overall.id: # Update ID if different
logger.info(
"Updating tracked best program from "
f"'{self.best_program_id}' to '{best_overall.id}'."
)
self.best_program_id = best_overall.id
if not self.read_only:
self._update_metadata_in_db("best_program_id", self.best_program_id)
return best_overall
@db_retry()
def get_all_programs(self) -> List[Program]:
"""Get all programs from the database."""
if not self.cursor:
raise ConnectionError("DB not connected.")
self.cursor.execute(
"""
SELECT p.*,
CASE WHEN a.program_id IS NOT NULL THEN 1 ELSE 0 END as in_archive
FROM programs p
LEFT JOIN archive a ON p.id = a.program_id
"""
)
rows = self.cursor.fetchall()
programs = [self._program_from_row(row) for row in rows]
# Filter out any None values that might result from row processing errors
return [p for p in programs if p is not None]
@db_retry()
def get_programs_by_generation(self, generation: int) -> List[Program]:
"""Get all programs from a specific generation."""
if not self.cursor:
raise ConnectionError("DB not connected.")
self.cursor.execute(
"SELECT * FROM programs WHERE generation = ?", (generation,)
)
rows = self.cursor.fetchall()
programs = [self._program_from_row(row) for row in rows]
return [p for p in programs if p is not None]
@db_retry()
def get_recent_programs(self, n: int = 10) -> List[Program]:
"""Get N most recent programs, ordered by generation DESC, timestamp DESC."""
if not self.cursor:
raise ConnectionError("DB not connected.")
self.cursor.execute(
"SELECT * FROM programs ORDER BY generation DESC, timestamp DESC LIMIT ?",
(n,),
)
rows = self.cursor.fetchall()
programs = [self._program_from_row(row) for row in rows]
return [p for p in programs if p is not None]
@db_retry()
def get_top_programs(
self,
n: int = 10,
metric: Optional[str] = "combined_score",
correct_only: bool = False,
) -> List[Program]:
"""Get top programs, using SQL for sorting when possible."""
if not self.cursor:
raise ConnectionError("DB not connected.")
# Add correctness filter to WHERE clause if requested
correctness_filter = "WHERE correct = 1" if correct_only else ""
# Try to use SQL for sorting when possible for better performance
if metric == "combined_score":
# Use SQLite's json_extract for better performance
base_query = """
SELECT * FROM programs
WHERE combined_score IS NOT NULL
"""
if correct_only:
base_query += " AND correct = 1"
base_query += " ORDER BY combined_score DESC LIMIT ?"
self.cursor.execute(base_query, (n,))
all_rows = self.cursor.fetchall()
elif metric == "timestamp":
# Direct timestamp sorting
query = (
f"SELECT * FROM programs {correctness_filter} "
"ORDER BY timestamp DESC LIMIT ?"
)
self.cursor.execute(query, (n,))
all_rows = self.cursor.fetchall()
else:
# Fall back to Python sorting for complex cases
query = f"SELECT * FROM programs {correctness_filter}"
self.cursor.execute(query)
all_rows = self.cursor.fetchall()
if not all_rows:
return []
# Process results
programs = []
for row_data in all_rows:
p_dict = dict(row_data)
# Optimize JSON parsing
public_metrics_text = p_dict.get("public_metrics")
if public_metrics_text:
try:
p_dict["public_metrics"] = json.loads(public_metrics_text)
except json.JSONDecodeError:
p_dict["public_metrics"] = {}
else:
p_dict["public_metrics"] = {}
private_metrics_text = p_dict.get("private_metrics")
if private_metrics_text:
try:
p_dict["private_metrics"] = json.loads(private_metrics_text)
except json.JSONDecodeError:
p_dict["private_metrics"] = {}
else:
p_dict["private_metrics"] = {}
metadata_text = p_dict.get("metadata")
if metadata_text:
try:
p_dict["metadata"] = json.loads(metadata_text)
except json.JSONDecodeError:
p_dict["metadata"] = {}
else:
p_dict["metadata"] = {}
# Create program object
programs.append(Program.from_dict(p_dict))
# If we already have the sorted programs from SQL, just return them
if metric in ["combined_score", "timestamp"] and programs:
return programs[:n]
# Otherwise, sort in Python
if programs:
if metric:
progs_with_metric = [
p
for p in programs
if p.public_metrics and metric in p.public_metrics
]
sorted_p = sorted(
progs_with_metric,
key=lambda p_item: p_item.public_metrics.get(metric, -float("inf")),
reverse=True,
)
else: # Default: average metrics
progs_with_metrics = [p for p in programs if p.public_metrics]
sorted_p = sorted(
progs_with_metrics,
key=lambda p_item: sum(p_item.public_metrics.values())
/ len(p_item.public_metrics)
if p_item.public_metrics
else -float("inf"),
reverse=True,
)
return sorted_p[:n]
return []
def save(self, path: Optional[str] = None) -> None:
if not self.conn or not self.cursor:
logger.warning("No DB connection, skipping save.")
return
# Main purpose here is to save/commit metadata like last_iteration.
current_db_file_path_str = self.config.db_path
if path and current_db_file_path_str:
if Path(path).resolve() != Path(current_db_file_path_str).resolve():
logger.warning(
f"Save path '{path}' differs from connected DB "
f"'{current_db_file_path_str}'. Metadata saved to "
"connected DB."
)
elif path and not current_db_file_path_str:
logger.warning(
f"Attempting to save with path '{path}' but current "
"database is in-memory. Metadata will be committed to the "
"in-memory instance."
)
self._update_metadata_in_db("last_iteration", str(self.last_iteration))
self.conn.commit() # Commit any pending transactions
logger.info(
f"Database state committed. Last iteration: "
f"{self.last_iteration}. Best: {self.best_program_id}"
)
def load(self, path: str) -> None:
logger.info(f"Loading database from '{path}'...")
if self.conn:
db_display_name = self.config.db_path or ":memory:"
logger.info(f"Closing existing connection to '{db_display_name}'.")
self.conn.close()
db_path_obj = Path(path).resolve()
# Robustness check for unclean shutdown with WAL
db_wal_file = Path(f"{db_path_obj}-wal")
db_shm_file = Path(f"{db_path_obj}-shm")
if (
db_path_obj.exists()
and db_path_obj.stat().st_size == 0
and (db_wal_file.exists() or db_shm_file.exists())
):
logger.warning(
f"Database file {db_path_obj} is empty but WAL/SHM files "
"exist. This may indicate an unclean shutdown. Removing "
"WAL/SHM files to attempt recovery.",
db_path_obj,
)
if db_wal_file.exists():
db_wal_file.unlink()
if db_shm_file.exists():
db_shm_file.unlink()
self.config.db_path = str(db_path_obj) # Update config
if not db_path_obj.exists():
logger.warning(
f"DB file '{db_path_obj}' not found. New DB created if writes occur."
)
db_path_obj.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(str(db_path_obj), timeout=30.0)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
self._create_tables()
self._load_metadata_from_db()
count = self._count_programs_in_db()
logger.info(
f"Loaded DB from '{db_path_obj}'. {count} programs. "
f"Last iter: {self.last_iteration}."
)
def _is_better(self, program1: Program, program2: Program) -> bool:
# First prioritize correctness
if program1.correct and not program2.correct:
return True
if program2.correct and not program1.correct:
return False
# If both have same correctness status, compare scores
s1 = program1.combined_score
s2 = program2.combined_score
if s1 is not None and s2 is not None:
if s1 != s2:
return s1 > s2
elif s1 is not None:
return True # p1 has score, p2 doesn't
elif s2 is not None:
return False # p2 has score, p1 doesn't
try:
avg1 = (
sum(program1.public_metrics.values()) / len(program1.public_metrics)
if program1.public_metrics
else -float("inf")
)
avg2 = (
sum(program2.public_metrics.values()) / len(program2.public_metrics)
if program2.public_metrics
else -float("inf")
)
if avg1 != avg2:
return avg1 > avg2
except Exception:
return False
return program1.timestamp > program2.timestamp # Tie-breaker
@db_retry()
def _update_archive(self, program: Program) -> None:
if (
not self.cursor
or not self.conn
or not hasattr(self.config, "archive_size")
or self.config.archive_size <= 0
):
logger.debug("Archive update skipped (config/DB issue or size <= 0).")
return
# Add programs with positive score to the archive (not just correct ones)
# This ensures plateau problems have archive diversity for inspiration
if not program.combined_score or program.combined_score <= 0:
logger.debug(f"Program {program.id} not added to archive (score <= 0).")
return
self.cursor.execute("SELECT COUNT(*) FROM archive")
count = (self.cursor.fetchone() or [0])[0]
if count < self.config.archive_size:
self.cursor.execute(
"INSERT OR IGNORE INTO archive (program_id) VALUES (?)",
(program.id,),
)
else: # Archive is full, find worst to replace
self.cursor.execute(
"SELECT a.program_id, p.combined_score, p.timestamp, p.correct "
"FROM archive a JOIN programs p ON a.program_id = p.id"
)
archived_rows = self.cursor.fetchall()
if not archived_rows: # Should not happen if count was > 0
self.cursor.execute(
"INSERT OR IGNORE INTO archive (program_id) VALUES (?)",
(program.id,),
)
self.conn.commit()
return
archive_programs_for_cmp = []
for r_data in archived_rows:
# Create minimal Program-like dict for _is_better
combined_score_val = r_data["combined_score"]
# This is a simplified way, _is_better needs Program objects
# For full Program object: self.get(r_data["program_id"]) but could be slow
archive_programs_for_cmp.append(
Program(
id=r_data["program_id"],
code="",
combined_score=combined_score_val,
timestamp=r_data["timestamp"],
correct=bool(r_data["correct"]),
)
)
if (
not archive_programs_for_cmp
): # Should be populated if archived_rows existed
self.cursor.execute(
"INSERT OR IGNORE INTO archive (program_id) VALUES (?)",
(program.id,),
)
self.conn.commit()
return
worst_in_archive = archive_programs_for_cmp[0]
for p_archived in archive_programs_for_cmp[1:]:
if self._is_better(worst_in_archive, p_archived):
worst_in_archive = p_archived
if self._is_better(program, worst_in_archive):
self.cursor.execute(
"DELETE FROM archive WHERE program_id = ?",
(worst_in_archive.id,),
)
self.cursor.execute(
"INSERT INTO archive (program_id) VALUES (?)", (program.id,)
)
logger.info(
f"Program {program.id} replaced {worst_in_archive.id} in archive."
)
self.conn.commit()
@db_retry()
def _update_best_program(self, program: Program) -> None:
# Only consider correct programs for best program tracking
if not program.correct:
logger.debug(f"Program {program.id} not considered for best (not correct).")
return
current_best_p = None
if self.best_program_id:
current_best_p = self.get(self.best_program_id)
if current_best_p is None or self._is_better(program, current_best_p):
self.best_program_id = program.id
self._update_metadata_in_db("best_program_id", self.best_program_id)
log_msg = f"New best program: {program.id}"
if current_best_p:
p1_score = program.combined_score or 0.0
p2_score = current_best_p.combined_score or 0.0
log_msg += (
f" (gen: {current_best_p.generation}{program.generation}, "
f"score: {p2_score:.4f}{p1_score:.4f}, "
f"island: {current_best_p.island_idx}{program.island_idx})"
)
else:
score = program.combined_score or 0.0
log_msg += (
f" (gen: {program.generation}, score: {score:.4f}, initialized "
f"island: {program.island_idx})."
)
logger.info(log_msg)
def print_summary(self, console=None) -> None:
"""Print a summary of the database contents using DatabaseDisplay."""
if not hasattr(self, "_database_display"):
self._database_display = DatabaseDisplay(
cursor=self.cursor,
conn=self.conn,
config=self.config,
island_manager=self.island_manager,
count_programs_func=self._count_programs_in_db,
get_best_program_func=self.get_best_program,
)
self._database_display.set_last_iteration(self.last_iteration)
self._database_display.print_summary(console)
def _print_program_summary(self, program) -> None:
"""Print a rich summary of a newly added program using DatabaseDisplay."""
if not hasattr(self, "_database_display"):
self._database_display = DatabaseDisplay(
cursor=self.cursor,
conn=self.conn,
config=self.config,
island_manager=self.island_manager,
count_programs_func=self._count_programs_in_db,
get_best_program_func=self.get_best_program,
)
self._database_display.print_program_summary(program)
def check_scheduled_operations(self):
"""Run any operations that were scheduled during add but deferred for performance."""
if self._schedule_migration:
logger.info("Running scheduled migration operation")
self.island_manager.perform_migration(self.last_iteration)
self._schedule_migration = False
def close(self):
"""Closes the database connection."""
if self.conn:
self.conn.close()
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""Compute cosine similarity between two vectors."""
if not vec1 or not vec2 or len(vec1) != len(vec2):
return 0.0
arr1 = np.array(vec1, dtype=np.float32)
arr2 = np.array(vec2, dtype=np.float32)
norm_a = np.linalg.norm(arr1)
norm_b = np.linalg.norm(arr2)
if norm_a == 0 or norm_b == 0:
return 0.0
similarity = np.dot(arr1, arr2) / (norm_a * norm_b)
return float(similarity)
@db_retry()
def compute_similarity_thread_safe(
self, vec: List[float], island_idx: int
) -> List[float]:
"""
Thread-safe version of similarity computation. Creates its own DB connection.
"""
conn = None
try:
# Create a new connection for this thread
conn = sqlite3.connect(
self.config.db_path, check_same_thread=False, timeout=60.0
)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute(
"SELECT embedding FROM programs WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'",
(island_idx,),
)
rows = cursor.fetchall()
if not rows:
return []
similarities = []
for row in rows:
db_embedding = json.loads(row["embedding"])
if db_embedding:
sim = self._cosine_similarity(vec, db_embedding)
similarities.append(sim)
return similarities
except Exception as e:
logger.error(f"Thread-safe similarity computation failed: {e}")
raise
finally:
if conn:
conn.close()
@db_retry()
def compute_similarity(
self, code_embedding: List[float], island_idx: int
) -> List[float]:
"""
Compute similarity scores between the given embedding and all programs
in the specified island.
Args:
code_embedding: The embedding to compare against
island_idx: The island index to constrain the search to
Returns:
List of similarity scores (cosine similarity between 0 and 1)
"""
if not self.cursor:
raise ConnectionError("DB not connected.")
if not code_embedding:
logger.warning("Empty code embedding provided to compute_similarity")
return []
# Get all programs in the specified island that have embeddings
self.cursor.execute(
"""
SELECT id, embedding FROM programs
WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'
""",
(island_idx,),
)
rows = self.cursor.fetchall()
if not rows:
logger.debug(f"No programs with embeddings found in island {island_idx}")
return []
# Extract embeddings and compute similarities
similarity_scores = []
for row in rows:
try:
embedding = json.loads(row["embedding"])
if embedding: # Skip empty embeddings
similarity = self._cosine_similarity(code_embedding, embedding)
similarity_scores.append(similarity)
else:
similarity_scores.append(0.0)
except json.JSONDecodeError:
logger.warning(f"Could not decode embedding for program {row['id']}")
similarity_scores.append(0.0)
continue
logger.debug(
f"Computed {len(similarity_scores)} similarity scores for "
f"island {island_idx}"
)
return similarity_scores
@db_retry()
def get_most_similar_program(
self, code_embedding: List[float], island_idx: int
) -> Optional[Program]:
"""
Get the most similar program to the given embedding in the specified island.
Args:
code_embedding: The embedding to compare against
island_idx: The island index to constrain the search to
Returns:
The most similar Program object, or None if no programs found
"""
if not self.cursor:
raise ConnectionError("DB not connected.")
if not code_embedding:
logger.warning("Empty code embedding provided to get_most_similar_program")
return None
# Get all programs in the specified island that have embeddings
self.cursor.execute(
"""
SELECT id, embedding FROM programs
WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'
""",
(island_idx,),
)
rows = self.cursor.fetchall()
if not rows:
logger.debug(f"No programs with embeddings found in island {island_idx}")
return None
# Find the program with highest similarity
max_similarity = -1.0
most_similar_id = None
for row in rows:
try:
embedding = json.loads(row["embedding"])
if embedding: # Skip empty embeddings
similarity = self._cosine_similarity(code_embedding, embedding)
if similarity > max_similarity:
max_similarity = similarity
most_similar_id = row["id"]
except json.JSONDecodeError:
logger.warning(f"Could not decode embedding for program {row['id']}")
continue
if most_similar_id:
return self.get(most_similar_id)
return None
@db_retry()
def get_most_similar_program_thread_safe(
self, code_embedding: List[float], island_idx: int
) -> Optional[Program]:
"""
Thread-safe version of get_most_similar_program that creates its own DB connection.
Args:
code_embedding: The embedding to compare against
island_idx: The island index to constrain the search to
Returns:
The most similar Program object, or None if not found
"""
if not code_embedding:
logger.warning(
"Empty code embedding provided to get_most_similar_program_thread_safe"
)
return None
conn = None
try:
# Create a new connection for this thread
conn = sqlite3.connect(
self.config.db_path, check_same_thread=False, timeout=60.0
)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Get all programs in the specified island that have embeddings
cursor.execute(
"""
SELECT id, embedding FROM programs
WHERE island_idx = ? AND embedding IS NOT NULL AND embedding != '[]'
""",
(island_idx,),
)
rows = cursor.fetchall()
if not rows:
return None
# Compute similarities
import numpy as np
similarities = []
program_ids = []
for row in rows:
try:
embedding = json.loads(row["embedding"])
if embedding: # Check if embedding is not empty
similarity = np.dot(code_embedding, embedding) / (
np.linalg.norm(code_embedding) * np.linalg.norm(embedding)
)
similarities.append(similarity)
program_ids.append(row["id"])
except (json.JSONDecodeError, ValueError, ZeroDivisionError) as e:
logger.warning(
f"Error computing similarity for program {row['id']}: {e}"
)
continue
if not similarities:
return None
# Find the most similar program
max_similarity_idx = np.argmax(similarities)
most_similar_id = program_ids[max_similarity_idx]
# Get the full program data
cursor.execute("SELECT * FROM programs WHERE id = ?", (most_similar_id,))
row = cursor.fetchone()
if row:
return self._program_from_row(row)
return None
except Exception as e:
logger.error(f"Error in get_most_similar_program_thread_safe: {e}")
return None
finally:
if conn:
conn.close()
@db_retry()
def _recompute_embeddings_and_clusters(self, num_clusters: int = 4):
if self.read_only:
return
if not self.cursor or not self.conn:
raise ConnectionError("DB not connected.")
self.cursor.execute(
"SELECT id, embedding FROM programs "
"WHERE embedding IS NOT NULL AND embedding != '[]'"
)
rows = self.cursor.fetchall()
if len(rows) < num_clusters:
logger.info(
f"Not enough programs with embeddings ({len(rows)}) to "
f"perform clustering. Need at least {num_clusters}."
)
return
program_ids = [row["id"] for row in rows]
embeddings = [json.loads(row["embedding"]) for row in rows]
# Use EmbeddingClient for dim reduction and clustering
try:
logger.info(
"Recomputing PCA-reduced embedding features for %s programs.",
len(program_ids),
)
reduced_2d = self.embedding_client.get_dim_reduction(
embeddings, method="pca", dims=2
)
reduced_3d = self.embedding_client.get_dim_reduction(
embeddings, method="pca", dims=3
)
cluster_ids = self.embedding_client.get_embedding_clusters(
embeddings, num_clusters=num_clusters
)
except Exception as e:
logger.error(f"Failed to recompute embedding features: {e}")
return
# Update all programs in a single transaction
self.conn.execute("BEGIN TRANSACTION")
try:
for i, program_id in enumerate(program_ids):
embedding_pca_2d_json = json.dumps(reduced_2d[i].tolist())
embedding_pca_3d_json = json.dumps(reduced_3d[i].tolist())
cluster_id = int(cluster_ids[i])
self.cursor.execute(
"""
UPDATE programs
SET embedding_pca_2d = ?,
embedding_pca_3d = ?,
embedding_cluster_id = ?
WHERE id = ?
""",
(
embedding_pca_2d_json,
embedding_pca_3d_json,
cluster_id,
program_id,
),
)
self.conn.commit()
logger.info(
"Successfully updated embedding features for %s programs.",
len(program_ids),
)
except Exception as e:
self.conn.rollback()
logger.error("Failed to update programs with new embedding features: %s", e)
@db_retry()
def _recompute_embeddings_and_clusters_thread_safe(self, num_clusters: int = 4):
"""
Thread-safe version of embedding recomputation. Creates its own DB connection.
"""
if self.read_only:
return
conn = None
try:
# Create a new connection for this thread
conn = sqlite3.connect(
self.config.db_path, check_same_thread=False, timeout=60.0
)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute(
"SELECT id, embedding FROM programs "
"WHERE embedding IS NOT NULL AND embedding != '[]'"
)
rows = cursor.fetchall()
if len(rows) < num_clusters:
if len(rows) > 0:
logger.info(
f"Not enough programs with embeddings ({len(rows)}) to "
f"perform clustering. Need at least {num_clusters}."
)
return
program_ids = [row["id"] for row in rows]
embeddings = [json.loads(row["embedding"]) for row in rows]
# Use EmbeddingClient for dim reduction and clustering
try:
logger.info(
"Recomputing PCA-reduced embedding features for %s programs.",
len(program_ids),
)
logger.info("Computing 2D PCA reduction...")
reduced_2d = self.embedding_client.get_dim_reduction(
embeddings, method="pca", dims=2
)
logger.info("2D PCA reduction completed")
logger.info("Computing 3D PCA reduction...")
reduced_3d = self.embedding_client.get_dim_reduction(
embeddings, method="pca", dims=3
)
logger.info("3D PCA reduction completed")
logger.info(f"Computing GMM clustering with {num_clusters} clusters...")
cluster_ids = self.embedding_client.get_embedding_clusters(
embeddings, num_clusters=num_clusters
)
logger.info("GMM clustering completed")
except Exception as e:
logger.error(f"Failed to recompute embedding features: {e}")
return
# Update all programs in a single transaction
conn.execute("BEGIN TRANSACTION")
try:
for i, program_id in enumerate(program_ids):
embedding_pca_2d_json = json.dumps(reduced_2d[i].tolist())
embedding_pca_3d_json = json.dumps(reduced_3d[i].tolist())
cluster_id = int(cluster_ids[i])
cursor.execute(
"""
UPDATE programs
SET embedding_pca_2d = ?,
embedding_pca_3d = ?,
embedding_cluster_id = ?
WHERE id = ?
""",
(
embedding_pca_2d_json,
embedding_pca_3d_json,
cluster_id,
program_id,
),
)
conn.commit()
logger.info(
"Successfully updated embedding features for %s programs.",
len(program_ids),
)
except Exception as e:
conn.rollback()
logger.error(
"Failed to update programs with new embedding features: %s", e
)
raise # Re-raise exception
except Exception as e:
logger.error(f"Thread-safe embedding recomputation failed: {e}")
raise # Re-raise exception
finally:
if conn:
conn.close()
@db_retry()
def get_programs_by_generation_thread_safe(self, generation: int) -> List[Program]:
"""Thread-safe version of get_programs_by_generation."""
conn = None
try:
conn = sqlite3.connect(
self.config.db_path, check_same_thread=False, timeout=60.0
)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("SELECT * FROM programs WHERE generation = ?", (generation,))
rows = cursor.fetchall()
programs = []
for row in rows:
if not row:
continue
program_data = dict(row)
# Manually handle JSON deserialization for thread safety
for key, value in program_data.items():
if key in [
"public_metrics",
"private_metrics",
"metadata",
"archive_inspiration_ids",
"top_k_inspiration_ids",
"embedding",
"embedding_pca_2d",
"embedding_pca_3d",
"migration_history",
] and isinstance(value, str):
try:
program_data[key] = json.loads(value)
except json.JSONDecodeError:
program_data[key] = {} if key.endswith("_metrics") else []
programs.append(Program(**program_data))
return programs
finally:
if conn:
conn.close()
@db_retry()
def get_top_programs_thread_safe(
self,
n: int = 10,
correct_only: bool = True,
) -> List[Program]:
"""Thread-safe version of get_top_programs."""
conn = None
try:
conn = sqlite3.connect(
self.config.db_path, check_same_thread=False, timeout=60.0
)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# Use combined_score for sorting
base_query = """
SELECT * FROM programs
WHERE combined_score IS NOT NULL
"""
if correct_only:
base_query += " AND correct = 1"
base_query += " ORDER BY combined_score DESC LIMIT ?"
cursor.execute(base_query, (n,))
all_rows = cursor.fetchall()
if not all_rows:
return []
# Process results
programs = []
for row_data in all_rows:
program_data = dict(row_data)
# Manually handle JSON deserialization for thread safety
json_fields = [
"public_metrics",
"private_metrics",
"metadata",
"archive_inspiration_ids",
"top_k_inspiration_ids",
"embedding",
"embedding_pca_2d",
"embedding_pca_3d",
"migration_history",
]
for key, value in program_data.items():
if key in json_fields and isinstance(value, str):
try:
program_data[key] = json.loads(value)
except json.JSONDecodeError:
is_dict_field = (
key.endswith("_metrics") or key == "metadata"
)
program_data[key] = {} if is_dict_field else []
# Handle text_feedback
if (
"text_feedback" not in program_data
or program_data["text_feedback"] is None
):
program_data["text_feedback"] = ""
programs.append(Program.from_dict(program_data))
return programs
finally:
if conn:
conn.close()
def _get_programs_for_island(self, island_idx: int) -> List[Program]:
"""
Get all programs for a specific island.
"""