Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import tempfile | |
| import json | |
| import base64 | |
| from io import BytesIO | |
| from pathlib import Path | |
| import pickle | |
| import joblib | |
| from collections import defaultdict | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw, AllChem | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from sklearn.decomposition import PCA | |
| from huggingface_hub import hf_hub_download | |
| # Import torch_molecule models | |
| try: | |
| from torch_molecule import GREAMolecularPredictor, GNNMolecularPredictor | |
| import torch_molecule | |
| TORCH_MOLECULE_AVAILABLE = True | |
| print('torch_molecule version: ', torch_molecule.__version__) | |
| except ImportError: | |
| TORCH_MOLECULE_AVAILABLE = False | |
| print("Warning: torch_molecule not available. Some models may not work.") | |
| # ============= PROPERTY CONFIGURATION ============= | |
| # Load property mapping | |
| with open('property_mapping.json', 'r') as f: | |
| PROPERTY_MAPPING = json.load(f) | |
| # Filter out Gas Transport Properties | |
| ALL_PROPERTIES = {} | |
| PROPERTY_CATEGORIES = defaultdict(list) | |
| for prop_abbr, prop_info in PROPERTY_MAPPING.items(): | |
| if prop_info['category'] != 'Gas Transport Properties': | |
| ALL_PROPERTIES[prop_abbr] = prop_info | |
| PROPERTY_CATEGORIES[prop_info['category']].append(prop_abbr) | |
| # Get just permeability properties for plotting | |
| PERMEABILITY_PROPERTIES = [p for p in ALL_PROPERTIES.keys() if ALL_PROPERTIES[p]['category'] == 'Permeability Properties'] | |
| print(f"Loaded {len(ALL_PROPERTIES)} properties (excluding Gas Transport Properties)") | |
| print(f"Categories: {list(PROPERTY_CATEGORIES.keys())}") | |
| all_model_names = ['GREA', 'GCN', 'GIN', 'RandomForest', 'GaussianProcess'] | |
| # Training configuration | |
| TRAIN_IN_LOG = True | |
| HF_REPO_ID = "liuganghuggingface/polymer-prediction-gas-models" | |
| # Default SMILES for testing | |
| DEFAULT_SMILES = """*c1cc2c(cc1*)C1(C(C)C)c3ccccc3C2(C(C)C)c2cc3c(cc21)Oc1cc2nc(*)c(*)nc2cc1O3 | |
| *CN1CN(*)Cc2cc3c(cc21)C1c2ccccc2C3c2cc(*)c(*)cc21 | |
| *C(=C(*)c1ccc2c(c1)C(C)(C)C(C)(C)C2(C)C)c1ccccc1""" | |
| # Selectivity boundary parameters | |
| SELECTIVITY_BOUNDS = { | |
| 'CO2/CH4': { | |
| 'x': [1.00E+05, 1.00E-02], | |
| 'y': [1.00E+05/2.21E+04, 1.00E-02/4.88E-06], | |
| 'gases': ('CO2', 'CH4') | |
| }, | |
| 'H2/CH4': { | |
| 'x': [5.00E+04, 2.50E+00], | |
| 'y': [5.00E+04/8.67E+04, 2.50E+00/5.64E-04], | |
| 'gases': ('H2', 'CH4') | |
| }, | |
| 'O2/N2': { | |
| 'x': [5.00E+04, 1.00E-03], | |
| 'y': [5.00E+04/2.78E+04, 1.00E-03/2.43E-05], | |
| 'gases': ('O2', 'N2') | |
| }, | |
| 'H2/N2': { | |
| 'x': [1.00E+05, 1.00E-01], | |
| 'y': [1.00E+05/1.02E+05, 1.00E-01/9.21E-06], | |
| 'gases': ('H2', 'N2') | |
| }, | |
| 'CO2/N2': { | |
| 'x': [1.00E+06, 1.00E-04], | |
| 'y': [1.00E+06/3.05E+05, 1.00E-04/1.05E-08], | |
| 'gases': ('CO2', 'N2') | |
| } | |
| } | |
| # ============= UTILITY FUNCTIONS ============= | |
| def smiles_to_mol_image(smiles, img_size=(200, 200)): | |
| """Convert SMILES to molecule image as base64 string.""" | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| img = Draw.MolToImage(mol, size=img_size) | |
| buffered = BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_str}" | |
| def validate_smiles(smiles_list): | |
| """Validate a list of SMILES strings.""" | |
| valid_smiles = [] | |
| invalid_smiles = [] | |
| for idx, smiles in enumerate(smiles_list): | |
| smiles = smiles.strip() | |
| if not smiles: | |
| continue | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is not None: | |
| standardized = Chem.MolToSmiles(mol, isomericSmiles=True) | |
| valid_smiles.append((idx, smiles, standardized)) | |
| else: | |
| invalid_smiles.append((idx, smiles)) | |
| report = f"✅ Valid SMILES: {len(valid_smiles)}\n" | |
| report += f"❌ Invalid SMILES: {len(invalid_smiles)}\n" | |
| if invalid_smiles: | |
| report += "\n**Invalid SMILES detected:**\n" | |
| for idx, smiles in invalid_smiles: | |
| report += f" - Line {idx + 1}: `{smiles}`\n" | |
| report += "\n⚠️ **Please remove or correct the invalid SMILES before proceeding.**" | |
| return valid_smiles, invalid_smiles, report | |
| def smiles_to_fingerprint(smiles_list, n_bits=2048): | |
| """Convert SMILES to Morgan fingerprints.""" | |
| fingerprints = [] | |
| for smiles in smiles_list: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is not None: | |
| fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits) | |
| fingerprints.append(np.array(fp)) | |
| else: | |
| fingerprints.append(np.zeros(n_bits)) | |
| return np.array(fingerprints) | |
| # ============= MODEL LOADING ============= | |
| def load_all_models(): | |
| """Load all available models from HuggingFace Hub.""" | |
| print("Loading models from HuggingFace Hub (CPU only)...") | |
| loaded_models = {} | |
| device = torch.device('cpu') | |
| # Load models for ALL properties (not just permeability) | |
| for model_name in all_model_names: | |
| loaded_models[model_name] = {} | |
| # Iterate through all properties except Gas Transport Properties | |
| for prop_abbr in ALL_PROPERTIES.keys(): | |
| model_filename = f"{model_name.lower()}_{prop_abbr.lower()}" | |
| try: | |
| if model_name in ['GREA', 'GCN', 'GIN']: | |
| filename = f"{model_filename}.pt" | |
| if not TORCH_MOLECULE_AVAILABLE: | |
| continue | |
| print(f" Downloading {filename}...") | |
| model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename) | |
| if model_name == 'GREA': | |
| model = GREAMolecularPredictor(device='cpu') | |
| elif model_name == 'GCN': | |
| model = GNNMolecularPredictor(gnn_type='gcn-virtual', device='cpu') | |
| elif model_name == 'GIN': | |
| model = GNNMolecularPredictor(gnn_type='gin-virtual', device='cpu') | |
| model.load_from_local(model_path) | |
| loaded_models[model_name][prop_abbr] = (model, 'torch_molecule') | |
| print(f" ✓ Loaded {model_name} for {prop_abbr}") | |
| else: # sklearn models | |
| filename = f"{model_filename}.pkl" | |
| print(f" Downloading {filename}...") | |
| model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=filename) | |
| model = joblib.load(model_path) | |
| loaded_models[model_name][prop_abbr] = (model, 'sklearn') | |
| print(f" ✓ Loaded {model_name} for {prop_abbr}") | |
| except Exception as e: | |
| print(f" ❌ Error loading {model_name} for {prop_abbr}: {e}") | |
| print("Model loading complete!") | |
| return loaded_models | |
| PRELOADED_MODELS = load_all_models() | |
| # ============= PREDICTION FUNCTIONS ============= | |
| def predict_properties(smiles_list, selected_models, progress=gr.Progress()): | |
| """Predict properties for a list of SMILES.""" | |
| if not selected_models: | |
| return None, "❌ Please select at least one model." | |
| progress(0.1, desc="Validating SMILES...") | |
| valid_smiles, invalid_smiles, validation_report = validate_smiles(smiles_list) | |
| if invalid_smiles: | |
| return None, validation_report | |
| if not valid_smiles: | |
| return None, "❌ No SMILES provided." | |
| indices, original_smiles, standardized_smiles = zip(*valid_smiles) | |
| all_predictions = { | |
| 'original_smiles': list(original_smiles), | |
| 'standardized_smiles': list(standardized_smiles), | |
| 'predictions': {}, | |
| 'predictions_log': {} | |
| } | |
| # Prepare fingerprints for sklearn models | |
| X_fp = None | |
| needs_fingerprints = any(model in selected_models for model in ['RandomForest', 'GaussianProcess']) | |
| if needs_fingerprints: | |
| progress(0.2, desc="Computing fingerprints...") | |
| X_fp = smiles_to_fingerprint(standardized_smiles) | |
| model_errors = [] | |
| # Predict ALL properties (not just permeability) | |
| available_props = list(ALL_PROPERTIES.keys()) | |
| total_predictions = len(available_props) * len(selected_models) | |
| pred_count = 0 | |
| for model_name in selected_models: | |
| all_predictions['predictions'][model_name] = {} | |
| all_predictions['predictions_log'][model_name] = {} | |
| for prop in available_props: | |
| progress(0.2 + 0.7 * pred_count / total_predictions, | |
| desc=f"Predicting {prop} with {model_name}...") | |
| if model_name not in PRELOADED_MODELS or prop not in PRELOADED_MODELS[model_name]: | |
| model_errors.append(f"{model_name} for {prop}") | |
| pred_count += 1 | |
| continue | |
| model, model_type = PRELOADED_MODELS[model_name][prop] | |
| try: | |
| if model_type == 'torch_molecule': | |
| with torch.no_grad(): | |
| predictions_dict = model.predict(list(standardized_smiles)) | |
| predictions = predictions_dict['prediction'] | |
| else: | |
| predictions = model.predict(X_fp) | |
| if isinstance(predictions, np.ndarray) and predictions.ndim > 1: | |
| predictions = predictions.flatten() | |
| # Check if this property was trained in log scale (only for certain properties) | |
| prop_category = ALL_PROPERTIES[prop]['category'] | |
| if prop_category == 'Permeability Properties' and TRAIN_IN_LOG: | |
| predictions_original = 10**predictions | |
| all_predictions['predictions'][model_name][prop] = predictions_original | |
| all_predictions['predictions_log'][model_name][prop] = predictions | |
| else: | |
| # For other properties, store as-is | |
| all_predictions['predictions'][model_name][prop] = predictions | |
| # Log transform for potential log-scale visualizations | |
| all_predictions['predictions_log'][model_name][prop] = np.log10(np.maximum(np.abs(predictions), 1e-10)) | |
| except Exception as e: | |
| print(f"Error predicting {model_name} for {prop}: {e}") | |
| model_errors.append(f"{model_name} for {prop}") | |
| pred_count += 1 | |
| # Calculate averages | |
| progress(0.9, desc="Computing averages...") | |
| all_predictions['predictions']['Average'] = {} | |
| all_predictions['predictions_log']['Average'] = {} | |
| for prop in available_props: | |
| prop_predictions = [] | |
| prop_predictions_log = [] | |
| for model_name in selected_models: | |
| if model_name in all_predictions['predictions'] and prop in all_predictions['predictions'][model_name]: | |
| prop_predictions.append(all_predictions['predictions'][model_name][prop]) | |
| prop_predictions_log.append(all_predictions['predictions_log'][model_name][prop]) | |
| if prop_predictions: | |
| if len(prop_predictions) > 1: | |
| all_predictions['predictions']['Average'][prop] = np.mean(np.array(prop_predictions), axis=0) | |
| all_predictions['predictions_log']['Average'][prop] = np.mean(np.array(prop_predictions_log), axis=0) | |
| else: | |
| all_predictions['predictions']['Average'][prop] = prop_predictions[0] | |
| all_predictions['predictions_log']['Average'][prop] = prop_predictions_log[0] | |
| report = validation_report + "\n" | |
| if model_errors: | |
| report += f"\n⚠️ Some models were not available: {len(set(model_errors))} property-model combinations\n" | |
| report += f"\n✅ Successfully predicted properties for {len(valid_smiles)} molecules using {len(selected_models)} model(s)." | |
| report += f"\n📊 Attempted to predict {len(available_props)} properties across all categories." | |
| report += f"\n💻 All predictions run on CPU." | |
| progress(1.0, desc="Done!") | |
| return all_predictions, report | |
| # ============= VISUALIZATION FUNCTIONS ============= | |
| def create_results_gallery_html(all_predictions, selected_view='Average', selected_category='All', max_display=30): | |
| """ | |
| Create an HTML gallery showing top molecules with their structures grouped by category. | |
| """ | |
| if all_predictions is None: | |
| return "<p>No data available.</p>" | |
| num_molecules = len(all_predictions['original_smiles']) | |
| display_count = min(num_molecules, max_display) | |
| # Get predictions for selected view | |
| if selected_view not in all_predictions['predictions']: | |
| return "<p>No predictions available for selected view.</p>" | |
| predictions = all_predictions['predictions'][selected_view] | |
| # Filter properties by category | |
| if selected_category == 'All': | |
| display_categories = PROPERTY_CATEGORIES | |
| else: | |
| display_categories = {selected_category: PROPERTY_CATEGORIES[selected_category]} | |
| html = f""" | |
| <div style='font-family: Arial, sans-serif; color: #000;'> | |
| <h3 style='color: #1a1a1a; margin-bottom: 10px;'>Prediction Results - {selected_category}</h3> | |
| <p style='color: #333; font-size: 1.1em; font-weight: 500;'><strong>Showing top {display_count} of {num_molecules} molecules.</strong></p> | |
| {f"<p style='color: #d97706; font-weight: 500; font-size: 1.05em;'>⬇️ Download the CSV file below to see all {num_molecules} results.</p>" if num_molecules > max_display else ""} | |
| """ | |
| # Display each molecule | |
| for idx in range(display_count): | |
| smiles = all_predictions['original_smiles'][idx] | |
| mol_img = smiles_to_mol_image(smiles, img_size=(250, 250)) | |
| html += f""" | |
| <div style='border: 2px solid #cbd5e1; border-radius: 12px; padding: 20px; margin: 20px 0; background: #ffffff; box-shadow: 0 2px 8px rgba(0,0,0,0.1);'> | |
| <div style='display: flex; gap: 25px; align-items: flex-start;'> | |
| <div style='flex-shrink: 0; background: #f8fafc; padding: 10px; border-radius: 8px;'> | |
| {"<img src='" + mol_img + "' style='width: 250px; height: 250px; display: block;'/>" if mol_img else "<p style='color: #ef4444; font-weight: 500;'>Invalid structure</p>"} | |
| </div> | |
| <div style='flex-grow: 1;'> | |
| <h4 style='color: #0f172a; margin: 0 0 15px 0; font-size: 1.3em; border-bottom: 2px solid #e2e8f0; padding-bottom: 8px;'>Molecule {idx + 1}</h4> | |
| <p style='color: #475569; margin-bottom: 15px; background: #f1f5f9; padding: 10px; border-radius: 6px; word-break: break-all;'> | |
| <strong style='color: #1e293b;'>SMILES:</strong> | |
| <code style='background: #e2e8f0; padding: 4px 8px; border-radius: 4px; font-size: 0.9em; color: #334155;'>{smiles[:100]}{"..." if len(smiles) > 100 else ""}</code> | |
| </p> | |
| """ | |
| # Display properties by category - always show all properties | |
| for category, props in display_categories.items(): | |
| category_props = [p for p in props if p in predictions] | |
| if not category_props: | |
| continue | |
| html += f""" | |
| <div style='margin-top: 15px; background: #f8fafc; padding: 12px; border-radius: 8px; border-left: 4px solid #3b82f6;'> | |
| <strong style='color: #1e40af; font-size: 1.1em; display: block; margin-bottom: 8px;'>{category}:</strong> | |
| """ | |
| # Display all properties for this category | |
| for prop in category_props: | |
| if prop in predictions: | |
| value = predictions[prop][idx] | |
| prop_info = ALL_PROPERTIES[prop] | |
| html += f""" | |
| <div style='margin: 6px 0; padding: 6px 10px; background: #ffffff; border-radius: 4px;'> | |
| <span style='color: #334155; font-weight: 500;'>• {prop_info['full_name']}</span> | |
| <span style='color: #64748b;'>({prop_info['unit']}):</span> | |
| <span style='color: #0369a1; font-weight: bold; font-size: 1.05em;'>{value:.3f}</span> | |
| </div> | |
| """ | |
| html += "</div>" | |
| html += """ | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def format_full_predictions_csv(all_predictions, selected_view='Average'): | |
| """Format all predictions into CSV for download.""" | |
| if all_predictions is None: | |
| return None | |
| df = pd.DataFrame({ | |
| 'SMILES': all_predictions['original_smiles'] | |
| }) | |
| predictions = all_predictions['predictions'][selected_view] | |
| # Add predictions grouped by category | |
| for category, props in PROPERTY_CATEGORIES.items(): | |
| for prop in props: | |
| if prop in predictions: | |
| prop_info = ALL_PROPERTIES[prop] | |
| col_name = f"{prop_info['full_name']} ({prop_info['unit']})" | |
| df[col_name] = predictions[prop] | |
| return df | |
| def create_selectivity_plot_with_images(all_predictions, selected_view='Average', selectivity_pair='CO2/CH4', max_display=30): | |
| """Create selectivity plot with molecule images on hover.""" | |
| if all_predictions is None or selectivity_pair not in SELECTIVITY_BOUNDS: | |
| return None | |
| bounds = SELECTIVITY_BOUNDS[selectivity_pair] | |
| gas1, gas2 = bounds['gases'] | |
| if selected_view not in all_predictions['predictions_log']: | |
| return None | |
| if gas1 not in all_predictions['predictions_log'][selected_view] or gas2 not in all_predictions['predictions_log'][selected_view]: | |
| return None | |
| # Get predictions (limit to max_display) | |
| num_molecules = len(all_predictions['original_smiles']) | |
| display_count = min(num_molecules, max_display) | |
| gas1_perm_log = all_predictions['predictions_log'][selected_view][gas1][:display_count] | |
| gas2_perm_log = all_predictions['predictions_log'][selected_view][gas2][:display_count] | |
| gas1_perm = 10**gas1_perm_log | |
| gas2_perm = 10**gas2_perm_log | |
| gas1_perm = np.maximum(gas1_perm, 1e-10) | |
| gas2_perm = np.maximum(gas2_perm, 1e-10) | |
| selectivity = gas1_perm / gas2_perm | |
| # Create boundary line | |
| x1, x2 = bounds['x'] | |
| y1, y2 = bounds['y'] | |
| fig = go.Figure() | |
| # Add 2008 upper bound | |
| fig.add_trace(go.Scatter( | |
| x=[x1, x2], | |
| y=[y1, y2], | |
| mode='lines', | |
| name='2008 Upper Bound', | |
| line=dict(color='red', width=3, dash='dash'), | |
| hoverinfo='name' | |
| )) | |
| # Determine above/below bound | |
| x_log = np.log10(gas1_perm) | |
| y_log = np.log10(selectivity) | |
| x1_log, x2_log = np.log10(x1), np.log10(x2) | |
| y1_log, y2_log = np.log10(y1), np.log10(y2) | |
| a = (y1_log - y2_log) / (x1_log - x2_log) | |
| b = y1_log - a * x1_log | |
| y_bound = a * x_log + b | |
| above_bound = y_log > y_bound | |
| # Create hover texts with molecule info | |
| smiles_list = all_predictions['original_smiles'][:display_count] | |
| hover_texts = [] | |
| for i, smiles in enumerate(smiles_list): | |
| truncated = smiles if len(smiles) <= 80 else smiles[:77] + '...' | |
| status = "Above Bound" if above_bound[i] else "Below Bound" | |
| hover_text = (f"SMILES: {truncated}<br>" | |
| f"{gas1}: {gas1_perm[i]:.3f} Barrer<br>" | |
| f"{gas2}: {gas2_perm[i]:.3f} Barrer<br>" | |
| f"Selectivity: {selectivity[i]:.3f}<br>" | |
| f"Status: {status}") | |
| hover_texts.append(hover_text) | |
| # Add points | |
| if np.any(above_bound): | |
| fig.add_trace(go.Scatter( | |
| x=gas1_perm[above_bound], | |
| y=selectivity[above_bound], | |
| mode='markers', | |
| name='Above Bound', | |
| marker=dict(color='green', size=10, symbol='circle'), | |
| text=[hover_texts[i] for i in range(len(hover_texts)) if above_bound[i]], | |
| hovertemplate='%{text}<extra></extra>' | |
| )) | |
| if np.any(~above_bound): | |
| fig.add_trace(go.Scatter( | |
| x=gas1_perm[~above_bound], | |
| y=selectivity[~above_bound], | |
| mode='markers', | |
| name='Below Bound', | |
| marker=dict(color='blue', size=8, symbol='circle'), | |
| text=[hover_texts[i] for i in range(len(hover_texts)) if not above_bound[i]], | |
| hovertemplate='%{text}<extra></extra>' | |
| )) | |
| fig.update_xaxes( | |
| title=f"{gas1} Permeability (Barrer)", | |
| type="log", | |
| gridcolor='lightgray' | |
| ) | |
| fig.update_yaxes( | |
| title=f"{gas1}/{gas2} Selectivity", | |
| type="log", | |
| gridcolor='lightgray' | |
| ) | |
| fig.update_layout( | |
| title=f"{gas1}/{gas2} Selectivity Plot (Top {display_count} molecules)", | |
| hovermode='closest', | |
| showlegend=True, | |
| plot_bgcolor='white', | |
| height=600 | |
| ) | |
| return fig | |
| def generate_all_selectivity_plots(all_predictions, selected_view='Average', max_display=30): | |
| """Generate all selectivity plots at once.""" | |
| if all_predictions is None: | |
| return {} | |
| all_plots = {} | |
| for selectivity_pair in SELECTIVITY_BOUNDS.keys(): | |
| plot = create_selectivity_plot_with_images(all_predictions, selected_view, selectivity_pair, max_display) | |
| if plot is not None: | |
| all_plots[selectivity_pair] = plot | |
| return all_plots | |
| def create_pca_plot(all_predictions, selected_view='Average', max_display=30): | |
| """Create PCA plot with all properties on hover.""" | |
| if all_predictions is None: | |
| return None | |
| num_molecules = len(all_predictions['original_smiles']) | |
| display_count = min(num_molecules, max_display) | |
| smiles_list = all_predictions['standardized_smiles'][:display_count] | |
| # Compute fingerprints | |
| X_fp = smiles_to_fingerprint(smiles_list) | |
| # Perform PCA | |
| pca = PCA(n_components=2) | |
| X_pca = pca.fit_transform(X_fp) | |
| # Create hover texts with all properties | |
| predictions = all_predictions['predictions'][selected_view] | |
| hover_texts = [] | |
| for idx in range(display_count): | |
| smiles = all_predictions['original_smiles'][idx] | |
| truncated = smiles if len(smiles) <= 60 else smiles[:57] + '...' | |
| hover_text = f"SMILES: {truncated}<br>" | |
| for category, props in PROPERTY_CATEGORIES.items(): | |
| hover_text += f"<br><b>{category}:</b><br>" | |
| for prop in props: | |
| if prop in predictions: | |
| value = predictions[prop][idx] | |
| prop_info = ALL_PROPERTIES[prop] | |
| hover_text += f" {prop_info['full_name']}: {value:.3f} {prop_info['unit']}<br>" | |
| hover_texts.append(hover_text) | |
| # Create plot | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=X_pca[:, 0], | |
| y=X_pca[:, 1], | |
| mode='markers', | |
| marker=dict( | |
| size=10, | |
| color=np.arange(display_count), | |
| colorscale='Viridis', | |
| showscale=True, | |
| colorbar=dict(title="Molecule #") | |
| ), | |
| text=hover_texts, | |
| hovertemplate='%{text}<extra></extra>' | |
| )) | |
| fig.update_layout( | |
| title=f"PCA Visualization of Molecular Structures (Top {display_count} molecules)", | |
| xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)", | |
| yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)", | |
| hovermode='closest', | |
| plot_bgcolor='white', | |
| height=600 | |
| ) | |
| return fig | |
| # ============= GRADIO INTERFACE ============= | |
| available_models = [m for m in all_model_names if m in PRELOADED_MODELS and PRELOADED_MODELS[m]] | |
| if not available_models: | |
| print("⚠️ WARNING: No models loaded!") | |
| available_models = all_model_names | |
| with gr.Blocks(title="Polymer Property Prediction", theme=gr.themes.Soft()) as iface: | |
| gr.Markdown(""" | |
| <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; margin-bottom: 20px;"> | |
| <h1 style="color: white; margin: 0; font-size: 2.5em; text-shadow: 2px 2px 4px rgba(0,0,0,0.3);">🔬 Polymer Property Prediction</h1> | |
| <p style="font-size: 1.2em; color: #f0f0f0; margin: 10px 0 0 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.3);">Predict electronic, dielectric & optical, thermal, physical & thermodynamic, and gas permeability properties</p> | |
| <div style="margin-top: 15px;"> | |
| <a href="https://github.com/liugangcode/torch-molecule" target="_blank" | |
| style="color: #fff; text-decoration: none; background: rgba(255,255,255,0.2); padding: 10px 20px; border-radius: 20px; font-weight: 500; backdrop-filter: blur(10px);"> | |
| 💻 Powered by torch-molecule & sklearn | |
| </a> | |
| </div> | |
| </div> | |
| """) | |
| # Input section - more compact layout | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📝 Input SMILES") | |
| smiles_text = gr.Textbox( | |
| label="Enter SMILES (one per line)", | |
| placeholder="Enter polymer SMILES strings, one per line...", | |
| lines=8, | |
| value=DEFAULT_SMILES | |
| ) | |
| smiles_file = gr.File( | |
| label="Or upload a file (.txt, .csv, .smi)", | |
| file_types=[".txt", ".csv", ".smi"] | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Model Selection") | |
| model_selector = gr.CheckboxGroup( | |
| choices=available_models, | |
| label="Select Models for Ensemble Prediction", | |
| value=[available_models[0]] if available_models else [], | |
| info="Choose one or more models. Multiple models will be averaged for robust predictions." | |
| ) | |
| predict_btn = gr.Button("🔮 Predict Properties", variant="primary", size="lg") | |
| prediction_status = gr.Textbox(label="📊 Status", lines=4, show_label=True) | |
| # Results section | |
| gr.Markdown(""" | |
| <div style='margin-top: 30px; padding: 15px; background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); border-radius: 10px;'> | |
| <h2 style='color: white; margin: 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.2);'>📊 Prediction Results</h2> | |
| <p style='color: #fef3c7; margin: 5px 0 0 0; font-size: 0.9em;'>Results include: Electronic (bandgap, ionization energy), Dielectric & Optical (refractive index), Thermal (Tg, Tm, conductivity), Physical (density, FFV, radius of gyration), and Gas Permeability properties</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| view_selector = gr.Radio( | |
| choices=['Average'], | |
| label="Select Model View", | |
| value='Average', | |
| visible=False | |
| ) | |
| category_selector = gr.Radio( | |
| choices=['All'] + list(PROPERTY_CATEGORIES.keys()), | |
| label="Select Property Category", | |
| value='All', | |
| info="Choose a category to view specific properties" | |
| ) | |
| results_html = gr.HTML(label="Results with Molecular Structures") | |
| download_btn = gr.DownloadButton( | |
| label="📥 Download All Results (CSV)", | |
| visible=False | |
| ) | |
| # Visualization section | |
| gr.Markdown(""" | |
| <div style='margin-top: 30px; padding: 15px; background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); border-radius: 10px;'> | |
| <h2 style='color: white; margin: 0; text-shadow: 1px 1px 2px rgba(0,0,0,0.2);'>📈 Interactive Visualizations</h2> | |
| <p style='color: #f0f9ff; margin: 5px 0 0 0; font-size: 0.95em;'>Limited to top 30 molecules for performance</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("🎯 Gas Selectivity Analysis"): | |
| gr.Markdown(""" | |
| <div style='background: #dbeafe; border-left: 4px solid #2563eb; padding: 12px; border-radius: 6px; margin-bottom: 15px;'> | |
| <p style='margin: 0; color: #1e40af; font-weight: 500;'> | |
| Analyze gas separation performance against the 2008 Robeson upper bounds | |
| </p> | |
| </div> | |
| """) | |
| plot_selectivity_btn = gr.Button("📊 Generate All Selectivity Plots", variant="secondary", size="lg") | |
| with gr.Row(): | |
| selectivity_pair_selector = gr.Radio( | |
| choices=list(SELECTIVITY_BOUNDS.keys()), | |
| label="Select Gas Pair to View", | |
| value='CO2/CH4', | |
| visible=False | |
| ) | |
| selectivity_plot = gr.Plot(label="Selectivity Plot") | |
| with gr.TabItem("🗺️ PCA Visualization"): | |
| gr.Markdown(""" | |
| <div style='background: #dbeafe; border-left: 4px solid #2563eb; padding: 12px; border-radius: 6px; margin-bottom: 15px;'> | |
| <p style='margin: 0; color: #1e40af; font-weight: 500;'> | |
| Explore the chemical space using PCA dimensionality reduction on molecular fingerprints | |
| </p> | |
| </div> | |
| """) | |
| plot_pca_btn = gr.Button("📊 Generate PCA Plot", variant="secondary") | |
| pca_plot = gr.Plot(label="PCA Plot") | |
| # Hidden state | |
| all_predictions_state = gr.State(None) | |
| # Event handlers | |
| def on_predict(text_input, file_input, selected_models): | |
| # Process input | |
| smiles_list = [] | |
| if text_input and text_input.strip(): | |
| smiles_list.extend([line.strip() for line in text_input.strip().split('\n') if line.strip()]) | |
| if file_input is not None: | |
| try: | |
| file_path = file_input if isinstance(file_input, str) else file_input.name | |
| if file_path.endswith('.csv'): | |
| df = pd.read_csv(file_input if isinstance(file_input, str) else file_input.name) | |
| if 'SMILES' in df.columns: | |
| smiles_list.extend(df['SMILES'].dropna().astype(str).tolist()) | |
| else: | |
| if isinstance(file_input, str): | |
| with open(file_input, 'r') as f: | |
| lines = f.readlines() | |
| else: | |
| content = file_input.read() | |
| if isinstance(content, bytes): | |
| content = content.decode('utf-8') | |
| lines = content.strip().split('\n') | |
| smiles_list.extend([line.strip() for line in lines if line.strip()]) | |
| except Exception as e: | |
| return None, "", f"❌ Error reading file: {str(e)}", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) | |
| if not smiles_list: | |
| return None, "", "❌ Please provide SMILES strings.", gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) | |
| # Remove duplicates | |
| unique_smiles = list(dict.fromkeys(smiles_list)) | |
| # Make predictions | |
| all_predictions, report = predict_properties(unique_smiles, selected_models) | |
| if all_predictions is None: | |
| return None, "", report, gr.Radio(visible=False), gr.Radio(visible=True, value='All'), gr.DownloadButton(visible=False) | |
| # Create results gallery with default "All" category | |
| results_gallery = create_results_gallery_html(all_predictions, 'Average', 'All', max_display=30) | |
| # Prepare CSV download | |
| df_full = format_full_predictions_csv(all_predictions, 'Average') | |
| temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') | |
| df_full.to_csv(temp_csv.name, index=False) | |
| temp_csv.close() | |
| # Update view selector | |
| view_options = ['Average'] + [m for m in selected_models if m in all_predictions['predictions']] | |
| view_selector_update = gr.Radio( | |
| choices=view_options, | |
| value='Average', | |
| visible=True | |
| ) | |
| category_selector_update = gr.Radio( | |
| choices=['All'] + list(PROPERTY_CATEGORIES.keys()), | |
| value='All', | |
| visible=True | |
| ) | |
| return ( | |
| all_predictions, | |
| results_gallery, | |
| report, | |
| view_selector_update, | |
| category_selector_update, | |
| gr.DownloadButton( | |
| label="📥 Download All Results (CSV)", | |
| value=temp_csv.name, | |
| visible=True | |
| ) | |
| ) | |
| def on_view_or_category_change(all_predictions, selected_view, selected_category): | |
| if all_predictions is None: | |
| return "", gr.DownloadButton(visible=False) | |
| results_gallery = create_results_gallery_html(all_predictions, selected_view, selected_category, max_display=30) | |
| df_full = format_full_predictions_csv(all_predictions, selected_view) | |
| temp_csv = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv') | |
| df_full.to_csv(temp_csv.name, index=False) | |
| temp_csv.close() | |
| return results_gallery, gr.DownloadButton( | |
| label=f"📥 Download {selected_view} Results (CSV)", | |
| value=temp_csv.name, | |
| visible=True | |
| ) | |
| def on_plot_selectivity(all_predictions, selected_view): | |
| """Generate all selectivity plots at once.""" | |
| if all_predictions is None: | |
| return {}, gr.Radio(visible=False), None | |
| # Generate all plots | |
| all_plots = generate_all_selectivity_plots(all_predictions, selected_view, max_display=30) | |
| if not all_plots: | |
| return {}, gr.Radio(visible=False), None | |
| # Show the first plot (CO2/CH4) | |
| first_pair = 'CO2/CH4' if 'CO2/CH4' in all_plots else list(all_plots.keys())[0] | |
| return ( | |
| all_plots, | |
| gr.Radio( | |
| choices=list(all_plots.keys()), | |
| value=first_pair, | |
| visible=True | |
| ), | |
| all_plots[first_pair] | |
| ) | |
| def on_selectivity_pair_change(all_plots_dict, selected_pair): | |
| """Switch between pre-generated selectivity plots.""" | |
| if not all_plots_dict or selected_pair not in all_plots_dict: | |
| return None | |
| return all_plots_dict[selected_pair] | |
| def on_plot_pca(all_predictions, selected_view): | |
| if all_predictions is None: | |
| return None | |
| return create_pca_plot(all_predictions, selected_view, max_display=30) | |
| # Hidden state for storing all selectivity plots | |
| all_selectivity_plots_state = gr.State({}) | |
| # Connect events | |
| predict_btn.click( | |
| on_predict, | |
| inputs=[smiles_text, smiles_file, model_selector], | |
| outputs=[all_predictions_state, results_html, prediction_status, view_selector, category_selector, download_btn] | |
| ) | |
| view_selector.change( | |
| on_view_or_category_change, | |
| inputs=[all_predictions_state, view_selector, category_selector], | |
| outputs=[results_html, download_btn] | |
| ) | |
| category_selector.change( | |
| on_view_or_category_change, | |
| inputs=[all_predictions_state, view_selector, category_selector], | |
| outputs=[results_html, download_btn] | |
| ) | |
| plot_selectivity_btn.click( | |
| on_plot_selectivity, | |
| inputs=[all_predictions_state, view_selector], | |
| outputs=[all_selectivity_plots_state, selectivity_pair_selector, selectivity_plot] | |
| ) | |
| selectivity_pair_selector.change( | |
| on_selectivity_pair_change, | |
| inputs=[all_selectivity_plots_state, selectivity_pair_selector], | |
| outputs=[selectivity_plot] | |
| ) | |
| plot_pca_btn.click( | |
| on_plot_pca, | |
| inputs=[all_predictions_state, view_selector], | |
| outputs=[pca_plot] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(share=True) | |