bio-acdc / core.py
AliSaadatV's picture
Upload core.py
d08bd7c verified
"""
Core Bio-ACDC optimization loop.
Manages the coevolution of biological LM populations and sequence tasks.
"""
import os
import json
import logging
import numpy as np
import torch
from typing import List, Dict, Optional, Tuple, Any
from dataclasses import dataclass, field
from pathlib import Path
from .archive import BioArchive, BioSolution
from .tasks import BioTaskPool
from .evaluator import BioEvaluator
from .mergers import BioModelMerger
from .mutators import BioMutator
logger = logging.getLogger(__name__)
@dataclass
class BioACDCConfig:
"""Configuration for Bio-ACDC optimization."""
# Population settings
archive_size: int = 50
num_generations: int = 100
offspring_per_gen: int = 10
# Model settings
seed_model_paths: List[str] = field(default_factory=list)
base_model_path: Optional[str] = None
# Task settings
num_tasks_per_eval: int = 20
task_difficulty_threshold: float = 0.5
# Evolution settings
use_mutation: bool = True
mutation_rate: float = 0.3
# Evaluation settings
eval_batch_size: int = 8
max_sequence_length: int = 1024
# Output
output_dir: str = "./bio_acdc_output"
save_frequency: int = 5
# Device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
class BioACDC:
"""
Main Bio-ACDC optimization controller.
Coevolves a population of biological language models with synthetic
sequence tasks to discover specialized experts.
"""
def __init__(
self,
config: BioACDCConfig,
task_pool: BioTaskPool,
evaluator: BioEvaluator,
merger: BioModelMerger,
mutator: Optional[BioMutator] = None,
):
self.config = config
self.task_pool = task_pool
self.evaluator = evaluator
self.merger = merger
self.mutator = mutator
# Initialize archive
self.archive = BioArchive(
archive_size=config.archive_size,
k_neighbors=5,
dominated_score=999.0,
)
# Setup output directories
self.output_dir = Path(config.output_dir)
self.model_dir = self.output_dir / "models"
self.archive_dir = self.output_dir / "archives"
self.task_dir = self.output_dir / "tasks"
self.eval_dir = self.output_dir / "evaluations"
for d in [self.model_dir, self.archive_dir, self.task_dir, self.eval_dir]:
d.mkdir(parents=True, exist_ok=True)
# RNG
self.rng = np.random.RandomState(42)
logger.info(f"Bio-ACDC initialized with config: {config}")
def initialize_seed_population(self) -> None:
"""Load and evaluate seed models into the archive."""
logger.info(f"Initializing seed population from {len(self.config.seed_model_paths)} models")
for model_path in self.config.seed_model_paths:
logger.info(f"Evaluating seed model: {model_path}")
# Get current task pool
tasks = self.task_pool.get_tasks(n=self.config.num_tasks_per_eval)
# Evaluate model on tasks
metrics = self.evaluator.evaluate_model(model_path, tasks)
# Create solution
skill_vector = {t.task_id: metrics.get(t.task_id, 0.0) for t in tasks}
fitness = np.mean(list(skill_vector.values())) if skill_vector else 0.0
solution = BioSolution(
model_path=model_path,
fitness=fitness,
skill_vector=skill_vector,
generation=0,
)
self.archive.add_solution(solution)
logger.info(f"Seed model {model_path}: fitness={fitness:.4f}")
# Save initial archive
self._save_archive(0)
def evolve(self) -> List[BioSolution]:
"""Run the main optimization loop."""
logger.info("Starting Bio-ACDC evolution")
# Initialize
self.initialize_seed_population()
for generation in range(1, self.config.num_generations + 1):
logger.info(f"=== Generation {generation}/{self.config.num_generations} ===")
# Generate new models
new_solutions = self._generate_offspring(generation)
# Evaluate on current task pool
for sol in new_solutions:
tasks = self.task_pool.get_tasks(n=self.config.num_tasks_per_eval)
metrics = self.evaluator.evaluate_model(sol.model_path, tasks)
skill_vector = {t.task_id: metrics.get(t.task_id, 0.0) for t in tasks}
sol.fitness = np.mean(list(skill_vector.values())) if skill_vector else 0.0
sol.skill_vector = skill_vector
# Update archive with dominated novelty search
self.archive.update(new_solutions)
# Generate new tasks based on archive weaknesses
self._update_task_pool(generation)
# Save checkpoint
if generation % self.config.save_frequency == 0:
self._save_archive(generation)
self._save_tasks(generation)
# Log status
best = self.archive.get_best()
if best:
logger.info(f"Generation {generation}: Best fitness={best.fitness:.4f}, Archive size={len(self.archive.solutions)}")
# Final save
self._save_archive(self.config.num_generations)
return self.archive.solutions
def _generate_offspring(self, generation: int) -> List[BioSolution]:
"""Create new models by merging parents from archive."""
new_solutions = []
for i in range(self.config.offspring_per_gen):
# Select parents
if len(self.archive.solutions) >= 2:
parents = self.rng.choice(
self.archive.solutions,
size=2,
replace=False,
)
else:
# Fallback to seed models
parents = self.rng.choice(
self.config.seed_model_paths,
size=2,
replace=False,
)
parents = [
BioSolution(model_path=p, fitness=0.0, skill_vector={}, generation=0)
for p in parents
]
parent_paths = [p.model_path for p in parents]
# Merge
save_path = str(self.model_dir / f"gen_{generation}_ind_{i}")
try:
merged_path = self.merger.merge(
parent_paths=parent_paths,
save_path=save_path,
)
# Mutate if enabled
if self.config.use_mutation and self.mutator is not None:
if self.rng.random() < self.config.mutation_rate:
merged_path = self.mutator.mutate(
model_path=merged_path,
save_path=save_path + "_mutated",
)
sol = BioSolution(
model_path=merged_path,
fitness=0.0, # Will be evaluated
skill_vector={},
generation=generation,
parent_paths=parent_paths,
)
new_solutions.append(sol)
except Exception as e:
logger.error(f"Failed to create offspring {i}: {e}")
continue
return new_solutions
def _update_task_pool(self, generation: int) -> None:
"""Generate new tasks targeting archive weaknesses."""
# Analyze what tasks the archive struggles with
difficulty_weights = self.archive.compute_difficulty_weights()
# Generate new tasks
new_tasks = self.task_pool.generate_new_tasks(
archive=self.archive,
num_tasks=5,
difficulty_weights=difficulty_weights,
)
# Save generated tasks
for task in new_tasks:
task_path = self.task_dir / f"gen_{generation}_{task.task_id}.json"
with open(task_path, "w") as f:
json.dump(task.to_dict(), f, indent=2)
logger.info(f"Generated {len(new_tasks)} new tasks")
def _save_archive(self, generation: int) -> None:
"""Save archive to disk."""
archive_path = self.archive_dir / f"archive_gen_{generation}.json"
self.archive.save(archive_path)
logger.info(f"Archive saved to {archive_path}")
def _save_tasks(self, generation: int) -> None:
"""Save current task pool."""
tasks_path = self.task_dir / f"tasks_gen_{generation}.json"
tasks_data = [t.to_dict() for t in self.task_pool.tasks]
with open(tasks_path, "w") as f:
json.dump(tasks_data, f, indent=2)