File size: 3,063 Bytes
492a96f
 
 
 
6c5cf6f
 
cbb2cbe
007085f
09baacd
 
 
 
 
 
 
 
 
 
 
 
 
 
007085f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e17f87b
 
007085f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
---
license: bsd-2-clause
---

The model weights provided here are for solubility prediction

# Install FragNet

FragNet is available on GitHub: https://github.com/pnnl/FragNet

To install FragNet, run the following commands:

```bash
git clone https://github.com/pnnl/FragNet.git
cd FragNet
# make sure a python virtual environment is activated
pip install --upgrade pip
pip install -r requirements.txt
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
pip install .
```

# Load model
```python
import torch
from huggingface_hub import hf_hub_download
from fragnet.model.gat.gat2 import FragNetFineTune
from huggingface.fragnet_config import FragNetConfig

config_path = hf_hub_download(repo_id="gihan12/FragNet", filename="config.json")
model_path = hf_hub_download(repo_id="gihan12/FragNet", filename="pytorch_model.bin")

config = FragNetConfig.from_json_file(config_path)
model = FragNetFineTune(**config.get_model_kwargs())
model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
model.eval()
```

# Prepare molecule data the proper way

```python
import pandas as pd
import pickle
from fragnet.dataset.data import CreateData
from fragnet.dataset.fragments import get_3Dcoords2

# A function to process SMILES
def smiles_to_fragnet_data(smiles, data_type="exp1s", frag_type="murcko"):
    """Convert SMILES to FragNet data format."""
    create_data = CreateData(
        data_type=data_type,
        create_bond_graph_data=True,
        add_dhangles=True,
    )
    
    # Get 3D coordinates
    res = get_3Dcoords2(smiles, maxiters=500)
    if res is None:
        return None
    
    mol, conf_res = res
    
    # get_3Dcoords2 returns (mol, list of (conf_id, energy))
    # We need to get the conformer with the lowest energy
    if not conf_res:
        return None
    
    # Sort by energy and get the best conformer
    conf_res_sorted = sorted(conf_res, key=lambda x: x[1])
    best_conf_id = conf_res_sorted[0][0]
    best_conf = mol.GetConformer(best_conf_id)
    
    # create_data_point expects: (smiles, y, mol, conf, frag_type)
    # For inference, use a dummy y value (0.0) - it will be replaced by prediction
    args = (smiles, 0.0, mol, best_conf, frag_type)
    data = create_data.create_data_point(args)
    
    # Fix y to be 1D tensor for proper batching
    data.y = data.y.reshape(-1)
    
    return data
```

```python
# Test with Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
data = smiles_to_fragnet_data(smiles)

if data is not None:
    print("✓ Data created successfully")
    print(f"  Atoms: {data.x_atoms.shape}")
    print(f"  Fragments: {data.x_frags.shape}")
else:
    print("✗ Failed to create data")
```




# Make prediction
```python
from fragnet.dataset.data import collate_fn

if data is not None:
    # Create batch using the proper collate function
    batch = collate_fn([data])
    
    # Predict
    with torch.no_grad():
        prediction = model(batch)
        print(f"\nPrediction for {smiles}")
        print(f"  Value: {prediction.item():.4f}")

```