Spaces:
Sleeping
Sleeping
File size: 3,128 Bytes
315d4ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import List
from core.config import EvolutionConfig
from .molecule import Molecule
import pandas as pd
class Population:
"""Manages the population of molecules."""
def __init__(self, config: EvolutionConfig):
self.config = config
self.molecules: List[Molecule] = []
self.seen_smiles: set = set()
def add_molecule(self, mol: Molecule) -> bool:
"""Add a molecule if it's not already in the population."""
if mol.smiles in self.seen_smiles:
return False
self.molecules.append(mol)
self.seen_smiles.add(mol.smiles)
return True
def add_molecules(self, molecules: List[Molecule]) -> int:
"""Add multiple molecules, return count added."""
return sum(self.add_molecule(mol) for mol in molecules)
def pareto_front(self) -> List[Molecule]:
"""Extract the Pareto front from the population."""
if not self.config.minimize_ysi:
return []
return [
mol for mol in self.molecules
if not any(other.dominates(mol, self.config.maximize_cn)
for other in self.molecules if other is not mol)
]
def get_survivors(self) -> List[Molecule]:
"""Select survivors for the next generation."""
target_size = int(self.config.population_size * self.config.survivor_fraction)
if self.config.minimize_ysi:
survivors = self.pareto_front()
sort_key = lambda m: (
-self.config.cn_objective(m.cn), # higher objective = better
m.ysi
)
if len(survivors) > target_size:
survivors = sorted(survivors, key=sort_key)[:target_size]
elif len(survivors) < target_size:
remainder = [m for m in self.molecules if m not in survivors]
remainder = sorted(remainder, key=sort_key)
survivors.extend(remainder[:target_size - len(survivors)])
else:
# Single objective mode
survivors = sorted(
self.molecules,
key=lambda m: self.config.cn_objective(m.cn),
reverse=True
)[:target_size]
return survivors
def to_dataframe(self) -> pd.DataFrame:
"""Convert population to DataFrame."""
df = pd.DataFrame([m.to_dict() for m in self.molecules])
if self.config.maximize_cn:
if self.config.minimize_ysi:
sort_cols = ["cn", "ysi"]
ascending = [False, True] # Descending CN, ascending YSI
else:
sort_cols = ["cn"]
ascending = False
else:
if self.config.minimize_ysi:
sort_cols = ["cn_error", "ysi"]
ascending = True
else:
sort_cols = ["cn_error"]
ascending = True
df = df.sort_values(sort_cols, ascending=ascending)
df.insert(0, 'rank', range(1, len(df) + 1))
return df
|