liuganghuggingface commited on
Commit
286c763
·
1 Parent(s): b503fed
Files changed (4) hide show
  1. .gradio/certificate.pem +31 -0
  2. README.md +3 -1
  3. app.py +880 -0
  4. requirements.txt +12 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -11,4 +11,6 @@ license: mit
11
  short_description: 'Polymer property prediction for gas separation design '
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
11
  short_description: 'Polymer property prediction for gas separation design '
12
  ---
13
 
14
+ <!-- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference -->
15
+
16
+ Based on torch-molecule (https://github.com/liugangcode/torch-molecule) and sklearn.
app.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ import os
7
+ import tempfile
8
+ from pathlib import Path
9
+ import pickle
10
+ import joblib
11
+
12
+ from rdkit import Chem
13
+ from rdkit.Chem import Draw, AllChem
14
+
15
+ import plotly.graph_objects as go
16
+ import plotly.express as px
17
+
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ # Import torch_molecule models
21
+ try:
22
+ from torch_molecule import GREAMolecularPredictor, GNNMolecularPredictor
23
+ TORCH_MOLECULE_AVAILABLE = True
24
+ except ImportError:
25
+ TORCH_MOLECULE_AVAILABLE = False
26
+ print("Warning: torch_molecule not available. Some models may not work.")
27
+
28
+ all_properties = ['CH4', 'CO2', 'H2', 'N2', 'O2']
29
+ all_model_names = ['GREA', 'GCN', 'GIN', 'RandomForest', 'GaussianProcess']
30
+
31
+ # Training configuration - set to True if models were trained in log space
32
+ TRAIN_IN_LOG = True
33
+
34
+ # HuggingFace repository ID
35
+ HF_REPO_ID = "liuganghuggingface/polymer-prediction-gas-models"
36
+
37
+ # Default SMILES for testing
38
+ DEFAULT_SMILES = """*c1cc2c(cc1*)C1(C(C)C)c3ccccc3C2(C(C)C)c2cc3c(cc21)Oc1cc2nc(*)c(*)nc2cc1O3
39
+ *CN1CN(*)Cc2cc3c(cc21)C1c2ccccc2C3c2cc(*)c(*)cc21
40
+ *C(=C(*)c1ccc2c(c1)C(C)(C)C(C)(C)C2(C)C)c1ccccc1"""
41
+
42
+ # Selectivity boundary parameters (from 3_create_polymer_oracle.py)
43
+ SELECTIVITY_BOUNDS = {
44
+ 'CO2/CH4': {
45
+ 'x': [1.00E+05, 1.00E-02],
46
+ 'y': [1.00E+05/2.21E+04, 1.00E-02/4.88E-06],
47
+ 'gases': ('CO2', 'CH4')
48
+ },
49
+ 'H2/CH4': {
50
+ 'x': [5.00E+04, 2.50E+00],
51
+ 'y': [5.00E+04/8.67E+04, 2.50E+00/5.64E-04],
52
+ 'gases': ('H2', 'CH4')
53
+ },
54
+ 'O2/N2': {
55
+ 'x': [5.00E+04, 1.00E-03],
56
+ 'y': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05],
57
+ 'gases': ('O2', 'N2')
58
+ },
59
+ 'H2/N2': {
60
+ 'x': [1.00E+05, 1.00E-01],
61
+ 'y': [1.00E+05/1.02E+05, 1.00E-01/9.21E-06],
62
+ 'gases': ('H2', 'N2')
63
+ },
64
+ 'CO2/N2': {
65
+ 'x': [1.00E+06, 1.00E-04],
66
+ 'y': [1.00E+06/3.05E+05, 1.00E-04/1.05E-08],
67
+ 'gases': ('CO2', 'N2')
68
+ }
69
+ }
70
+
71
+ # ============= MODEL LOADING =============
72
+
73
+ def load_all_models():
74
+ """
75
+ Load all available models from HuggingFace Hub at startup.
76
+
77
+ Returns:
78
+ Dictionary with structure: {model_name: {gas: (model, model_type)}}
79
+ """
80
+ print("Loading all models from HuggingFace Hub...")
81
+ loaded_models = {}
82
+
83
+ for model_name in all_model_names:
84
+ loaded_models[model_name] = {}
85
+
86
+ for gas in all_properties:
87
+ model_filename = f"{model_name.lower()}_{gas.lower()}"
88
+
89
+ try:
90
+ if model_name in ['GREA', 'GCN', 'GIN']:
91
+ filename = f"{model_filename}.pt"
92
+
93
+ if not TORCH_MOLECULE_AVAILABLE:
94
+ print(f" ⚠️ torch_molecule not available for {model_name}")
95
+ continue
96
+
97
+ # Download model from HuggingFace Hub
98
+ print(f" Downloading {filename} from HuggingFace Hub...")
99
+ model_path = hf_hub_download(
100
+ repo_id=HF_REPO_ID,
101
+ filename=filename
102
+ )
103
+ print('model path for .pt file: ', model_path)
104
+
105
+ # Instantiate model architecture
106
+ if model_name == 'GREA':
107
+ model = GREAMolecularPredictor()
108
+ elif model_name == 'GCN':
109
+ model = GNNMolecularPredictor(gnn_type='gcn-virtual')
110
+ elif model_name == 'GIN':
111
+ model = GNNMolecularPredictor(gnn_type='gin-virtual')
112
+
113
+ # Load model weights from downloaded file
114
+ model.load_from_local(model_path)
115
+
116
+ loaded_models[model_name][gas] = (model, 'torch_molecule')
117
+ print(f" ✓ Loaded {model_name} for {gas}")
118
+
119
+ else: # sklearn models
120
+ filename = f"{model_filename}.pkl"
121
+
122
+ # Download model from HuggingFace Hub
123
+ print(f" Downloading {filename} from HuggingFace Hub...")
124
+ model_path = hf_hub_download(
125
+ repo_id=HF_REPO_ID,
126
+ filename=filename
127
+ )
128
+ print('model path for .pkl file: ', model_path)
129
+
130
+ # Load sklearn model with joblib
131
+ model = joblib.load(model_path)
132
+ loaded_models[model_name][gas] = (model, 'sklearn')
133
+ print(f" ✓ Loaded {model_name} for {gas}")
134
+
135
+ except Exception as e:
136
+ print(f" ❌ Error loading {model_name} for {gas}: {e}")
137
+
138
+ print("Model loading complete!")
139
+ return loaded_models
140
+
141
+ # Load all models at startup
142
+ PRELOADED_MODELS = load_all_models()
143
+
144
+ # ============= PREDICTION FUNCTIONS =============
145
+
146
+ def validate_smiles(smiles_list):
147
+ """
148
+ Validate a list of SMILES strings.
149
+
150
+ Returns:
151
+ valid_smiles: List of valid SMILES (standardized)
152
+ invalid_smiles: List of invalid SMILES with indices
153
+ validation_report: String report of validation
154
+ """
155
+ valid_smiles = []
156
+ invalid_smiles = []
157
+
158
+ for idx, smiles in enumerate(smiles_list):
159
+ smiles = smiles.strip()
160
+ if not smiles:
161
+ continue
162
+
163
+ mol = Chem.MolFromSmiles(smiles)
164
+ if mol is not None:
165
+ # Standardize SMILES
166
+ standardized = Chem.MolToSmiles(mol, isomericSmiles=True)
167
+ valid_smiles.append((idx, smiles, standardized))
168
+ else:
169
+ invalid_smiles.append((idx, smiles))
170
+
171
+ report = f"✅ Valid SMILES: {len(valid_smiles)}\n"
172
+ report += f"❌ Invalid SMILES: {len(invalid_smiles)}\n"
173
+
174
+ if invalid_smiles:
175
+ report += "\n**Invalid SMILES detected:**\n"
176
+ for idx, smiles in invalid_smiles:
177
+ report += f" - Line {idx + 1}: `{smiles}`\n"
178
+ report += "\n⚠️ **Please remove or correct the invalid SMILES before proceeding.**"
179
+
180
+ return valid_smiles, invalid_smiles, report
181
+
182
+ def smiles_to_fingerprint(smiles_list, n_bits=2048):
183
+ """Convert SMILES to Morgan fingerprints for sklearn models."""
184
+ fingerprints = []
185
+ for smiles in smiles_list:
186
+ mol = Chem.MolFromSmiles(smiles)
187
+ if mol is not None:
188
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits)
189
+ fingerprints.append(np.array(fp))
190
+ else:
191
+ fingerprints.append(np.zeros(n_bits))
192
+ return np.array(fingerprints)
193
+
194
+ def predict_properties(smiles_list, selected_models, progress=gr.Progress()):
195
+ """
196
+ Predict properties for a list of SMILES using selected models.
197
+
198
+ Args:
199
+ smiles_list: List of SMILES strings
200
+ selected_models: List of model names to use
201
+
202
+ Returns:
203
+ Dictionary with all predictions, report string
204
+ """
205
+ if not selected_models:
206
+ return None, "❌ Please select at least one model."
207
+
208
+ # Validate SMILES
209
+ progress(0.1, desc="Validating SMILES...")
210
+ valid_smiles, invalid_smiles, validation_report = validate_smiles(smiles_list)
211
+
212
+ # Stop if there are any invalid SMILES
213
+ if invalid_smiles:
214
+ return None, validation_report
215
+
216
+ if not valid_smiles:
217
+ return None, "❌ No SMILES provided."
218
+
219
+ # Extract standardized SMILES
220
+ indices, original_smiles, standardized_smiles = zip(*valid_smiles)
221
+
222
+ # Store all predictions by model
223
+ all_predictions = {
224
+ 'original_smiles': list(original_smiles),
225
+ 'standardized_smiles': list(standardized_smiles),
226
+ 'predictions': {}, # {model_name: {gas: predictions}}
227
+ 'predictions_log': {} # Store log-space predictions for plotting
228
+ }
229
+
230
+ # For sklearn models, prepare fingerprints once
231
+ X_fp = None
232
+ needs_fingerprints = any(model in selected_models for model in ['RandomForest', 'GaussianProcess'])
233
+ if needs_fingerprints:
234
+ progress(0.2, desc="Computing molecular fingerprints...")
235
+ X_fp = smiles_to_fingerprint(standardized_smiles)
236
+
237
+ # Track prediction errors
238
+ model_errors = []
239
+
240
+ # Make predictions for each gas and each model
241
+ total_predictions = len(all_properties) * len(selected_models)
242
+ pred_count = 0
243
+
244
+ for model_name in selected_models:
245
+ all_predictions['predictions'][model_name] = {}
246
+ all_predictions['predictions_log'][model_name] = {}
247
+
248
+ for gas in all_properties:
249
+ progress(0.2 + 0.7 * pred_count / total_predictions,
250
+ desc=f"Predicting {gas} with {model_name}...")
251
+
252
+ # Check if model is available
253
+ if model_name not in PRELOADED_MODELS or gas not in PRELOADED_MODELS[model_name]:
254
+ model_errors.append(f"{model_name} for {gas} (not available)")
255
+ pred_count += 1
256
+ continue
257
+
258
+ model, model_type = PRELOADED_MODELS[model_name][gas]
259
+
260
+ # Make predictions
261
+ try:
262
+ if model_type == 'torch_molecule':
263
+ predictions_dict = model.predict(list(standardized_smiles))
264
+ predictions = predictions_dict['prediction']
265
+ else: # sklearn
266
+ predictions = model.predict(X_fp)
267
+
268
+ # Ensure predictions are 1-dimensional
269
+ if isinstance(predictions, np.ndarray) and predictions.ndim > 1:
270
+ predictions = predictions.flatten()
271
+
272
+ # Store predictions
273
+ # If trained in log space, store both log and original space
274
+ if TRAIN_IN_LOG:
275
+ # predictions are in log space, convert to original for display
276
+ predictions_original = 10**predictions
277
+ all_predictions['predictions'][model_name][gas] = predictions_original
278
+ all_predictions['predictions_log'][model_name][gas] = predictions
279
+ else:
280
+ # predictions are already in original space
281
+ all_predictions['predictions'][model_name][gas] = predictions
282
+ all_predictions['predictions_log'][model_name][gas] = np.log10(np.maximum(predictions, 1e-10))
283
+
284
+ except Exception as e:
285
+ print(f"Error predicting with {model_name} for {gas}: {e}")
286
+ model_errors.append(f"{model_name} for {gas} (prediction error)")
287
+
288
+ pred_count += 1
289
+
290
+ # Calculate average predictions across models
291
+ progress(0.9, desc="Computing averages...")
292
+ all_predictions['predictions']['Average'] = {}
293
+ all_predictions['predictions_log']['Average'] = {}
294
+
295
+ for gas in all_properties:
296
+ gas_predictions = []
297
+ gas_predictions_log = []
298
+ for model_name in selected_models:
299
+ if model_name in all_predictions['predictions'] and gas in all_predictions['predictions'][model_name]:
300
+ gas_predictions.append(all_predictions['predictions'][model_name][gas])
301
+ gas_predictions_log.append(all_predictions['predictions_log'][model_name][gas])
302
+
303
+ if gas_predictions:
304
+ if len(gas_predictions) > 1:
305
+ stacked = np.array(gas_predictions)
306
+ stacked_log = np.array(gas_predictions_log)
307
+ all_predictions['predictions']['Average'][gas] = np.mean(stacked, axis=0)
308
+ all_predictions['predictions_log']['Average'][gas] = np.mean(stacked_log, axis=0)
309
+ else:
310
+ all_predictions['predictions']['Average'][gas] = gas_predictions[0]
311
+ all_predictions['predictions_log']['Average'][gas] = gas_predictions_log[0]
312
+
313
+ # Create summary report
314
+ report = validation_report + "\n"
315
+ if model_errors:
316
+ report += f"\n⚠️ Model issues: {', '.join(model_errors)}\n"
317
+ report += f"\n✅ Successfully made predictions for {len(valid_smiles)} molecules using {len(selected_models)} model(s)."
318
+ if TRAIN_IN_LOG:
319
+ report += f"\n📊 Note: Models were trained in log space. Predictions shown in original space (Barrer)."
320
+
321
+ progress(1.0, desc="Done!")
322
+ return all_predictions, report
323
+
324
+ def format_predictions_dataframe(all_predictions, selected_view='Average'):
325
+ """
326
+ Format predictions into a clean DataFrame for display.
327
+
328
+ Args:
329
+ all_predictions: Dictionary with all predictions
330
+ selected_view: Which model's predictions to show ('Average' or specific model name)
331
+
332
+ Returns:
333
+ DataFrame with formatted predictions
334
+ """
335
+ if all_predictions is None:
336
+ return None
337
+
338
+ # Create base DataFrame with only original SMILES
339
+ df = pd.DataFrame({
340
+ 'Original_SMILES': all_predictions['original_smiles']
341
+ })
342
+
343
+ # Add predictions for selected view
344
+ if selected_view in all_predictions['predictions']:
345
+ for gas in all_properties:
346
+ if gas in all_predictions['predictions'][selected_view]:
347
+ predictions = all_predictions['predictions'][selected_view][gas]
348
+ # Format to 3 decimal places
349
+ df[gas] = [f"{val:.3f}" for val in predictions]
350
+ else:
351
+ df[gas] = ['N/A'] * len(df)
352
+
353
+ return df
354
+
355
+ def create_selectivity_plot(all_predictions, selected_view='Average', selectivity_pair='CO2/CH4'):
356
+ """
357
+ Create a selectivity plot with 2008 upper bound.
358
+
359
+ Args:
360
+ all_predictions: Dictionary with all predictions
361
+ selected_view: Which model's predictions to show
362
+ selectivity_pair: Which gas pair to plot (e.g., 'CO2/CH4')
363
+
364
+ Returns:
365
+ Plotly figure
366
+ """
367
+ if all_predictions is None or selectivity_pair not in SELECTIVITY_BOUNDS:
368
+ return None
369
+
370
+ bounds = SELECTIVITY_BOUNDS[selectivity_pair]
371
+ gas1, gas2 = bounds['gases']
372
+
373
+ # Get predictions - use log space for plotting
374
+ if selected_view not in all_predictions['predictions_log']:
375
+ return None
376
+
377
+ if gas1 not in all_predictions['predictions_log'][selected_view] or gas2 not in all_predictions['predictions_log'][selected_view]:
378
+ return None
379
+
380
+ # Use log-space predictions for more accurate selectivity calculation
381
+ gas1_perm_log = all_predictions['predictions_log'][selected_view][gas1]
382
+ gas2_perm_log = all_predictions['predictions_log'][selected_view][gas2]
383
+
384
+ # Convert to original space for plotting
385
+ gas1_perm = 10**gas1_perm_log
386
+ gas2_perm = 10**gas2_perm_log
387
+
388
+ # Ensure positive values
389
+ gas1_perm = np.maximum(gas1_perm, 1e-10)
390
+ gas2_perm = np.maximum(gas2_perm, 1e-10)
391
+
392
+ # Calculate selectivity
393
+ selectivity = gas1_perm / gas2_perm
394
+
395
+ # Create boundary line
396
+ x1, x2 = bounds['x']
397
+ y1, y2 = bounds['y']
398
+
399
+ # Create figure
400
+ fig = go.Figure()
401
+
402
+ # Add 2008 upper bound line
403
+ fig.add_trace(go.Scatter(
404
+ x=[x1, x2],
405
+ y=[y1, y2],
406
+ mode='lines',
407
+ name='2008 Upper Bound',
408
+ line=dict(color='red', width=3, dash='dash'),
409
+ hoverinfo='name'
410
+ ))
411
+
412
+ # Add polymer points
413
+ smiles_list = all_predictions['original_smiles']
414
+
415
+ # Determine which polymers are above the bound
416
+ x_log = np.log10(gas1_perm)
417
+ y_log = np.log10(selectivity)
418
+
419
+ # Calculate boundary line parameters
420
+ x1_log, x2_log = np.log10(x1), np.log10(x2)
421
+ y1_log, y2_log = np.log10(y1), np.log10(y2)
422
+ a = (y1_log - y2_log) / (x1_log - x2_log)
423
+ b = y1_log - a * x1_log
424
+
425
+ # Calculate distance from boundary
426
+ y_bound = a * x_log + b
427
+ above_bound = y_log > y_bound
428
+
429
+ # Truncate long SMILES for hover text
430
+ hover_texts = []
431
+ for i, smiles in enumerate(smiles_list):
432
+ truncated = smiles if len(smiles) <= 100 else smiles[:97] + '...'
433
+ status = "Above Bound" if above_bound[i] else "Below Bound"
434
+ hover_text = (f"SMILES: {truncated}<br>"
435
+ f"{gas1}: {gas1_perm[i]:.3f}<br>"
436
+ f"{gas2}: {gas2_perm[i]:.3f}<br>"
437
+ f"Selectivity: {selectivity[i]:.3f}<br>"
438
+ f"Status: {status}")
439
+ hover_texts.append(hover_text)
440
+
441
+ # Add points (above bound)
442
+ if np.any(above_bound):
443
+ fig.add_trace(go.Scatter(
444
+ x=gas1_perm[above_bound],
445
+ y=selectivity[above_bound],
446
+ mode='markers',
447
+ name='Above Bound',
448
+ marker=dict(color='green', size=10, symbol='circle'),
449
+ text=[hover_texts[i] for i in range(len(hover_texts)) if above_bound[i]],
450
+ hovertemplate='%{text}<extra></extra>'
451
+ ))
452
+
453
+ # Add points (below bound)
454
+ if np.any(~above_bound):
455
+ fig.add_trace(go.Scatter(
456
+ x=gas1_perm[~above_bound],
457
+ y=selectivity[~above_bound],
458
+ mode='markers',
459
+ name='Below Bound',
460
+ marker=dict(color='blue', size=8, symbol='circle'),
461
+ text=[hover_texts[i] for i in range(len(hover_texts)) if not above_bound[i]],
462
+ hovertemplate='%{text}<extra></extra>'
463
+ ))
464
+
465
+ # Update layout
466
+ fig.update_xaxes(
467
+ title=f"{gas1} Permeability (Barrer)",
468
+ type="log",
469
+ gridcolor='lightgray'
470
+ )
471
+
472
+ fig.update_yaxes(
473
+ title=f"{gas1}/{gas2} Selectivity",
474
+ type="log",
475
+ gridcolor='lightgray'
476
+ )
477
+
478
+ fig.update_layout(
479
+ title=f"{gas1}/{gas2} Selectivity Plot",
480
+ hovermode='closest',
481
+ showlegend=True,
482
+ plot_bgcolor='white',
483
+ height=600
484
+ )
485
+
486
+ return fig
487
+
488
+ def get_polymers_above_bound(all_predictions, selected_view='Average', selectivity_pair='CO2/CH4'):
489
+ """
490
+ Get list of polymers above the 2008 upper bound.
491
+
492
+ Returns:
493
+ String listing polymers above bound
494
+ """
495
+ if all_predictions is None or selectivity_pair not in SELECTIVITY_BOUNDS:
496
+ return "No data available."
497
+
498
+ bounds = SELECTIVITY_BOUNDS[selectivity_pair]
499
+ gas1, gas2 = bounds['gases']
500
+
501
+ # Get predictions - use log space for calculation
502
+ if selected_view not in all_predictions['predictions_log']:
503
+ return "No predictions available for selected view."
504
+
505
+ if gas1 not in all_predictions['predictions_log'][selected_view] or gas2 not in all_predictions['predictions_log'][selected_view]:
506
+ return f"Predictions not available for {gas1} or {gas2}."
507
+
508
+ # Use log-space predictions
509
+ gas1_perm_log = all_predictions['predictions_log'][selected_view][gas1]
510
+ gas2_perm_log = all_predictions['predictions_log'][selected_view][gas2]
511
+
512
+ # Convert to original space
513
+ gas1_perm = 10**gas1_perm_log
514
+ gas2_perm = 10**gas2_perm_log
515
+
516
+ # Ensure positive values
517
+ gas1_perm = np.maximum(gas1_perm, 1e-10)
518
+ gas2_perm = np.maximum(gas2_perm, 1e-10)
519
+
520
+ # Calculate selectivity
521
+ selectivity = gas1_perm / gas2_perm
522
+
523
+ # Calculate which are above bound
524
+ x_log = np.log10(gas1_perm)
525
+ y_log = np.log10(selectivity)
526
+
527
+ x1, x2 = bounds['x']
528
+ y1, y2 = bounds['y']
529
+ x1_log, x2_log = np.log10(x1), np.log10(x2)
530
+ y1_log, y2_log = np.log10(y1), np.log10(y2)
531
+ a = (y1_log - y2_log) / (x1_log - x2_log)
532
+ b = y1_log - a * x1_log
533
+
534
+ y_bound = a * x_log + b
535
+ above_bound = y_log > y_bound
536
+
537
+ # Create report
538
+ smiles_list = all_predictions['original_smiles']
539
+ above_count = np.sum(above_bound)
540
+
541
+ report = f"**Polymers Above 2008 Upper Bound: {above_count}/{len(smiles_list)}**\n\n"
542
+
543
+ if above_count == 0:
544
+ report += "No polymers exceed the 2008 upper bound.\n"
545
+ else:
546
+ report += "| # | SMILES | " + gas1 + " | " + gas2 + " | Selectivity |\n"
547
+ report += "|---|--------|" + "-"*len(gas1) + "|" + "-"*len(gas2) + "|-------------|\n"
548
+
549
+ idx = 1
550
+ for i in range(len(smiles_list)):
551
+ if above_bound[i]:
552
+ smiles = smiles_list[i]
553
+ # Truncate if too long
554
+ if len(smiles) > 50:
555
+ smiles = smiles[:47] + "..."
556
+ report += f"| {idx} | `{smiles}` | {gas1_perm[i]:.3f} | {gas2_perm[i]:.3f} | {selectivity[i]:.3f} |\n"
557
+ idx += 1
558
+
559
+ return report
560
+
561
+ def process_smiles_input(text_input, file_input, selected_models):
562
+ """Process SMILES from text or file input."""
563
+ smiles_list = []
564
+
565
+ # Process text input
566
+ if text_input and text_input.strip():
567
+ lines = text_input.strip().split('\n')
568
+ smiles_list.extend([line.strip() for line in lines if line.strip()])
569
+
570
+ # Process file input
571
+ if file_input is not None:
572
+ try:
573
+ # Handle different file formats
574
+ file_path = file_input if isinstance(file_input, str) else file_input.name
575
+
576
+ # Try to read as CSV first
577
+ if file_path.endswith('.csv'):
578
+ df = pd.read_csv(file_input if isinstance(file_input, str) else file_input.name)
579
+ if 'SMILES' in df.columns:
580
+ # Read from SMILES column
581
+ smiles_from_file = df['SMILES'].dropna().astype(str).tolist()
582
+ smiles_list.extend([s.strip() for s in smiles_from_file if s.strip()])
583
+ else:
584
+ return None, f"❌ CSV file must contain a 'SMILES' column. Found columns: {', '.join(df.columns)}", []
585
+ else:
586
+ # Read as plain text file (.txt, .smi)
587
+ if isinstance(file_input, str):
588
+ with open(file_input, 'r') as f:
589
+ lines = f.readlines()
590
+ else:
591
+ content = file_input.read()
592
+ if isinstance(content, bytes):
593
+ content = content.decode('utf-8')
594
+ lines = content.strip().split('\n')
595
+
596
+ smiles_list.extend([line.strip() for line in lines if line.strip()])
597
+
598
+ except Exception as e:
599
+ return None, f"❌ Error reading file: {str(e)}", []
600
+
601
+ if not smiles_list:
602
+ return None, "❌ Please provide SMILES strings via text input or file upload.", []
603
+
604
+ # Remove duplicates while preserving order
605
+ seen = set()
606
+ unique_smiles = []
607
+ for s in smiles_list:
608
+ if s not in seen:
609
+ seen.add(s)
610
+ unique_smiles.append(s)
611
+
612
+ # Make predictions
613
+ all_predictions, report = predict_properties(unique_smiles, selected_models)
614
+
615
+ # Get available view options
616
+ view_options = []
617
+ if all_predictions:
618
+ view_options = ['Average'] + [m for m in selected_models if m in all_predictions['predictions']]
619
+
620
+ return all_predictions, report, view_options
621
+
622
+ # ============= GRADIO INTERFACE =============
623
+
624
+ # Get available models for the interface
625
+ available_models = []
626
+ for model_name in all_model_names:
627
+ if model_name in PRELOADED_MODELS and PRELOADED_MODELS[model_name]:
628
+ available_models.append(model_name)
629
+
630
+ if not available_models:
631
+ print("⚠️ WARNING: No models were successfully loaded!")
632
+ available_models = all_model_names # Show all options but they won't work
633
+
634
+ with gr.Blocks(title="Polymer Property Prediction for Gas Permeability and Separation") as iface:
635
+ # Navigation Bar
636
+ with gr.Row(elem_id="navbar"):
637
+ gr.Markdown("""
638
+ <div style="text-align: center;">
639
+ <h1>🔬 Polymer Property Prediction for Gas Permeability and Separation</h1>
640
+ <div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;">
641
+ <a href="https://github.com/liugangcode/torch-molecule" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;">
642
+ <img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" />
643
+ <span>💻 Support by torch-molecule and sklearn</span>
644
+ </a>
645
+ </div>
646
+ </div>
647
+ """)
648
+
649
+ # Main content
650
+ gr.Markdown("""
651
+ ## Batch Property Prediction for gas permeability properties (CH₄, CO₂, H₂, N₂, O₂)
652
+
653
+ **Input Options:**
654
+ - **Text Box**: Enter SMILES strings (one per line)
655
+ - **File Upload**: Upload a text file containing SMILES strings (.txt, .csv, .smi), see example file format for details
656
+
657
+ **Model Selection**: Choose one or more prediction models. If multiple models are selected, an averaged prediction will also be provided.
658
+
659
+ ⚠️ **Note**: All SMILES must be valid. Invalid SMILES will prevent prediction and must be corrected first. We treat the * as the polymerization point.
660
+ """)
661
+
662
+ with gr.Row():
663
+ with gr.Column():
664
+ gr.Markdown("### Input SMILES")
665
+ smiles_text = gr.Textbox(
666
+ label="Enter SMILES (one per line)",
667
+ placeholder="Enter SMILES here...",
668
+ lines=10,
669
+ value=DEFAULT_SMILES
670
+ )
671
+
672
+ smiles_file = gr.File(
673
+ label="Or upload a file with SMILES",
674
+ file_types=[".txt", ".csv", ".smi"]
675
+ )
676
+
677
+ with gr.Accordion("📄 Example File Format", open=False):
678
+ gr.Markdown("""
679
+ **For CSV files (.csv):**
680
+
681
+ Your CSV file must contain a column named "SMILES". Other columns are optional.
682
+
683
+ Example CSV content:
684
+ ```
685
+ SMILES,Name,Notes
686
+ *c1cc2c(cc1*)C1(C(C)C)c3ccccc3C2(C(C)C)c2cc3c(cc21)Oc1cc2nc(*)c(*)nc2cc1O3,Polymer1,High performance
687
+ *CN1CN(*)Cc2cc3c(cc21)C1c2ccccc2C3c2cc(*)c(*)cc21,Polymer2,Good selectivity
688
+ *C(=C(*)c1ccc2c(c1)C(C)(C)C(C)(C)C2(C)C)c1ccccc1,Polymer3,Standard
689
+ ```
690
+
691
+ **For text files (.txt, .smi):**
692
+
693
+ Simply list one SMILES per line:
694
+ ```
695
+ *c1cc2c(cc1*)C1(C(C)C)c3ccccc3C2(C(C)C)c2cc3c(cc21)Oc1cc2nc(*)c(*)nc2cc1O3
696
+ *CN1CN(*)Cc2cc3c(cc21)C1c2ccccc2C3c2cc(*)c(*)cc21
697
+ *C(=C(*)c1ccc2c(c1)C(C)(C)C(C)(C)C2(C)C)c1ccccc1
698
+ ```
699
+ """)
700
+
701
+ with gr.Column():
702
+ gr.Markdown("### Model Selection")
703
+ model_selector = gr.CheckboxGroup(
704
+ choices=available_models,
705
+ label="Select Models to Use",
706
+ value=[available_models[0]] if available_models else [],
707
+ info="Select one or more models. Predictions will be averaged if multiple models are selected."
708
+ )
709
+
710
+ with gr.Accordion("ℹ️ Model Information", open=True):
711
+ gr.Markdown("""
712
+ **Available Models:**
713
+ - **GREA**: Graph Rationalization with Environment-based Augmentations (Deep Learning)
714
+ <a href="https://arxiv.org/abs/2206.02886" target="_blank" style="text-decoration: none; color: inherit;">
715
+ 📄 View Paper
716
+ </a>
717
+ - **GCN**: Graph Convolutional Network (Deep Learning)
718
+ - **GIN**: Graph Isomorphism Network (Deep Learning)
719
+ - **RandomForest**: Random Forest Regressor (ML)
720
+ - **GaussianProcess**: Gaussian Process Regressor (ML)
721
+
722
+ **Gas Properties:**
723
+ - CH₄: Methane permeability
724
+ - CO₂: Carbon dioxide permeability
725
+ - H₂: Hydrogen permeability
726
+ - N₂: Nitrogen permeability
727
+ - O₂: Oxygen permeability
728
+
729
+ Units are in Barrer (10⁻¹⁰ cm³(STP)·cm/(cm²·s·cmHg))
730
+ """)
731
+
732
+ predict_btn = gr.Button("🔮 Predict Properties", variant="primary", size="lg")
733
+
734
+ with gr.Row():
735
+ prediction_status = gr.Textbox(label="Status", lines=5)
736
+
737
+ with gr.Row():
738
+ view_selector = gr.Radio(
739
+ choices=['Average'],
740
+ label="Select which predictions to display",
741
+ value='Average',
742
+ visible=False
743
+ )
744
+
745
+ with gr.Row():
746
+ prediction_results = gr.Dataframe(
747
+ label="Prediction Results",
748
+ wrap=True,
749
+ interactive=False
750
+ )
751
+
752
+ with gr.Row():
753
+ download_btn = gr.DownloadButton(
754
+ label="📥 Download Results as CSV",
755
+ visible=False
756
+ )
757
+
758
+ # Selectivity Plot Section
759
+ gr.Markdown("## Gas Selectivity Analysis")
760
+ gr.Markdown("Visualize polymer performance against the 2008 upper bound for gas separation.")
761
+
762
+ with gr.Row():
763
+ selectivity_pair_selector = gr.Radio(
764
+ choices=list(SELECTIVITY_BOUNDS.keys()),
765
+ label="Select Gas Pair",
766
+ value='CO2/CH4'
767
+ )
768
+
769
+ with gr.Row():
770
+ selectivity_plot = gr.Plot(label="Selectivity Plot")
771
+
772
+ with gr.Row():
773
+ polymers_above_bound = gr.Markdown("Run prediction to see polymers above the bound.")
774
+
775
+ # Hidden state to store all predictions
776
+ all_predictions_state = gr.State(None)
777
+
778
+ def on_predict(text_input, file_input, selected_models):
779
+ all_predictions, report, view_options = process_smiles_input(text_input, file_input, selected_models)
780
+
781
+ if all_predictions is not None:
782
+ # Format DataFrame for display
783
+ df = format_predictions_dataframe(all_predictions, 'Average')
784
+
785
+ # Update view selector with available options
786
+ view_selector_update = gr.Radio(
787
+ choices=view_options,
788
+ value='Average',
789
+ visible=True
790
+ )
791
+
792
+ # Save raw predictions to CSV for download
793
+ temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv')
794
+ df.to_csv(temp_csv.name, index=False)
795
+ temp_csv.close()
796
+
797
+ # Create selectivity plot
798
+ plot_fig = create_selectivity_plot(all_predictions, 'Average', 'CO2/CH4')
799
+
800
+ # Get polymers above bound
801
+ above_bound_report = get_polymers_above_bound(all_predictions, 'Average', 'CO2/CH4')
802
+
803
+ return (
804
+ all_predictions,
805
+ df,
806
+ report,
807
+ view_selector_update,
808
+ gr.DownloadButton(
809
+ label="📥 Download Results as CSV",
810
+ value=temp_csv.name,
811
+ visible=True
812
+ ),
813
+ plot_fig,
814
+ above_bound_report
815
+ )
816
+ else:
817
+ return (
818
+ None,
819
+ None,
820
+ report,
821
+ gr.Radio(visible=False),
822
+ gr.DownloadButton(visible=False),
823
+ None,
824
+ "Run prediction to see polymers above the bound."
825
+ )
826
+
827
+ def on_view_change(all_predictions, selected_view, selectivity_pair):
828
+ if all_predictions is None:
829
+ return None, gr.DownloadButton(visible=False), None, "No data available."
830
+
831
+ df = format_predictions_dataframe(all_predictions, selected_view)
832
+
833
+ # Update download with new view
834
+ temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv')
835
+ df.to_csv(temp_csv.name, index=False)
836
+ temp_csv.close()
837
+
838
+ # Update plot
839
+ plot_fig = create_selectivity_plot(all_predictions, selected_view, selectivity_pair)
840
+
841
+ # Get polymers above bound
842
+ above_bound_report = get_polymers_above_bound(all_predictions, selected_view, selectivity_pair)
843
+
844
+ return df, gr.DownloadButton(
845
+ label=f"📥 Download {selected_view} Results as CSV",
846
+ value=temp_csv.name,
847
+ visible=True
848
+ ), plot_fig, above_bound_report
849
+
850
+ def on_selectivity_change(all_predictions, selected_view, selectivity_pair):
851
+ if all_predictions is None:
852
+ return None, "No data available."
853
+
854
+ plot_fig = create_selectivity_plot(all_predictions, selected_view, selectivity_pair)
855
+ above_bound_report = get_polymers_above_bound(all_predictions, selected_view, selectivity_pair)
856
+
857
+ return plot_fig, above_bound_report
858
+
859
+ predict_btn.click(
860
+ on_predict,
861
+ inputs=[smiles_text, smiles_file, model_selector],
862
+ outputs=[all_predictions_state, prediction_results, prediction_status, view_selector, download_btn, selectivity_plot, polymers_above_bound]
863
+ )
864
+
865
+ view_selector.change(
866
+ on_view_change,
867
+ inputs=[all_predictions_state, view_selector, selectivity_pair_selector],
868
+ outputs=[prediction_results, download_btn, selectivity_plot, polymers_above_bound]
869
+ )
870
+
871
+ selectivity_pair_selector.change(
872
+ on_selectivity_change,
873
+ inputs=[all_predictions_state, view_selector, selectivity_pair_selector],
874
+ outputs=[selectivity_plot, polymers_above_bound]
875
+ )
876
+
877
+ # Launch the interface
878
+ if __name__ == "__main__":
879
+ # iface.launch(share=True)
880
+ iface.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pyarrow
2
+ pandas
3
+ joblib
4
+ scikit-learn==1.3.2
5
+ rdkit==2023.9.6
6
+ torch
7
+ huggingface_hub
8
+ gradio
9
+ imageio
10
+ spaces
11
+ torch-molecule
12
+ plotly