shapely / app.py
Wajahat698's picture
Update app.py
9dc1a68 verified
import subprocess
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import gradio as gr
import tempfile
import logging
from PIL import Image
import os
import io
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import base64
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def plot_driver_analysis(results_df, average_value, title):
"""
Plot driver analysis results with factors instead of trust buckets.
"""
logger.info("Plotting driver analysis with title '%s'.", title)
try:
# Define color scheme for factors
color_map = {
"Factor 1": "#375570",
"Factor 2": "#E3B05B",
"Factor 3": "#C63F48",
"Factor 4": "#418387",
"Factor 5": "#DF8859",
"Factor 6": "#6D93AB",
}
# Define the order
order = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"]
# Apply categorical ordering
results_df["Predictor"] = pd.Categorical(
results_df["Predictor"], categories=order, ordered=True
)
results_df.sort_values("Predictor", ascending=False, inplace=True)
# Create the figure
fig, ax = plt.subplots(figsize=(10, 8))
# Set x-axis formatter
formatter = FuncFormatter(lambda x, _: f"{x:.0f}%")
ax.xaxis.set_major_formatter(formatter)
# Determine x-axis range
actual_min = results_df["Importance_percent"].min()
actual_max = results_df["Importance_percent"].max()
x_min = 0
x_max = actual_max + 5
plt.xlim(x_min, x_max)
# Set x-axis ticks and grid
x_ticks = np.arange(np.floor(x_min), np.ceil(x_max) + 5, 5)
ax.set_xticks(x_ticks)
for tick in x_ticks:
ax.axvline(x=tick, color="grey", linestyle="--", linewidth=0.5, zorder=2)
# Create bars
for i, row in enumerate(results_df.itertuples(index=False)):
color = color_map.get(row.Predictor, "#cccccc")
ax.barh(
row.Predictor,
row.Importance_percent,
left=0,
color=color,
edgecolor="white",
height=0.6,
zorder=3,
)
ax.text(
row.Importance_percent + 0.5,
i,
f"{row.Importance_percent:.1f}%",
va="center",
ha="left",
color="#8c8b8c",
)
# Draw average line
ax.axvline(average_value, color="black", linewidth=1, linestyle="-", zorder=3)
plt.title(title, fontsize=14)
# Style the plot
ax.spines[["left", "top", "right"]].set_color("none")
ax.tick_params(axis="y", colors="#8c8b8c", length=0)
ax.set_axisbelow(True)
plt.tight_layout()
# Save to image
img_data = io.BytesIO()
plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none")
img_data.seek(0)
img = Image.open(img_data)
plt.close(fig)
return img
except Exception as e:
logger.error("Error plotting driver analysis: %s", e)
raise
def plot_factor_performance(driver_df, title):
"""
Plot factor performance (agreement scores).
"""
factors = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"]
try:
color_map = {
"Factor 1": "#375570",
"Factor 2": "#E3B05B",
"Factor 3": "#C63F48",
"Factor 4": "#418387",
"Factor 5": "#DF8859",
"Factor 6": "#6D93AB",
}
# Calculate mean scores
results_df = (driver_df[factors].mean()).reset_index()
results_df.columns = ["Factor", "Agreement_Score"]
fig, ax = plt.subplots(figsize=(10, 8))
ax.bar(
results_df["Factor"],
results_df["Agreement_Score"],
color=[color_map[factor] for factor in results_df["Factor"]],
edgecolor="white",
zorder=2,
)
# Add values on top
for i, row in enumerate(results_df.itertuples(index=False, name=None)):
factor, score = row
ax.text(
i,
score + 0.1,
f"{score:.1f}",
ha="center",
va="bottom",
color="#8c8b8c",
)
# Set y-axis
plt.ylim(1, 10)
plt.yticks(range(1, 11))
plt.ylabel("Agreement Score")
plt.title(title, fontsize=14)
ax.spines[["top", "right"]].set_color("none")
# Add grid
y_ticks = ax.get_yticks()
for y_tick in y_ticks:
ax.axhline(y=y_tick, color="grey", linestyle="--", linewidth=0.5, zorder=1)
ax.set_axisbelow(True)
plt.tight_layout()
# Save to image
img_data = io.BytesIO()
plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none")
img_data.seek(0)
img = Image.open(img_data)
plt.close(fig)
return img
except Exception as e:
logger.error("Error plotting factor performance: %s", e)
raise
def calculate_r2_image(r2_percent):
"""
Create R² visualization.
"""
categories = [
("<40%: Deficient", "#b03c3c"),
(">50%: Gaps", "#bdd8da"),
(">60%: Proven", "#89b7bc"),
(">70%: Robust", "#375a5e"),
]
labels = [c[0] for c in categories]
colors = [c[1] for c in categories]
fig, ax = plt.subplots(figsize=(3.6, 3.6), subplot_kw=dict(aspect="equal"))
wedges, _ = ax.pie(
[1] * 4,
startangle=90,
counterclock=False,
colors=colors,
wedgeprops=dict(width=0.35)
)
# Add outer labels
for i, wedge in enumerate(wedges):
angle = (wedge.theta2 + wedge.theta1) / 2
x = 1.5 * np.cos(np.deg2rad(angle))
y = 1.5 * np.sin(np.deg2rad(angle))
ax.text(
x, y, labels[i],
ha='center', va='center',
fontsize=9,
color='black'
)
# Center R² text
ax.text(
0, 0, f"{int(round(r2_percent))}%",
ha='center', va='center',
fontsize=19, fontweight='bold'
)
ax.set_title("Model Validity", fontsize=11, pad=10)
ax.axis('off')
fig.patch.set_facecolor('none')
ax.patch.set_facecolor('none')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', transparent=True, dpi=200)
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
return f"""
<div style='display: flex; justify-content: center; align-items: center;'>
<img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/>
</div>
"""
def create_avg_target_display(avg_target):
"""
Create average target (Purchase Consideration) visualization.
"""
fig, ax = plt.subplots(figsize=(3.6, 3.6))
# Create circular display
circle = plt.Circle((0.5, 0.5), 0.4, color='#4CAF50', alpha=0.3)
ax.add_patch(circle)
ax.text(0.5, 0.5, f"{avg_target:.1f}",
ha='center', va='center', fontsize=24, fontweight='bold')
ax.text(0.5, 0.2, "Scale: 1-6",
ha='center', va='center', fontsize=10, color='gray')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title("Avg Target", fontsize=11, pad=10)
ax.axis('off')
fig.patch.set_facecolor('none')
ax.patch.set_facecolor('none')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png', transparent=True, dpi=200)
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
return f"""
<div style='display: flex; justify-content: center; align-items: center;'>
<img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/>
</div>
"""
def call_r_script_simplified(input_file, csv_output_path):
"""
Call R script for Shapley regression analysis on Consideration.
"""
command = [
"Rscript",
"process_data.R",
input_file,
csv_output_path
]
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
logger.error("R script failed with error: %s", e)
# For demo purposes, create mock data if R script fails
mock_data = pd.DataFrame({
'Predictor': ['Factor 1', 'Factor 2', 'Factor 3', 'Factor 4', 'Factor 5', 'Factor 6'],
'Importance': [0.15, 0.22, 0.18, 0.20, 0.13, 0.12]
})
mock_data.to_csv(csv_output_path, index=False)
except Exception as e:
logger.error("Error calling R script: %s", e)
raise
def analyze_prospects_data(file_path):
"""
Analyze prospects data focusing on Purchase Consideration as target.
"""
logger.info("Analyzing prospects file: %s", file_path)
try:
# Load Excel file
df = pd.read_excel(file_path, sheet_name="Driver", header=3)
# Map column names from trust buckets to factors
column_mapping = {
"Stability": "Factor 1",
"Development": "Factor 2",
"Relationship": "Factor 3",
"Benefit": "Factor 4",
"Vision": "Factor 5",
"Competence": "Factor 6"
}
# Create a copy with renamed columns for analysis
df_analysis = df.copy()
for old_name, new_name in column_mapping.items():
if old_name in df_analysis.columns:
df_analysis.rename(columns={old_name: new_name}, inplace=True)
# Check if Consideration column exists
if "Consideration" not in df.columns:
logger.error("Consideration column not found in dataset")
return None, None, None, None
# Calculate R² for Consideration model
factors = list(column_mapping.values())
X = df_analysis[factors].dropna()
y = df.loc[X.index, "Consideration"] # Use Consideration as target
model = LinearRegression()
model.fit(X, y)
r2 = r2_score(y, model.predict(X))
r2_percent = r2 * 100
# Calculate average target (Consideration)
avg_target = df["Consideration"].mean()
# Create visualizations
r2_html = calculate_r2_image(r2_percent)
avg_target_html = create_avg_target_display(avg_target)
# Factor performance plot
factor_performance_img = plot_factor_performance(df_analysis, "Factor Performance (Agreement Scores)")
# Run Shapley analysis on Consideration
temp_dir = tempfile.mkdtemp()
csv_output_path = os.path.join(temp_dir, "consideration_results.csv")
# Call R script or create mock results
call_r_script_simplified(file_path, csv_output_path)
# Load results with renamed predictors
results_df = pd.read_csv(csv_output_path)
# Map predictor names if they come from R script with original names
if "Predictor" in results_df.columns:
results_df["Predictor"] = results_df["Predictor"].map(
lambda x: column_mapping.get(x, x)
)
results_df["Importance_percent"] = results_df["Importance"] * 100
average_value = results_df["Importance_percent"].mean()
# Create driver analysis plot
driver_analysis_img = plot_driver_analysis(
results_df,
average_value,
"Shapley Driver Analysis - Purchase Consideration"
)
# Clean up
os.remove(csv_output_path)
os.rmdir(temp_dir)
return r2_html, avg_target_html, factor_performance_img, driver_analysis_img
except Exception as e:
logger.error(f"Error analyzing data: {e}")
raise
# Gradio interface with light theme
css = """
.metric-container {
display: flex;
justify-content: space-around;
margin: 20px 0;
}
"""
# JavaScript to force light theme
js = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
# Create the Gradio interface with light theme
with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
gr.Markdown("""
<h2 style="text-align: center; font-size: 2.25rem; font-weight: 600;">
Driver Analysis Demo - Purchase Consideration
</h2>
""")
gr.Markdown("### Prospects Analysis")
gr.Markdown("Analysis showing what drives Purchase Consideration among prospects")
# Metrics row
with gr.Row(equal_height=True):
with gr.Column(scale=1):
r2_output = gr.HTML()
with gr.Column(scale=1):
avg_target_output = gr.HTML()
# Charts
with gr.Row():
with gr.Column():
factor_performance_plot = gr.Image(show_label=False)
with gr.Column():
driver_analysis_plot = gr.Image(show_label=False)
# Hidden state for file path
prospects_file = gr.State(value="example_files/Volkswagen Non Customers.xlsx")
# Auto-load on page load
demo.load(
fn=analyze_prospects_data,
inputs=[prospects_file],
outputs=[r2_output, avg_target_output, factor_performance_plot, driver_analysis_plot]
)
# Launch without the theme parameter
demo.launch(server_name="0.0.0.0", share=False)