Spaces:
Running
Running
| """ | |
| Interactive Crystallization Component Predictor | |
| =============================================== | |
| Streamlit app for Hugging Face Hub deployment | |
| Predicts crystallization components using Simple Baseline and Advanced Baseline models | |
| """ | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import json | |
| import os | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Page config | |
| st.set_page_config( | |
| page_title="Crystallization Predictor", | |
| page_icon="🔬", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Get the directory of this script | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Title and Introduction | |
| st.title("🔬 Crystallization Component Predictor") | |
| st.markdown(""" | |
| ### Predict crystallization components using Machine Learning | |
| This app uses trained machine learning models to predict the optimal components for protein crystallization | |
| based on your experimental parameters. | |
| """) | |
| st.markdown("---") | |
| # Sidebar | |
| st.sidebar.header("⚙️ Model Selection") | |
| approach = st.sidebar.radio( | |
| "Choose Approach:", | |
| ["Advanced Baseline (Recommended)", "Simple Baseline"], | |
| help="Advanced has concentration parsing and better accuracy" | |
| ) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("### 📊 Model Performance") | |
| # Display performance metrics | |
| try: | |
| simple_results_path = os.path.join(BASE_DIR, 'models', 'simple_baseline', 'training_results.json') | |
| advanced_results_path = os.path.join(BASE_DIR, 'models', 'advanced_baseline', 'training_results.json') | |
| if os.path.exists(simple_results_path): | |
| with open(simple_results_path, 'r') as f: | |
| simple_results = json.load(f) | |
| if os.path.exists(advanced_results_path): | |
| with open(advanced_results_path, 'r') as f: | |
| advanced_results = json.load(f) | |
| if "Simple" in approach: | |
| st.sidebar.metric("Name Accuracy", "61.12%") | |
| st.sidebar.metric("pH R²", "95.58%") | |
| st.sidebar.warning("⚠️ Conc: N/A") | |
| else: | |
| st.sidebar.metric("Name Accuracy", "64.18%") | |
| st.sidebar.metric("Conc R²", "47.33%") | |
| st.sidebar.metric("pH R²", "99.34%") | |
| st.sidebar.success("✅ All metrics working!") | |
| except Exception as e: | |
| st.sidebar.info(f"Using default metrics") | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown(""" | |
| ### ℹ️ About | |
| This tool predicts three key crystallization parameters: | |
| - **Component Name**: The chemical compound | |
| - **Concentration**: Amount in solution (M) | |
| - **pH**: Acidity/basicity level | |
| **Recommended:** Advanced Baseline for complete predictions | |
| """) | |
| # Input Form | |
| st.header("🎯 Input Crystallization Parameters") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("#### Crystallization Setup") | |
| cryst_method = st.selectbox( | |
| "Crystallization Method", | |
| [ | |
| "VAPOR DIFFUSION, SITTING DROP", | |
| "VAPOR DIFFUSION, HANGING DROP", | |
| "VAPOR DIFFUSION", | |
| "BATCH MODE", | |
| "MICROBATCH" | |
| ], | |
| help="Select the crystallization technique you're using" | |
| ) | |
| temp = st.slider( | |
| "Temperature (K)", | |
| 250.0, 320.0, 293.0, 1.0, | |
| help="Typical room temperature is ~293K (20°C)" | |
| ) | |
| ph = st.slider( | |
| "pH", | |
| 0.0, 14.0, 7.0, 0.1, | |
| help="Initial pH of your crystallization solution" | |
| ) | |
| with col2: | |
| st.markdown("#### Crystal Properties") | |
| matthews = st.slider( | |
| "Matthews Coefficient", | |
| 1.0, 4.5, 2.2, 0.1, | |
| help="Ratio of unit cell volume to protein molecular weight (ų/Da)" | |
| ) | |
| solvent = st.slider( | |
| "Percent Solvent Content (%)", | |
| 0.0, 100.0, 45.0, 1.0, | |
| help="Percentage of solvent in the crystal" | |
| ) | |
| st.markdown("---") | |
| # Predict button | |
| if st.button("🚀 Predict Components", type="primary", use_container_width=True): | |
| try: | |
| with st.spinner("🔄 Loading models and making predictions..."): | |
| if "Advanced" in approach: | |
| # Load advanced models | |
| model_name = joblib.load(os.path.join(BASE_DIR, 'models', 'advanced_baseline', 'model_component_name.pkl')) | |
| model_conc = joblib.load(os.path.join(BASE_DIR, 'models', 'advanced_baseline', 'model_component_conc.pkl')) | |
| model_ph = joblib.load(os.path.join(BASE_DIR, 'models', 'advanced_baseline', 'model_component_ph.pkl')) | |
| le = joblib.load(os.path.join(BASE_DIR, 'models', 'advanced_baseline', 'label_encoder_name.pkl')) | |
| scaler = joblib.load(os.path.join(BASE_DIR, 'models', 'advanced_baseline', 'scaler.pkl')) | |
| tfidf = joblib.load(os.path.join(BASE_DIR, 'models', 'advanced_baseline', 'tfidf.pkl')) | |
| # Feature engineering (Advanced Baseline needs 8 features) | |
| temp_ph_int = temp * ph | |
| matthews_solvent_int = matthews * solvent | |
| ph_diff = 0 # Unknown for new prediction | |
| solvent_ratio = solvent / (matthews + 1e-6) | |
| numerical = np.array([[temp, ph, matthews, solvent, | |
| temp_ph_int, matthews_solvent_int, | |
| ph_diff, solvent_ratio]]) | |
| else: | |
| # Load simple models | |
| model_name = joblib.load(os.path.join(BASE_DIR, 'models', 'simple_baseline', 'model_component_name.pkl')) | |
| model_ph = joblib.load(os.path.join(BASE_DIR, 'models', 'simple_baseline', 'model_component_ph.pkl')) | |
| le = joblib.load(os.path.join(BASE_DIR, 'models', 'simple_baseline', 'label_encoder_name.pkl')) | |
| scaler = joblib.load(os.path.join(BASE_DIR, 'models', 'simple_baseline', 'scaler.pkl')) | |
| tfidf = joblib.load(os.path.join(BASE_DIR, 'models', 'simple_baseline', 'tfidf.pkl')) | |
| # Simple baseline: only 4 features | |
| numerical = np.array([[temp, ph, matthews, solvent]]) | |
| # Scale numerical features | |
| numerical_scaled = scaler.transform(numerical) | |
| # TF-IDF for crystallization method | |
| method_tfidf = tfidf.transform([cryst_method.upper()]).toarray() | |
| # Combine features | |
| X_pred = np.concatenate([numerical_scaled, method_tfidf], axis=1) | |
| # Make predictions | |
| pred_name_idx = model_name.predict(X_pred)[0] | |
| pred_name = le.inverse_transform([pred_name_idx])[0] | |
| pred_name_proba = model_name.predict_proba(X_pred)[0] | |
| top_5_idx = np.argsort(pred_name_proba)[-5:][::-1] | |
| top_5_names = le.inverse_transform(top_5_idx) | |
| top_5_proba = pred_name_proba[top_5_idx] | |
| pred_ph = model_ph.predict(X_pred)[0] | |
| if "Advanced" in approach: | |
| pred_conc = model_conc.predict(X_pred)[0] | |
| # Display Results | |
| st.success("✅ Predictions Complete!") | |
| st.markdown("---") | |
| st.header("📊 Prediction Results") | |
| # Component Name | |
| st.subheader("1️⃣ Component_1_Name") | |
| st.markdown("**Most likely chemical component for crystallization:**") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.metric("Predicted Component", pred_name) | |
| st.caption("Top prediction from the model") | |
| with col2: | |
| st.markdown("**Top 5 Predictions (with confidence):**") | |
| top5_df = pd.DataFrame({ | |
| 'Rank': range(1, 6), | |
| 'Component': top_5_names, | |
| 'Probability': [f"{p:.2%}" for p in top_5_proba] | |
| }) | |
| st.dataframe(top5_df, hide_index=True, use_container_width=True) | |
| st.markdown("---") | |
| # Concentration | |
| st.subheader("2️⃣ Component_1_Conc") | |
| if "Advanced" in approach: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("Predicted Concentration (log-scale)", f"{pred_conc:.4f}") | |
| with col2: | |
| actual_molarity = 10**pred_conc | |
| st.metric("Actual Molarity", f"{actual_molarity:.6f} M") | |
| st.info(f"💡 Use approximately **{actual_molarity:.6f} M** of {pred_name} in your crystallization trials") | |
| else: | |
| st.warning("⚠️ Not available in Simple Baseline - use Advanced Baseline for concentration predictions") | |
| st.markdown("---") | |
| # pH | |
| st.subheader("3️⃣ Component_1_pH") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.metric("Predicted pH", f"{pred_ph:.2f}") | |
| # pH classification | |
| if pred_ph < 6: | |
| ph_class = "Acidic" | |
| ph_emoji = "🔴" | |
| elif pred_ph < 8: | |
| ph_class = "Neutral" | |
| ph_emoji = "🟢" | |
| else: | |
| ph_class = "Basic" | |
| ph_emoji = "🔵" | |
| st.caption(f"{ph_emoji} {ph_class} solution") | |
| with col2: | |
| # pH visualization | |
| ph_percent = (pred_ph / 14) * 100 | |
| ph_color = "red" if pred_ph < 6 else ("green" if pred_ph < 8 else "blue") | |
| st.markdown(f""" | |
| <div style='background: linear-gradient(to right, red, yellow, green, cyan, blue); | |
| height: 40px; border-radius: 10px; margin: 10px 0; border: 2px solid #333;'></div> | |
| <div style='display: flex; justify-content: space-between; font-size: 14px;'> | |
| <span><b>0</b> (Acidic)</span> | |
| <span><b>7</b> (Neutral)</span> | |
| <span><b>14</b> (Basic)</span> | |
| </div> | |
| <div style='text-align: center; margin-top: 15px;'> | |
| <b style='font-size: 24px; color: {ph_color};'>pH = {pred_ph:.2f}</b> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.info(f"💡 Adjust your buffer to maintain pH ≈ **{pred_ph:.2f}** for optimal crystallization") | |
| # Input Summary | |
| st.markdown("---") | |
| st.subheader("📥 Input Summary") | |
| input_df = pd.DataFrame({ | |
| 'Parameter': [ | |
| 'Crystallization Method', | |
| 'Temperature', | |
| 'Input pH', | |
| 'Matthews Coefficient', | |
| 'Solvent Content' | |
| ], | |
| 'Value': [ | |
| cryst_method, | |
| f"{temp:.1f} K ({temp-273.15:.1f}°C)", | |
| f"{ph:.1f}", | |
| f"{matthews:.2f} Ų/Da", | |
| f"{solvent:.1f}%" | |
| ] | |
| }) | |
| st.table(input_df) | |
| # Download Results | |
| st.markdown("---") | |
| st.subheader("💾 Download Results") | |
| results_dict = { | |
| 'Crystallization Method': cryst_method, | |
| 'Temperature (K)': temp, | |
| 'Temperature (°C)': temp - 273.15, | |
| 'Input pH': ph, | |
| 'Matthews Coefficient': matthews, | |
| 'Solvent Content (%)': solvent, | |
| 'Predicted Component': pred_name, | |
| 'Component Probability': f"{top_5_proba[0]:.4f}", | |
| 'Predicted pH': f"{pred_ph:.2f}", | |
| } | |
| if "Advanced" in approach: | |
| results_dict['Predicted Concentration (log)'] = f"{pred_conc:.4f}" | |
| results_dict['Predicted Concentration (M)'] = f"{10**pred_conc:.6f}" | |
| results_df = pd.DataFrame([results_dict]) | |
| csv = results_df.to_csv(index=False) | |
| st.download_button( | |
| label="📥 Download Predictions as CSV", | |
| data=csv, | |
| file_name="crystallization_predictions.csv", | |
| mime="text/csv", | |
| ) | |
| except FileNotFoundError as e: | |
| st.error(f""" | |
| ❌ **Model files not found!** | |
| Error: {e} | |
| Please ensure model files are in the correct directory: | |
| - `models/simple_baseline/` | |
| - `models/advanced_baseline/` | |
| """) | |
| except Exception as e: | |
| st.error(f"❌ **Prediction Error:** {e}") | |
| with st.expander("🔍 Show full error details"): | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| # Model Comparison Section | |
| st.markdown("---") | |
| st.header("📈 Model Comparison") | |
| comparison_df = pd.DataFrame({ | |
| 'Model': ['Simple Baseline', 'Advanced Baseline', 'Transformer'], | |
| 'Name Accuracy': ['61.12%', '64.18% ⭐', '53.85%'], | |
| 'Conc R²': ['N/A', '47.33%', '18.72%'], | |
| 'pH R²': ['95.58%', '99.34% ⭐', '99.27%'], | |
| 'Speed': ['⚡ Fast', '⚡ Fast', '🐌 Slow'], | |
| 'Recommendation': ['Basic use', '✅ Best overall', 'Research only'] | |
| }) | |
| st.dataframe( | |
| comparison_df, | |
| hide_index=True, | |
| use_container_width=True, | |
| column_config={ | |
| "Model": st.column_config.TextColumn("Model", width="medium"), | |
| "Name Accuracy": st.column_config.TextColumn("Name Accuracy", width="medium"), | |
| "Conc R²": st.column_config.TextColumn("Concentration R²", width="medium"), | |
| "pH R²": st.column_config.TextColumn("pH R²", width="medium"), | |
| } | |
| ) | |
| st.markdown(""" | |
| **Model Selection Guide:** | |
| - **Simple Baseline**: Fast predictions, no concentration. Good for quick pH and component estimates. | |
| - **Advanced Baseline**: ⭐ Recommended for most users. Includes all three predictions with high accuracy. | |
| - **Transformer**: Deep learning approach, requires more data for better performance. | |
| """) | |
| # Visualizations Section | |
| st.markdown("---") | |
| st.header("📊 Performance Visualizations") | |
| viz_path = os.path.join(BASE_DIR, 'visualizations') | |
| if os.path.exists(viz_path): | |
| try: | |
| tab1, tab2, tab3, tab4 = st.tabs([ | |
| "📊 Name Accuracy", | |
| "📈 Concentration R²", | |
| "🧪 pH R²", | |
| "🎯 Complete Comparison" | |
| ]) | |
| with tab1: | |
| img_path = os.path.join(viz_path, '01_component_name_comparison.png') | |
| if os.path.exists(img_path): | |
| st.image(img_path, use_column_width=True) | |
| st.caption("Comparison of component name prediction accuracy across all models") | |
| with tab2: | |
| img_path = os.path.join(viz_path, '02_component_conc_comparison.png') | |
| if os.path.exists(img_path): | |
| st.image(img_path, use_column_width=True) | |
| st.caption("Concentration prediction performance (R² scores)") | |
| with tab3: | |
| img_path = os.path.join(viz_path, '03_component_ph_comparison.png') | |
| if os.path.exists(img_path): | |
| st.image(img_path, use_column_width=True) | |
| st.caption("pH prediction performance (R² scores)") | |
| with tab4: | |
| img_path = os.path.join(viz_path, '05_complete_comparison.png') | |
| if os.path.exists(img_path): | |
| st.image(img_path, use_column_width=True) | |
| st.caption("Comprehensive comparison of all approaches and metrics") | |
| except Exception as e: | |
| st.info(f"Visualizations are being loaded... {e}") | |
| else: | |
| st.info("📊 Visualization files not found in this deployment") | |
| # Information Section | |
| st.markdown("---") | |
| st.header("ℹ️ How It Works") | |
| with st.expander("🔬 About Protein Crystallization"): | |
| st.markdown(""" | |
| **Protein crystallization** is a crucial step in structural biology for determining 3D protein structures using X-ray crystallography. | |
| **Key Parameters:** | |
| - **Crystallization Method**: The technique used (e.g., vapor diffusion, batch mode) | |
| - **Temperature**: Affects protein stability and crystal growth | |
| - **pH**: Critical for protein solubility and crystal formation | |
| - **Matthews Coefficient**: Indicates crystal packing density | |
| - **Solvent Content**: Amount of solvent in the crystal lattice | |
| This tool helps predict optimal conditions based on historical crystallization data. | |
| """) | |
| with st.expander("🤖 About the Models"): | |
| st.markdown(""" | |
| **Simple Baseline:** | |
| - Random Forest classifier for component name | |
| - XGBoost regressor for pH | |
| - Uses 4 numerical features + TF-IDF of method | |
| **Advanced Baseline:** | |
| - Ensemble of Random Forest, XGBoost, LightGBM, and CatBoost | |
| - Includes concentration prediction with log-transformation | |
| - Uses 8 engineered features including interactions | |
| - Best overall performance: 64% name accuracy, 99% pH R² | |
| **Training Data:** | |
| - Based on protein crystallization experiments from PDB | |
| - Includes various crystallization methods and conditions | |
| - Models trained on structured crystallization data | |
| """) | |
| with st.expander("📖 How to Use"): | |
| st.markdown(""" | |
| 1. **Select a model** in the sidebar (Advanced Baseline recommended) | |
| 2. **Input your parameters**: | |
| - Choose crystallization method | |
| - Set temperature, pH, Matthews coefficient, and solvent content | |
| 3. **Click "Predict Components"** to get predictions | |
| 4. **Review results**: | |
| - Component name with confidence scores | |
| - Concentration (if using Advanced Baseline) | |
| - Optimal pH for crystallization | |
| 5. **Download** results as CSV for your records | |
| 💡 **Tip:** Start with the recommended default values and adjust based on your specific protein and experimental setup. | |
| """) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='text-align: center; color: gray; padding: 20px;'> | |
| <p><b>🔬 Crystallization Component Prediction System</b></p> | |
| <p><i>Advanced Baseline achieves: 64% Name Accuracy | 47% Conc R² | 99% pH R²</i></p> | |
| <p>Built with Scikit-learn, XGBoost, LightGBM, CatBoost & Streamlit</p> | |
| <p style='font-size: 12px; margin-top: 10px;'> | |
| For research and educational purposes. Validate predictions experimentally. | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |