linear-regression / src /streamlit_app.py
Ezzio11's picture
Update src/streamlit_app.py
f968412 verified
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.stats.stattools import jarque_bera
from statsmodels.stats.diagnostic import het_breuschpagan, normal_ad
from scipy.stats import boxcox, shapiro
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler, PowerTransformer
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.linear_model import LinearRegression
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
import warnings
# Configuration
warnings.filterwarnings('ignore')
os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
os.environ["STREAMLIT_METRICS_ENABLED"] = "false"
st.set_page_config(page_title="Advanced Regression Analysis", layout="wide")
def load_data():
"""Load data with improved error handling and data type detection"""
uploaded_data = st.file_uploader('πŸ“‚ Upload Data File', type=['csv', 'txt', 'xlsx', 'xls'])
if uploaded_data is not None:
try:
if uploaded_data.type == 'text/plain':
delimiter = st.radio('Select delimiter (separator)', [',', '\t', '|', ' ', 'Auto Detect'])
if delimiter == 'Auto Detect':
df = pd.read_csv(uploaded_data, sep=None, engine='python')
else:
df = pd.read_csv(uploaded_data, sep=delimiter)
elif uploaded_data.type == 'text/csv':
df = pd.read_csv(uploaded_data)
elif uploaded_data.type in ['application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.ms-excel']:
df = pd.read_excel(uploaded_data)
# Basic data quality check
st.write('### πŸ” Dataset Preview')
st.dataframe(df.head())
# Show data summary
with st.expander("πŸ“Š Data Summary"):
st.write("**Data Types:**")
st.dataframe(df.dtypes.astype(str))
st.write("**Descriptive Statistics:**")
st.dataframe(df.describe())
return df
except Exception as e:
st.error(f"Error loading file: {str(e)}")
return None
return None
@st.cache_data
def calculate_vif(X):
"""Calculate VIF with improved handling"""
X = X.select_dtypes(include=[np.number]).dropna()
X = X.loc[:, (X != X.iloc[0]).any()]
if X.shape[1] < 2:
return None
vif_data = pd.DataFrame()
vif_data["Feature"] = X.columns
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
vif_data["Severity"] = np.where(vif_data["VIF"] > 10, "High",
np.where(vif_data["VIF"] > 5, "Moderate", "Low"))
return vif_data.sort_values("VIF", ascending=False)
@st.cache_data
def transform_data(X, y, transformations):
"""Apply selected transformations to data"""
X_trans = X.copy()
y_trans = y.copy()
# Apply transformations to features
if 'log' in transformations:
X_trans = np.log1p(X_trans)
if 'sqrt' in transformations:
X_trans = np.sqrt(X_trans)
if 'boxcox' in transformations:
for col in X_trans.columns:
if (X_trans[col] > 0).all():
X_trans[col], _ = boxcox(X_trans[col] + 1e-6)
# Apply transformations to target
if 'log_y' in transformations:
y_trans = np.log1p(y_trans)
if 'sqrt_y' in transformations:
y_trans = np.sqrt(y_trans)
if 'boxcox_y' in transformations:
if (y_trans > 0).all():
y_trans, _ = boxcox(y_trans + 1e-6)
return X_trans, y_trans
def plot_residual_analysis(y_true, y_pred, residuals):
"""Create comprehensive residual analysis plots"""
fig = make_subplots(rows=2, cols=2,
subplot_titles=("Residuals vs Fitted",
"Q-Q Plot",
"Residual Histogram",
"Residuals vs Order"))
# Residuals vs Fitted
fig.add_trace(
go.Scatter(x=y_pred, y=residuals, mode='markers', name='Residuals'),
row=1, col=1
)
fig.add_hline(y=0, line_dash="dot", row=1, col=1)
# Q-Q Plot - Proper implementation
qq = sm.ProbPlot(residuals)
theoretical = qq.theoretical_quantiles
sample = qq.sample_quantiles
# Calculate regression line for Q-Q plot
slope, intercept = np.polyfit(theoretical, sample, 1)
line_x = np.array([theoretical.min(), theoretical.max()])
line_y = slope * line_x + intercept
fig.add_trace(
go.Scatter(x=theoretical, y=sample, mode='markers', name='Q-Q Points'),
row=1, col=2
)
# Add reference line
fig.add_trace(
go.Scatter(x=line_x, y=line_y, mode='lines',
line=dict(color='red'), name='Reference Line'),
row=1, col=2
)
# Residual Histogram
fig.add_trace(
go.Histogram(x=residuals, nbinsx=50, name='Residuals'),
row=2, col=1
)
# Residuals vs Order
fig.add_trace(
go.Scatter(x=np.arange(len(residuals)), y=residuals,
mode='lines+markers', name='Residuals'),
row=2, col=2
)
fig.add_hline(y=0, line_dash="dot", row=2, col=2)
fig.update_layout(
height=800,
showlegend=False,
template='plotly_white',
margin=dict(l=50, r=50, b=50, t=50)
)
st.plotly_chart(fig, use_container_width=True)
def main():
st.title('πŸ“ˆ Statistical Linear Regression Analysis')
st.markdown("""
This tool provides comprehensive linear regression analysis with diagnostics and visualizations.
Upload your data, select variables, and explore the results!
""")
df = load_data()
if df is not None:
# Data Cleaning Section
st.sidebar.header("Data Cleaning Options")
if df.isnull().sum().sum() > 0:
st.sidebar.warning("⚠️ Dataset contains missing values")
impute_method = st.sidebar.selectbox(
"Imputation method",
['Fill with mean', 'Fill with median', 'Fill with mode', 'Drop rows']
)
if impute_method == 'Fill with mean':
df.fillna(df.mean(), inplace=True)
elif impute_method == 'Fill with median':
df.fillna(df.median(), inplace=True)
elif impute_method == 'Fill with mode':
df.fillna(df.mode().iloc[0], inplace=True)
elif impute_method == 'Drop rows':
df.dropna(inplace=True)
# Outlier Handling
outlier_method = st.sidebar.selectbox(
"Outlier handling",
['None', 'Z-score (3Οƒ)', 'IQR Method']
)
# Variable Selection
st.header("Variable Selection")
col1, col2 = st.columns(2)
with col1:
predictors = st.multiselect(
'🎯 Select Predictor Variables',
[col for col in df.columns if df[col].nunique() > 1],
help="Select multiple features for multiple regression"
)
with col2:
target = st.selectbox(
'πŸ“Œ Select Target Variable',
[col for col in df.columns if col not in predictors]
)
if not predictors or not target:
st.warning("Please select at least one predictor and a target variable")
st.stop()
X = df[predictors]
y = df[target]
# Data Transformation Section
st.header("Data Transformations")
transformations = st.multiselect(
"Apply transformations to improve model performance",
['log', 'sqrt', 'boxcox', 'log_y', 'sqrt_y', 'boxcox_y'],
help="Log and sqrt help with right-skewed data. Box-Cox requires positive values."
)
if transformations:
X, y = transform_data(X, y, transformations)
# Model Configuration
st.header("Model Configuration")
col1, col2 = st.columns(2)
with col1:
test_size = st.slider('Test set size (%)', 10, 50, 20, 5)/100
random_state = st.number_input('Random seed', 0, 1000, 42)
with col2:
scale_data = st.checkbox("Standardize features", True)
cv_folds = st.selectbox("Cross-validation folds", [3, 5, 10], 2)
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state
)
# Reset indicies
X_train = X_train.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
y_test = y_test.reset_index(drop=True)
# Standardize if requested
if scale_data:
scaler = StandardScaler()
X_train = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns)
X_test = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns)
# Add constant for statsmodels
X_train_const = sm.add_constant(X_train)
X_test_const = sm.add_constant(X_test)
# Fit models
model_sm = sm.OLS(y_train, X_train_const).fit()
model_sk = LinearRegression().fit(X_train, y_train)
# Cross-validation
cv_scores = cross_val_score(model_sk, X_train, y_train,
cv=cv_folds, scoring='r2')
# Predictions
y_pred = model_sm.predict(X_test_const)
y_train_pred = model_sm.predict(X_train_const)
# Performance Metrics
st.header("Model Performance")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("RΒ² (Training)", f"{model_sm.rsquared:.3f}")
with col2:
st.metric("Adj. RΒ² (Training)", f"{model_sm.rsquared_adj:.3f}")
with col3:
st.metric("RΒ² (Test)", f"{r2_score(y_test, y_pred):.3f}")
with col4:
st.metric("CV RΒ² (Mean)", f"{np.mean(cv_scores):.3f}")
st.markdown("---")
# Actual vs Predicted Plot
fig_avp = px.scatter(
x=y_test, y=y_pred,
labels={'x': 'Actual', 'y': 'Predicted'},
title='Actual vs Predicted Values',
trendline="ols"
)
fig_avp.add_shape(type="line", x0=y_test.min(), y0=y_test.min(),
x1=y_test.max(), y1=y_test.max(),
line=dict(color="Red", dash="dot"))
st.plotly_chart(fig_avp, use_container_width=True)
# Feature Importance
if len(predictors) > 1:
st.subheader("Feature Importance")
coef_df = pd.DataFrame({
'Feature': X_train_const.columns[1:],
'Coefficient': model_sm.params[1:],
'Absolute Impact': np.abs(model_sm.params[1:])
}).sort_values('Absolute Impact', ascending=False)
fig_coef = px.bar(coef_df, x='Feature', y='Coefficient',
color='Coefficient', color_continuous_scale='RdBu',
title='Feature Coefficients')
st.plotly_chart(fig_coef, use_container_width=True)
# Diagnostic Plots
st.header("Model Diagnostics")
residuals = y_train - y_train_pred
with st.expander("Residual Analysis"):
plot_residual_analysis(y_train, y_train_pred, residuals)
# Normality tests
jb_stat, jb_pval = jarque_bera(residuals)[:2]
ad_stat, ad_pval = normal_ad(residuals)[:2]
sh_stat, sh_pval = shapiro(residuals)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Jarque-Bera p-value", f"{jb_pval:.4f}")
with col2:
st.metric("Anderson-Darling p-value", f"{ad_pval:.4f}")
with col3:
st.metric("Shapiro-Wilk p-value", f"{sh_pval:.4f}")
if any(p < 0.05 for p in [jb_pval, ad_pval, sh_pval]):
st.warning("Residuals may not be normally distributed")
else:
st.success("Residuals appear normally distributed")
with st.expander("Heteroscedasticity Check"):
_, bp_pval, _, _ = het_breuschpagan(residuals, X_train_const)
st.metric("Breusch-Pagan p-value", f"{bp_pval:.4f}")
if bp_pval < 0.05:
st.warning("Evidence of heteroscedasticity")
else:
st.success("No significant heteroscedasticity detected")
with st.expander("Multicollinearity Check"):
vif_data = calculate_vif(X_train)
if vif_data is not None:
fig_vif = px.bar(vif_data, x='Feature', y='VIF', color='Severity',
color_discrete_map={'High': 'red', 'Moderate': 'orange', 'Low': 'green'},
title='Variance Inflation Factors (VIF)')
st.plotly_chart(fig_vif, use_container_width=True)
high_vif = vif_data[vif_data['VIF'] > 10]
if not high_vif.empty:
st.warning("High multicollinearity detected in these features:")
st.dataframe(high_vif)
else:
st.info("Not enough features to calculate VIF")
# Model Summary
st.header("Model Summary")
with st.expander("Detailed Summary"):
st.write(model_sm.summary())
# Prediction Interface
st.header("Make Predictions")
st.markdown("Enter values for prediction (using original scale):")
input_values = {}
cols = st.columns(min(3, len(predictors)))
for i, predictor in enumerate(predictors):
with cols[i % len(cols)]:
input_values[predictor] = st.number_input(
predictor,
value=float(X[predictor].median()),
step=float(X[predictor].std()/10)
)
if st.button("Predict"):
input_df = pd.DataFrame([input_values])
# Apply transformations if needed
if transformations:
input_df, _ = transform_data(input_df, pd.Series([0]), transformations)
# Standardize if needed
if scale_data:
input_df = pd.DataFrame(scaler.transform(input_df), columns=input_df.columns)
# Add constant and predict
input_df = sm.add_constant(input_df, has_constant='add')
prediction = model_sm.predict(input_df)
# Inverse transform if needed
if 'log_y' in transformations:
prediction = np.expm1(prediction)
elif 'sqrt_y' in transformations:
prediction = np.square(prediction)
st.success(f"**Predicted {target}:** {prediction[0]:.2f}")
# Show prediction interval
pred_ci = model_sm.get_prediction(input_df).conf_int()
if 'log_y' in transformations:
pred_ci = np.expm1(pred_ci)
elif 'sqrt_y' in transformations:
pred_ci = np.square(pred_ci)
st.info(f"95% Confidence Interval: ({pred_ci[0][0]:.2f}, {pred_ci[0][1]:.2f})")
if __name__ == '__main__':
st.set_page_config(page_title="Linear Regression Analysis", layout="wide")
main()