shapely / app.py
Wajahat698's picture
Update app.py
a38cada verified
import subprocess
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import gradio as gr
import tempfile
import logging
from PIL import Image
import os
import io
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import base64
import shutil
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def plot_driver_analysis(results_df, average_value, title):
"""
Plot driver analysis results with factors instead of trust buckets.
"""
logger.info("Plotting driver analysis with title '%s'.", title)
try:
# Define color scheme for factors
color_map = {
"Factor 1": "#375570",
"Factor 2": "#E3B05B",
"Factor 3": "#C63F48",
"Factor 4": "#418387",
"Factor 5": "#DF8859",
"Factor 6": "#6D93AB",
}
# Define the order
order = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"]
# Apply categorical ordering
results_df["Predictor"] = pd.Categorical(
results_df["Predictor"], categories=order, ordered=True
)
results_df.sort_values("Predictor", ascending=False, inplace=True)
# Create the figure
fig, ax = plt.subplots(figsize=(10, 8))
# Set x-axis formatter
formatter = FuncFormatter(lambda x, _: f"{x:.0f}%")
ax.xaxis.set_major_formatter(formatter)
# Determine x-axis range
actual_min = results_df["Importance_percent"].min()
actual_max = results_df["Importance_percent"].max()
x_min = 0
x_max = actual_max + 5
plt.xlim(x_min, x_max)
# Set x-axis ticks and grid
x_ticks = np.arange(np.floor(x_min), np.ceil(x_max) + 5, 5)
ax.set_xticks(x_ticks)
for tick in x_ticks:
ax.axvline(x=tick, color="grey", linestyle="--", linewidth=0.5, zorder=2)
# Create bars
for i, row in enumerate(results_df.itertuples(index=False)):
color = color_map.get(row.Predictor, "#cccccc")
ax.barh(
row.Predictor,
row.Importance_percent,
left=0,
color=color,
edgecolor="white",
height=0.6,
zorder=3,
)
ax.text(
row.Importance_percent + 0.5,
i,
f"{row.Importance_percent:.1f}%",
va="center",
ha="left",
color="#8c8b8c",
)
# Draw average line
ax.axvline(average_value, color="black", linewidth=1, linestyle="-", zorder=3)
plt.title(title, fontsize=14)
# Style the plot
ax.spines[["left", "top", "right"]].set_color("none")
ax.tick_params(axis="y", colors="#8c8b8c", length=0)
ax.set_axisbelow(True)
plt.tight_layout()
# Save to image
img_data = io.BytesIO()
plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none")
img_data.seek(0)
img = Image.open(img_data)
plt.close(fig)
return img
except Exception as e:
logger.error("Error plotting driver analysis: %s", e)
raise
def plot_factor_performance(driver_df, title):
"""
Plot factor performance (agreement scores).
"""
factors = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"]
try:
color_map = {
"Factor 1": "#375570",
"Factor 2": "#E3B05B",
"Factor 3": "#C63F48",
"Factor 4": "#418387",
"Factor 5": "#DF8859",
"Factor 6": "#6D93AB",
}
# Calculate mean scores
results_df = (driver_df[factors].mean()).reset_index()
results_df.columns = ["Factor", "Agreement_Score"]
fig, ax = plt.subplots(figsize=(10, 8))
ax.bar(
results_df["Factor"],
results_df["Agreement_Score"],
color=[color_map[factor] for factor in results_df["Factor"]],
edgecolor="white",
zorder=2,
)
# Add values on top
for i, row in enumerate(results_df.itertuples(index=False, name=None)):
factor, score = row
ax.text(
i,
score + 0.1,
f"{score:.1f}",
ha="center",
va="bottom",
color="#8c8b8c",
)
# Set y-axis
plt.ylim(1, 10)
plt.yticks(range(1, 10))
plt.ylabel("Agreement Score")
plt.title(title, fontsize=14)
ax.spines[["top", "right"]].set_color("none")
# Add grid
y_ticks = ax.get_yticks()
for y_tick in y_ticks:
ax.axhline(y=y_tick, color="grey", linestyle="--", linewidth=0.5, zorder=1)
ax.set_axisbelow(True)
plt.tight_layout()
# Save to image
img_data = io.BytesIO()
plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none")
img_data.seek(0)
img = Image.open(img_data)
plt.close(fig)
return img
except Exception as e:
logger.error("Error plotting factor performance: %s", e)
raise
def calculate_r2_image(r2_percent):
"""
Create R² visualization.
"""
categories = [
("<40%: Deficient", "#b03c3c"),
(">50%: Gaps", "#bdd8da"),
(">60%: Proven", "#89b7bc"),
(">70%: Robust", "#375a5e"),
]
labels = [c[0] for c in categories]
colors = [c[1] for c in categories]
fig, ax = plt.subplots(figsize=(3.6, 3.6), subplot_kw=dict(aspect="equal"))
wedges, _ = ax.pie(
[1] * 4,
startangle=90,
counterclock=False,
colors=colors,
wedgeprops=dict(width=0.35)
)
# Add outer labels
for i, wedge in enumerate(wedges):
angle = (wedge.theta2 + wedge.theta1) / 2
x = 1.5 * np.cos(np.deg2rad(angle))
y = 1.5 * np.sin(np.deg2rad(angle))
ax.text(
x, y, labels[i],
ha='center', va='center',
fontsize=9,
color='black'
)
# Center R² text
ax.text(
0, 0, f"{int(round(r2_percent))}%",
ha='center', va='center',
fontsize=19, fontweight='bold'
)
ax.set_title("Model Validity", fontsize=11, pad=10)
ax.axis('off')
fig.patch.set_facecolor('none')
ax.patch.set_facecolor('none')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', transparent=True, dpi=200)
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
return f"""
<div style='display: flex; justify-content: center; align-items: center;'>
<img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/>
</div>
"""
def create_avg_target_display(avg_target, target_name, scale_info):
"""
Create average target visualization.
"""
fig, ax = plt.subplots(figsize=(3.6, 3.6))
# Create circular display
circle = plt.Circle((0.5, 0.5), 0.4, color='#4CAF50', alpha=0.3)
ax.add_patch(circle)
ax.text(0.5, 0.5, f"{avg_target:.1f}",
ha='center', va='center', fontsize=24, fontweight='bold')
ax.text(0.5, 0.2, scale_info,
ha='center', va='center', fontsize=10, color='gray')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title(f"Avg", fontsize=11, pad=10)
ax.axis('off')
fig.patch.set_facecolor('none')
ax.patch.set_facecolor('none')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', transparent=True, dpi=200)
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
return f"""
<div style='display: flex; justify-content: center; align-items: center;'>
<img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/>
</div>
"""
def create_error_message(message):
"""Create an HTML error message."""
return f"""
<div style='background-color: #ffebee; border: 1px solid #f44336; border-radius: 4px; padding: 16px; color: #c62828;'>
<strong>Error:</strong> {message}
</div>
"""
def create_placeholder_message():
"""Create a placeholder message when no file is uploaded."""
return f"""
<div style='background-color: #f0f8ff; border: 1px solid #4682b4; border-radius: 4px; padding: 16px; color: #2e4057; text-align: center;'>
<strong>Upload an Excel file to start analysis</strong><br>
The analysis will begin automatically once your file is uploaded.
</div>
"""
def find_target_column(df):
"""
Find the best target column in the dataset.
Priority: Consideration -> Trust -> NPS -> Loyalty
"""
# Define target priorities
target_priorities = [
"Consideration",
"Trust",
"NPS",
"Loyalty",
]
# Function to automatically detect scale
def get_scale_info(column_data):
min_val = column_data.dropna().min()
max_val = column_data.dropna().max()
return f"Scale: {min_val:.0f}-{max_val:.0f}"
# Check for exact matches first
for target in target_priorities:
if target in df.columns:
scale = get_scale_info(df[target])
return target, target, scale
# Check for case-insensitive matches
df_columns_lower = {col.lower(): col for col in df.columns}
for target in target_priorities:
target_lower = target.lower()
if target_lower in df_columns_lower:
actual_col = df_columns_lower[target_lower]
scale = get_scale_info(df[actual_col])
return actual_col, target, scale
# Check for partial matches
for col in df.columns:
col_lower = col.lower()
if 'consider' in col_lower:
scale = get_scale_info(df[col])
return col, "Consideration", scale
elif 'trust' in col_lower:
scale = get_scale_info(df[col])
return col, "Trust", scale
elif 'nps' in col_lower:
scale = get_scale_info(df[col])
return col, "NPS", scale
elif 'loyal' in col_lower:
scale = get_scale_info(df[col])
return col, "Loyalty", scale
return None, None, None
def call_r_script_for_target(input_file, csv_output_path, target_type):
"""
Call R script for Shapley regression analysis for any target type.
"""
# Create temporary files for all outputs
temp_dir = os.path.dirname(csv_output_path)
text_output_path = os.path.join(temp_dir, "output.txt")
csv_output_path_trust = os.path.join(temp_dir, "trust.csv")
csv_output_path_nps = os.path.join(temp_dir, "nps.csv")
csv_output_path_loyalty = os.path.join(temp_dir, "loyalty.csv")
csv_output_path_consideration = os.path.join(temp_dir, "consideration.csv")
csv_output_path_satisfaction = os.path.join(temp_dir, "satisfaction.csv")
csv_output_path_trustbuilder = os.path.join(temp_dir, "trustbuilder.csv")
# Set the boolean flags based on target type
nps_present = (target_type.lower() == "nps")
loyalty_present = (target_type.lower() == "loyalty")
consideration_present = (target_type.lower() == "consideration")
satisfaction_present = (target_type.lower() == "satisfaction")
trustbuilder_present = False
# Map output file based on target type
target_output_map = {
"consideration": csv_output_path_consideration,
"trust": csv_output_path_trust,
"nps": csv_output_path_nps,
"loyalty": csv_output_path_loyalty,
}
target_csv_path = target_output_map.get(target_type.lower(), csv_output_path_consideration)
command = [
"Rscript",
"process_data.R",
input_file,
text_output_path,
csv_output_path_trust,
csv_output_path_nps,
csv_output_path_loyalty,
csv_output_path_consideration,
csv_output_path_satisfaction,
csv_output_path_trustbuilder,
str(nps_present).upper(),
str(loyalty_present).upper(),
str(consideration_present).upper(),
str(satisfaction_present).upper(),
str(trustbuilder_present).upper(),
]
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
logger.info("R script executed successfully")
# Copy the target-specific result to our expected output path
if os.path.exists(target_csv_path) and target_csv_path != csv_output_path:
shutil.copy2(target_csv_path, csv_output_path)
return True
except subprocess.CalledProcessError as e:
logger.error("R script failed with error: %s", e)
logger.error("R script stderr: %s", e.stderr)
logger.error("R script stdout: %s", e.stdout)
return False
except Exception as e:
logger.error("Error calling R script: %s", e)
return False
def analyze_prospects_data(file_path):
"""
Analyze prospects data with flexible target detection.
"""
if file_path is None:
return create_placeholder_message(), None, None, None
logger.info("Analyzing prospects file: %s", file_path)
try:
# Load Excel file
df = pd.read_excel(file_path, sheet_name="Driver", header=3)
# Check required columns
required_factor_columns = ["Stability", "Development", "Relationship", "Benefit", "Vision", "Competence"]
missing_factors = [col for col in required_factor_columns if col not in df.columns]
if missing_factors:
logger.error(f"Missing factor columns: {missing_factors}")
return create_error_message(f"Missing required columns: {missing_factors}"), None, None, None
# Find target column
target_col, target_name, scale_info = find_target_column(df)
if target_col is None:
logger.error("No suitable target column found")
logger.info(f"Available columns: {list(df.columns)}")
return create_error_message(f"No suitable target column found. Available columns: {list(df.columns)}"), None, None, None
logger.info(f"Using target column: {target_col} (interpreted as {target_name})")
# Map column names from trust buckets to factors
column_mapping = {
"Stability": "Factor 1",
"Development": "Factor 2",
"Relationship": "Factor 3",
"Benefit": "Factor 4",
"Vision": "Factor 5",
"Competence": "Factor 6"
}
# Create a copy with renamed columns for analysis
df_analysis = df.copy()
for old_name, new_name in column_mapping.items():
if old_name in df_analysis.columns:
df_analysis.rename(columns={old_name: new_name}, inplace=True)
# Calculate R² for target model
factors = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"]
X = df_analysis[factors].dropna()
y = df.loc[X.index, target_col]
# Remove any remaining NaN values
valid_mask = ~y.isna()
X = X[valid_mask]
y = y[valid_mask]
if len(X) < 10:
logger.error(f"Not enough valid data points: {len(X)}")
return create_error_message(f"Not enough valid data points: {len(X)}. Need at least 10."), None, None, None
model = LinearRegression()
model.fit(X, y)
r2 = r2_score(y, model.predict(X))
r2_percent = r2 * 100
# Calculate average target
avg_target = y.mean()
logger.info(f"R² Score: {r2_percent:.1f}%, Average {target_name}: {avg_target:.1f}")
# Create visualizations
r2_html = calculate_r2_image(r2_percent)
avg_target_html = create_avg_target_display(avg_target, target_name, scale_info)
# Factor performance plot
factor_performance_img = plot_factor_performance(df_analysis, "Factor Performance (Agreement Scores)")
# Run Shapley analysis
temp_dir = tempfile.mkdtemp()
csv_output_path = os.path.join(temp_dir, "results.csv")
# Call R script with proper parameters
r_success = call_r_script_for_target(file_path, csv_output_path, target_name)
if not r_success:
# Clean up and return error
try:
shutil.rmtree(temp_dir)
except:
pass
return create_error_message("R script failed to execute. Shapley analysis not available."), None, None, None
# Check if R script produced output file
if not os.path.exists(csv_output_path):
try:
shutil.rmtree(temp_dir)
except:
pass
return create_error_message("R script did not produce expected output file."), None, None, None
# Load results with renamed predictors
try:
results_df = pd.read_csv(csv_output_path)
except Exception as e:
logger.error(f"Error reading R script output: {e}")
try:
shutil.rmtree(temp_dir)
except:
pass
return create_error_message(f"Error reading R script output: {e}"), None, None, None
# Validate R script output
if "Predictor" not in results_df.columns or "Importance" not in results_df.columns:
logger.error("R script output missing required columns")
try:
shutil.rmtree(temp_dir)
except:
pass
return create_error_message("R script output is invalid - missing required columns."), None, None, None
# Map predictor names if they come from R script with original names
results_df["Predictor"] = results_df["Predictor"].map(
lambda x: column_mapping.get(x, x)
)
results_df["Importance_percent"] = results_df["Importance"] * 100
average_value = results_df["Importance_percent"].mean()
# Create driver analysis plot
driver_analysis_img = plot_driver_analysis(
results_df,
average_value,
f"Shapley Driver Analysis"
)
# Clean up
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.error(f"Error cleaning up temp files: {e}")
return r2_html, avg_target_html, factor_performance_img, driver_analysis_img
except Exception as e:
logger.error(f"Error analyzing data: {e}")
return create_error_message(f"Analysis failed: {str(e)}"), None, None, None
def load_default_file():
"""Load default file on startup"""
default_file = "example_files/Volkswagen Non Customers.xlsx"
if os.path.exists(default_file):
return analyze_prospects_data(default_file)
return create_placeholder_message(), None, None, None
def handle_file_upload(file):
"""Handle file upload and analysis - automatically analyze when file is uploaded"""
if file is None:
return create_placeholder_message(), None, None, None
return analyze_prospects_data(file.name)
# Gradio interface with light theme
css = """
.metric-container {
display: flex;
justify-content: space-around;
margin: 20px 0;
}
"""
# JavaScript to force light theme
js = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
# Create the Gradio interface with light theme
with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
gr.Markdown("Analysis showing what drives your target variable (e.g NPS, Decisions) using Factors 1-6")
# File upload section - REMOVED the analyze button
with gr.Row():
file_upload = gr.File(
label="Upload Excel File - Analysis will start automatically",
file_types=[".xlsx", ".xls"],
type="filepath"
)
# Metrics row
with gr.Row(equal_height=True):
with gr.Column(scale=1):
r2_output = gr.HTML()
with gr.Column(scale=1):
avg_target_output = gr.HTML()
# Charts
with gr.Row():
with gr.Column():
factor_performance_plot = gr.Image(show_label=False)
with gr.Column():
driver_analysis_plot = gr.Image(show_label=False)
# REMOVED the analyze button click handler
# Only keep the file upload change handler for automatic analysis
file_upload.change(
fn=handle_file_upload,
inputs=[file_upload],
outputs=[r2_output, avg_target_output, factor_performance_plot, driver_analysis_plot]
)
# Launch the demo
demo.launch(server_name="0.0.0.0", share=False)