Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from predict import TabularTransformer, model_predict | |
| from sklearn.preprocessing import MinMaxScaler | |
| import matplotlib.pyplot as plt | |
| import shap | |
| # Set page config | |
| st.set_page_config( | |
| page_title="Resistivity Prediction App", | |
| page_icon="🔮", | |
| layout="wide" | |
| ) | |
| # Title and description | |
| st.title("Resistivity Prediction App") | |
| st.markdown(""" | |
| This app predicts resistivity based on input features. Enter the values for each feature | |
| and click 'Predict' to get the prediction and explanation. | |
| """) | |
| def load_model_and_scalers(): | |
| # Load data for scaling | |
| df = pd.read_excel('data.xlsx') | |
| X = df.iloc[:, 0:8] | |
| y = df.iloc[:, 8] | |
| feature_names = X.columns.tolist() | |
| # Initialize scalers | |
| scaler_X = MinMaxScaler() | |
| scaler_y = MinMaxScaler() | |
| # Fit scalers | |
| scaler_X.fit(X) | |
| scaler_y.fit(y.values.reshape(-1, 1)) | |
| # Load model | |
| model = TabularTransformer(input_dim=8, output_dim=1) | |
| model.load_state_dict(torch.load('model.pth')) | |
| model.eval() | |
| return model, scaler_X, scaler_y, feature_names, X | |
| def explain_prediction(model, input_df, X_background, scaler_X, scaler_y, feature_names): | |
| # Create a prediction function for SHAP | |
| def predict_fn(X): | |
| X_tensor = torch.FloatTensor(scaler_X.transform(X)) | |
| with torch.no_grad(): | |
| scaled_pred = model(X_tensor).numpy() | |
| return scaler_y.inverse_transform(scaled_pred) | |
| # Use a subset of training data as background | |
| background_sample = X_background.sample(n=min(100, len(X_background)), random_state=42) | |
| explainer = shap.KernelExplainer(predict_fn, background_sample) | |
| # Calculate SHAP values for the input | |
| shap_values = explainer.shap_values(input_df) | |
| # Handle different SHAP value formats | |
| if isinstance(shap_values, list): | |
| shap_values = np.array(shap_values[0]) | |
| # Ensure correct shape for waterfall plot | |
| if len(shap_values.shape) > 1: | |
| if shap_values.shape[0] == len(feature_names): | |
| shap_values = shap_values.T | |
| shap_values = shap_values.flatten() | |
| # Create waterfall plot | |
| plt.figure(figsize=(12, 8)) | |
| shap.plots.waterfall( | |
| shap.Explanation( | |
| values=shap_values, | |
| base_values=explainer.expected_value if np.isscalar(explainer.expected_value) | |
| else explainer.expected_value[0], | |
| data=input_df.iloc[0].values, | |
| feature_names=feature_names | |
| ), | |
| show=False | |
| ) | |
| plt.title('SHAP Value Contributions') | |
| plt.tight_layout() | |
| plt.savefig('shap_explanation.png', dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return explainer.expected_value, shap_values | |
| # Load model and scalers | |
| try: | |
| model, scaler_X, scaler_y, feature_names, X = load_model_and_scalers() | |
| # Create input fields for features | |
| st.subheader("Input Features") | |
| # Create two columns for input fields | |
| col1, col2 = st.columns(2) | |
| # Dictionary to store input values | |
| input_values = {} | |
| # Create input fields split between two columns | |
| for i, feature in enumerate(feature_names): | |
| # Get min and max values for each feature | |
| min_val = float(X[feature].min()) | |
| max_val = float(X[feature].max()) | |
| # Add input field to alternating columns | |
| with col1 if i < len(feature_names)//2 else col2: | |
| input_values[feature] = st.number_input( | |
| f"{feature}", | |
| min_value=float(min_val), | |
| max_value=float(max_val), | |
| value=float(X[feature].mean()), | |
| help=f"Range: {min_val:.2f} to {max_val:.2f}" | |
| ) | |
| # Add predict button | |
| if st.button("Predict"): | |
| # Create input DataFrame | |
| input_df = pd.DataFrame([input_values]) | |
| # Make prediction | |
| prediction = model_predict(model, input_df, scaler_X, scaler_y) | |
| # Display prediction | |
| st.subheader("Prediction Result") | |
| st.markdown(f"### Predicted Resistivity: {prediction[0]:.2f}") | |
| # Calculate and display SHAP values | |
| st.subheader("Feature Importance Explanation") | |
| # Get SHAP values using the training data as background | |
| expected_value, shap_values = explain_prediction( | |
| model, input_df, X, scaler_X, scaler_y, feature_names | |
| ) | |
| # Display the waterfall plot | |
| st.image('shap_explanation.png') | |
| except Exception as e: | |
| st.error(f""" | |
| Error loading the model and data. Please make sure: | |
| 1. The model file 'model.pth' exists | |
| 2. The data file 'data.xlsx' exists | |
| 3. All required packages are installed | |
| Error details: {str(e)} | |
| """) |