| |
| """ |
| Standalone inference script for Pyrosage AMES AttentiveFP Model |
| Usage: python inference.py "SMILES_STRING" |
| """ |
|
|
| import sys |
| import torch |
| from torch_geometric.nn import AttentiveFP |
| from rdkit import Chem |
| from torch_geometric.data import Data |
|
|
|
|
| def smiles_to_data(smiles): |
| """Convert SMILES string to PyG Data object with enhanced features""" |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| return None |
|
|
| |
| atom_features = [] |
| for atom in mol.GetAtoms(): |
| features = [ |
| atom.GetAtomicNum(), |
| atom.GetTotalDegree(), |
| atom.GetFormalCharge(), |
| atom.GetTotalNumHs(), |
| atom.GetNumRadicalElectrons(), |
| int(atom.GetIsAromatic()), |
| int(atom.IsInRing()), |
| |
| int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP), |
| int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2), |
| int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3) |
| ] |
| atom_features.append(features) |
|
|
| x = torch.tensor(atom_features, dtype=torch.float) |
|
|
| |
| edges_list = [] |
| edge_features = [] |
| for bond in mol.GetBonds(): |
| i = bond.GetBeginAtomIdx() |
| j = bond.GetEndAtomIdx() |
| edges_list.extend([[i, j], [j, i]]) |
|
|
| features = [ |
| |
| int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE), |
| int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE), |
| int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE), |
| int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC), |
| |
| int(bond.GetIsConjugated()), |
| int(bond.IsInRing()) |
| ] |
| edge_features.extend([features, features]) |
|
|
| if not edges_list: |
| return None |
|
|
| edge_index = torch.tensor(edges_list, dtype=torch.long).t() |
| edge_attr = torch.tensor(edge_features, dtype=torch.float) |
|
|
| return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) |
|
|
|
|
| def load_model(): |
| """Load the AttentiveFP model""" |
| model_dict = torch.load('pytorch_model.bin', map_location='cpu') |
| state_dict = model_dict['model_state_dict'] |
| hyperparams = model_dict['hyperparameters'] |
|
|
| model = AttentiveFP( |
| in_channels=10, |
| hidden_channels=hyperparams["hidden_channels"], |
| out_channels=1, |
| edge_dim=6, |
| num_layers=hyperparams["num_layers"], |
| num_timesteps=hyperparams["num_timesteps"], |
| dropout=hyperparams["dropout"], |
| ) |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
| return model |
|
|
|
|
| def predict(model, smiles): |
| """Make prediction for a SMILES string""" |
| data = smiles_to_data(smiles) |
| if data is None: |
| return None |
| |
| batch = torch.zeros(data.num_nodes, dtype=torch.long) |
| with torch.no_grad(): |
| output = model(data.x, data.edge_index, data.edge_attr, batch) |
| return output.item() |
|
|
|
|
| def main(): |
| if len(sys.argv) != 2: |
| print("Usage: python inference.py 'SMILES_STRING'") |
| print("Example: python inference.py 'CC(=O)OC1=CC=CC=C1C(=O)O'") |
| sys.exit(1) |
| |
| smiles = sys.argv[1] |
| print(f"Loading AMES AttentiveFP model...") |
| |
| try: |
| model = load_model() |
| print(f"Making prediction for: {smiles}") |
| |
| prediction = predict(model, smiles) |
| if prediction is not None: |
| print(f'Classification result: {prediction:.4f} (>0.5 = positive, <=0.5 = negative)') |
| else: |
| print("Error: Could not process SMILES string") |
| |
| except Exception as e: |
| print(f"Error: {e}") |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|