MadishettiMahesh commited on
Commit
4888ee7
·
verified ·
1 Parent(s): 69d501d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +165 -0
  2. best_solv_sage.pth +3 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from torch_geometric.nn import SAGEConv, global_mean_pool
7
+ from torch_geometric.data import Batch
8
+ from torch_geometric.utils import from_smiles
9
+
10
+ # ===============================================================
11
+ # Model Definition (MUST MATCH TRAINING)
12
+ # ===============================================================
13
+ class MolEncoderSAGE(torch.nn.Module):
14
+ def __init__(self, in_dim=9, hidden=128, layers=3):
15
+ super().__init__()
16
+ self.convs = torch.nn.ModuleList()
17
+ self.convs.append(SAGEConv(in_dim, hidden))
18
+ for _ in range(layers - 1):
19
+ self.convs.append(SAGEConv(hidden, hidden))
20
+
21
+ def forward(self, data):
22
+ x = data.x.float()
23
+ edge_index = data.edge_index
24
+ batch = data.batch
25
+
26
+ for conv in self.convs:
27
+ x = torch.relu(conv(x, edge_index))
28
+
29
+ return global_mean_pool(x, batch)
30
+
31
+
32
+ class SolvSAGENet(torch.nn.Module):
33
+ def __init__(self, hidden=128, layers=3, dropout=0.1):
34
+ super().__init__()
35
+ self.solute = MolEncoderSAGE(9, hidden, layers)
36
+ self.solvent = MolEncoderSAGE(9, hidden, layers)
37
+
38
+ self.mlp = torch.nn.Sequential(
39
+ torch.nn.Linear(2 * hidden, 256),
40
+ torch.nn.ReLU(),
41
+ torch.nn.Dropout(dropout),
42
+ torch.nn.Linear(256, 128),
43
+ torch.nn.ReLU(),
44
+ torch.nn.Linear(128, 1)
45
+ )
46
+
47
+ def forward(self, s, v):
48
+ z = torch.cat([self.solute(s), self.solvent(v)], dim=1)
49
+ return self.mlp(z)
50
+
51
+ # ===============================================================
52
+ # Load Model (cached)
53
+ # ===============================================================
54
+ @st.cache_resource
55
+ def load_model():
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ model = SolvSAGENet(hidden=128, layers=3, dropout=0.1).to(device)
58
+ model.load_state_dict(torch.load("best_solv_sage.pth", map_location=device))
59
+ model.eval()
60
+ return model, device
61
+
62
+ model, device = load_model()
63
+
64
+ # ===============================================================
65
+ # Streamlit UI
66
+ # ===============================================================
67
+ st.set_page_config(page_title="ΔG_solv Prediction", layout="centered")
68
+
69
+ st.title("🔬 ΔGₛₒₗᵥ Prediction (GraphSAGE)")
70
+ st.markdown("""
71
+ Enter **solute and solvent SMILES** to predict
72
+ **Solvation Free Energy (ΔGₛₒₗᵥ)** in kcal/mol.
73
+ """)
74
+
75
+ # ===============================================================
76
+ # Single Prediction
77
+ # ===============================================================
78
+ st.header("🧪 Single Prediction")
79
+
80
+ solute_smiles = st.text_input(
81
+ "Solute SMILES",
82
+ value="CCO",
83
+ help="Example: CCO (ethanol)"
84
+ )
85
+
86
+ solvent_smiles = st.text_input(
87
+ "Solvent SMILES",
88
+ value="O",
89
+ help="Example: O (water)"
90
+ )
91
+
92
+ if st.button("Predict ΔGₛₒₗᵥ"):
93
+ try:
94
+ # Convert SMILES → graphs
95
+ solute_graph = from_smiles(solute_smiles)
96
+ solvent_graph = from_smiles(solvent_smiles)
97
+
98
+ # Create batch
99
+ solute_batch = Batch.from_data_list([solute_graph]).to(device)
100
+ solvent_batch = Batch.from_data_list([solvent_graph]).to(device)
101
+
102
+ # Predict
103
+ with torch.no_grad():
104
+ prediction = model(solute_batch, solvent_batch).item()
105
+
106
+ st.success(f"✅ Predicted ΔGₛₒₗᵥ: **{prediction:.3f} kcal/mol**")
107
+
108
+ except Exception as e:
109
+ st.error("❌ Invalid SMILES or model error")
110
+ st.write(e)
111
+
112
+ # ===============================================================
113
+ # Batch Prediction
114
+ # ===============================================================
115
+ st.header("📂 Batch Prediction (CSV Upload)")
116
+
117
+ st.markdown("""
118
+ Upload a CSV file with **columns**:
119
+ - `mol_solute`
120
+ - `mol_solvent`
121
+ """)
122
+
123
+ uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
124
+
125
+ if uploaded_file:
126
+ df = pd.read_csv(uploaded_file)
127
+
128
+ if {"mol_solute", "mol_solvent"}.issubset(df.columns):
129
+ predictions = []
130
+
131
+ with torch.no_grad():
132
+ for _, row in df.iterrows():
133
+ try:
134
+ s = from_smiles(row["mol_solute"])
135
+ v = from_smiles(row["mol_solvent"])
136
+
137
+ sb = Batch.from_data_list([s]).to(device)
138
+ vb = Batch.from_data_list([v]).to(device)
139
+
140
+ pred = model(sb, vb).item()
141
+ predictions.append(pred)
142
+ except:
143
+ predictions.append(np.nan)
144
+
145
+ df["predicted_Gsolv"] = predictions
146
+
147
+ st.dataframe(df)
148
+
149
+ st.download_button(
150
+ label="⬇️ Download Predictions",
151
+ data=df.to_csv(index=False),
152
+ file_name="predicted_gsolv.csv",
153
+ mime="text/csv"
154
+ )
155
+ else:
156
+ st.error("CSV must contain columns: mol_solute, mol_solvent")
157
+
158
+ # ===============================================================
159
+ # Footer
160
+ # ===============================================================
161
+ st.markdown("---")
162
+ st.markdown(
163
+ "🧠 **Graph Neural Network (GraphSAGE)** \n"
164
+ "🔗 PyTorch Geometric | Molecular ML"
165
+ )
best_solv_sage.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:114d73ec142074f3c79013af6169de54ed1730f7e6876fa155a3ef9a9ae5a21c
3
+ size 950317
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torch-geometric
4
+ pandas
5
+ numpy
6
+ scikit-learn