| --- |
| license: bsd-2-clause |
| --- |
| |
| The model weights provided here are for solubility prediction |
|
|
| # Install FragNet |
|
|
| FragNet is available on GitHub: https://github.com/pnnl/FragNet |
|
|
| To install FragNet, run the following commands: |
|
|
| ```bash |
| git clone https://github.com/pnnl/FragNet.git |
| cd FragNet |
| # make sure a python virtual environment is activated |
| pip install --upgrade pip |
| pip install -r requirements.txt |
| pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html |
| pip install . |
| ``` |
|
|
| # Load model |
| ```python |
| import torch |
| from huggingface_hub import hf_hub_download |
| from fragnet.model.gat.gat2 import FragNetFineTune |
| from huggingface.fragnet_config import FragNetConfig |
| |
| config_path = hf_hub_download(repo_id="gihan12/FragNet", filename="config.json") |
| model_path = hf_hub_download(repo_id="gihan12/FragNet", filename="pytorch_model.bin") |
| |
| config = FragNetConfig.from_json_file(config_path) |
| model = FragNetFineTune(**config.get_model_kwargs()) |
| model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) |
| model.eval() |
| ``` |
|
|
| # Prepare molecule data the proper way |
|
|
| ```python |
| import pandas as pd |
| import pickle |
| from fragnet.dataset.data import CreateData |
| from fragnet.dataset.fragments import get_3Dcoords2 |
| |
| # A function to process SMILES |
| def smiles_to_fragnet_data(smiles, data_type="exp1s", frag_type="murcko"): |
| """Convert SMILES to FragNet data format.""" |
| create_data = CreateData( |
| data_type=data_type, |
| create_bond_graph_data=True, |
| add_dhangles=True, |
| ) |
| |
| # Get 3D coordinates |
| res = get_3Dcoords2(smiles, maxiters=500) |
| if res is None: |
| return None |
| |
| mol, conf_res = res |
| |
| # get_3Dcoords2 returns (mol, list of (conf_id, energy)) |
| # We need to get the conformer with the lowest energy |
| if not conf_res: |
| return None |
| |
| # Sort by energy and get the best conformer |
| conf_res_sorted = sorted(conf_res, key=lambda x: x[1]) |
| best_conf_id = conf_res_sorted[0][0] |
| best_conf = mol.GetConformer(best_conf_id) |
| |
| # create_data_point expects: (smiles, y, mol, conf, frag_type) |
| # For inference, use a dummy y value (0.0) - it will be replaced by prediction |
| args = (smiles, 0.0, mol, best_conf, frag_type) |
| data = create_data.create_data_point(args) |
| |
| # Fix y to be 1D tensor for proper batching |
| data.y = data.y.reshape(-1) |
| |
| return data |
| ``` |
|
|
| ```python |
| # Test with Aspirin |
| smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" |
| data = smiles_to_fragnet_data(smiles) |
| |
| if data is not None: |
| print("✓ Data created successfully") |
| print(f" Atoms: {data.x_atoms.shape}") |
| print(f" Fragments: {data.x_frags.shape}") |
| else: |
| print("✗ Failed to create data") |
| ``` |
|
|
|
|
|
|
|
|
| # Make prediction |
| ```python |
| from fragnet.dataset.data import collate_fn |
| |
| if data is not None: |
| # Create batch using the proper collate function |
| batch = collate_fn([data]) |
| |
| # Predict |
| with torch.no_grad(): |
| prediction = model(batch) |
| print(f"\nPrediction for {smiles}") |
| print(f" Value: {prediction.item():.4f}") |
| |
| ``` |