Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| from mol_to_graph import batch_smiles_to_graphs | |
| def get_bbb_training_data(): | |
| """ | |
| Create a curated BBB permeability dataset with known compounds | |
| BBB permeability scale: | |
| - 1.0: High permeability (BBB+) | |
| - 0.5: Moderate permeability | |
| - 0.0: No permeability (BBB-) | |
| Data sources: Literature values and known BBB classifications | |
| """ | |
| data = { | |
| 'SMILES': [ | |
| # High BBB permeability (BBB+) - CNS drugs and neurotransmitters | |
| 'COC(=O)C1C(CC2CC1N2C)c3cccc(c3)OC', # Cocaine (0.95) | |
| 'CC(C)NCC(COc1ccccc1)O', # Propranolol (0.92) | |
| 'CCO', # Ethanol (0.88) | |
| 'c1ccccc1', # Benzene (0.90) | |
| 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C', # Caffeine (0.85) | |
| 'CC(C)Cc1ccc(cc1)C(C)C(=O)O', # Ibuprofen (0.82) | |
| 'CC(=O)Nc1ccc(cc1)O', # Paracetamol/Acetaminophen (0.80) | |
| 'C1CCC(CC1)C(C2CCCCC2)N', # Phencyclidine skeleton (0.93) | |
| 'c1ccc(cc1)CCN', # Phenethylamine (0.87) | |
| 'CN1CCCC1c2cccnc2', # Nicotine (0.89) | |
| 'COc1cc2c(cc1OC)[nH]cc2CCN', # Serotonin derivative (0.81) | |
| 'c1ccc2c(c1)ccc3c2cccc3', # Anthracene (0.91) | |
| 'Cc1ccccc1', # Toluene (0.88) | |
| 'c1ccc(cc1)C(=O)O', # Benzoic acid (0.75) | |
| 'CC(C)(C)c1ccc(cc1)O', # BHT derivative (0.84) | |
| # Moderate BBB permeability (0.4-0.6) | |
| 'CC(C)(C)NCC(c1cc(c(c(c1)O)CO)O)O', # Salbutamol (0.55) | |
| 'C1CNC(=O)NC1=O', # Uracil (0.50) | |
| 'c1cc(ccc1C(=O)O)N', # p-Aminobenzoic acid (0.52) | |
| 'CC(=O)c1ccc(cc1)O', # p-Hydroxyacetophenone (0.58) | |
| 'Nc1ncnc2n(cnc12)C3OC(CO)C(O)C3O', # Adenosine partial (0.45) | |
| 'c1ccc(cc1)c2ccccc2', # Biphenyl (0.62) | |
| 'COc1ccccc1', # Anisole (0.68) | |
| 'CC(=O)Oc1ccccc1C(=O)O', # Aspirin (0.50) | |
| # Low/No BBB permeability (BBB-) | |
| 'CC(=O)O', # Acetic acid (0.25) | |
| 'C(C(=O)O)N', # Glycine (0.15) | |
| 'C(CC(=O)O)C(C(=O)O)N', # Glutamic acid (0.10) | |
| 'C1=NC(=O)NC(=O)C1N', # Cytosine (0.20) | |
| 'C(C(C(C(C(C=O)O)O)O)O)O', # Glucose (0.08) | |
| 'C1C(C(C(C(C1N)OC2C(C(C(C(O2)CO)O)O)N)OC3C(C(C(O3)CO)OC4C(C(CO4)O)O)O)O)N', # Streptomycin (0.05) | |
| 'CC(C)(COP(=O)(O)OP(=O)(O)OCC1C(C(C(O1)n2cnc3c2nc[nH]c3=N)O)OP(=O)(O)O)C(C(=O)NCCC(=O)NCCSC(=O)C)O', # Coenzyme A (0.02) | |
| 'c1cc(ccc1C(=O)O)O', # p-Hydroxybenzoic acid (0.22) | |
| 'C(CO)N', # Ethanolamine (0.18) | |
| 'c1cc(c(cc1Cl)Cl)Occ2c(cc(cc2Cl)Cl)Cl', # Pentachlorophenol ether (0.12) | |
| 'C(=O)(O)O', # Carbonic acid (0.10) | |
| 'CCOP(=O)(OCC)OC', # Organophosphate (0.15) | |
| 'C1=NC2=C(N1)C(=O)NC(=N2)N', # Guanine (0.12) | |
| 'O=S(=O)(O)O', # Sulfuric acid (0.05) | |
| # Additional diverse molecules | |
| 'c1ccc(cc1)c2ccccc2c3ccccc3', # Triphenyl (0.70) | |
| 'CCN(CC)CC', # Triethylamine (0.78) | |
| 'c1ccc2c(c1)c(c[nH]2)CCN', # Tryptamine (0.83) | |
| 'c1ccc(cc1)NC(=O)c2ccccc2', # Benzanilide (0.65) | |
| 'CC1(C2CCC1(C(=O)C2)C)C', # Camphor (0.76) | |
| ], | |
| 'BBB_permeability': [ | |
| # High BBB (15 compounds) | |
| 0.95, 0.92, 0.88, 0.90, 0.85, 0.82, 0.80, 0.93, 0.87, 0.89, | |
| 0.81, 0.91, 0.88, 0.75, 0.84, | |
| # Moderate BBB (8 compounds) | |
| 0.55, 0.50, 0.52, 0.58, 0.45, 0.62, 0.68, 0.50, | |
| # Low BBB (14 compounds) | |
| 0.25, 0.15, 0.10, 0.20, 0.08, 0.05, 0.02, 0.22, 0.18, 0.12, | |
| 0.10, 0.15, 0.12, 0.05, | |
| # Additional diverse (5 compounds) | |
| 0.70, 0.78, 0.83, 0.65, 0.76, | |
| ], | |
| 'compound_name': [ | |
| # High BBB | |
| 'Cocaine', 'Propranolol', 'Ethanol', 'Benzene', 'Caffeine', | |
| 'Ibuprofen', 'Acetaminophen', 'Phencyclidine', 'Phenethylamine', 'Nicotine', | |
| 'Serotonin_derivative', 'Anthracene', 'Toluene', 'Benzoic_acid', 'BHT_derivative', | |
| # Moderate BBB | |
| 'Salbutamol', 'Uracil', 'p-Aminobenzoic_acid', 'p-Hydroxyacetophenone', | |
| 'Adenosine_partial', 'Biphenyl', 'Anisole', 'Aspirin', | |
| # Low BBB | |
| 'Acetic_acid', 'Glycine', 'Glutamic_acid', 'Cytosine', 'Glucose', | |
| 'Streptomycin', 'Coenzyme_A', 'p-Hydroxybenzoic_acid', 'Ethanolamine', | |
| 'Pentachlorophenol_ether', 'Carbonic_acid', 'Organophosphate', | |
| 'Guanine', 'Sulfuric_acid', | |
| # Additional (5 compounds) | |
| 'Triphenyl', 'Triethylamine', 'Tryptamine', 'Benzanilide', 'Camphor', | |
| ], | |
| 'category': [ | |
| # High BBB | |
| 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', | |
| 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', | |
| # Moderate BBB | |
| 'BBB±', 'BBB±', 'BBB±', 'BBB±', 'BBB±', 'BBB±', 'BBB±', 'BBB±', | |
| # Low BBB | |
| 'BBB-', 'BBB-', 'BBB-', 'BBB-', 'BBB-', 'BBB-', 'BBB-', 'BBB-', | |
| 'BBB-', 'BBB-', 'BBB-', 'BBB-', 'BBB-', 'BBB-', | |
| # Additional | |
| 'BBB+', 'BBB+', 'BBB+', 'BBB+', 'BBB+', | |
| ] | |
| } | |
| df = pd.DataFrame(data) | |
| return df | |
| def load_bbb_dataset(validation_split=0.2, random_state=42): | |
| """ | |
| Load BBB dataset and convert to PyTorch Geometric graphs | |
| Args: | |
| validation_split: Fraction of data to use for validation | |
| random_state: Random seed for reproducibility | |
| Returns: | |
| train_graphs, val_graphs, df (the full dataframe for reference) | |
| """ | |
| df = get_bbb_training_data() | |
| # Shuffle the data | |
| df = df.sample(frac=1, random_state=random_state).reset_index(drop=True) | |
| # Split into train/val | |
| val_size = int(len(df) * validation_split) | |
| val_df = df.iloc[:val_size] | |
| train_df = df.iloc[val_size:] | |
| print(f"Dataset Statistics:") | |
| print(f" Total compounds: {len(df)}") | |
| print(f" Training: {len(train_df)}") | |
| print(f" Validation: {len(val_df)}") | |
| print(f"\nClass distribution:") | |
| print(df['category'].value_counts()) | |
| # Convert to graphs | |
| train_graphs = batch_smiles_to_graphs( | |
| train_df['SMILES'].tolist(), | |
| train_df['BBB_permeability'].tolist() | |
| ) | |
| val_graphs = batch_smiles_to_graphs( | |
| val_df['SMILES'].tolist(), | |
| val_df['BBB_permeability'].tolist() | |
| ) | |
| print(f"\nGraphs created:") | |
| print(f" Training graphs: {len(train_graphs)}") | |
| print(f" Validation graphs: {len(val_graphs)}") | |
| return train_graphs, val_graphs, df | |
| if __name__ == "__main__": | |
| # Test dataset loading | |
| print("BBB Permeability Dataset") | |
| print("=" * 60) | |
| train_graphs, val_graphs, df = load_bbb_dataset(validation_split=0.2) | |
| print(f"\nSample molecules:") | |
| print(df[['compound_name', 'BBB_permeability', 'category']].head(10)) | |
| print(f"\nPermeability statistics:") | |
| print(f" Mean: {df['BBB_permeability'].mean():.3f}") | |
| print(f" Std: {df['BBB_permeability'].std():.3f}") | |
| print(f" Min: {df['BBB_permeability'].min():.3f}") | |
| print(f" Max: {df['BBB_permeability'].max():.3f}") | |
| print(f"\nExample graph structure:") | |
| if len(train_graphs) > 0: | |
| g = train_graphs[0] | |
| print(f" Nodes: {g.x.shape[0]}") | |
| print(f" Node features: {g.x.shape[1]}") | |
| print(f" Edges: {g.edge_index.shape[1]}") | |
| print(f" Target: {g.y.item():.3f}") | |
| print("\nDataset ready for training!") | |