import subprocess import pandas as pd import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter import gradio as gr import tempfile import logging from PIL import Image import os import io import numpy as np from sklearn.linear_model import LinearRegression from sklearn.metrics import r2_score import base64 # Initialize logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def plot_driver_analysis(results_df, average_value, title): """ Plot driver analysis results with factors instead of trust buckets. """ logger.info("Plotting driver analysis with title '%s'.", title) try: # Define color scheme for factors color_map = { "Factor 1": "#375570", "Factor 2": "#E3B05B", "Factor 3": "#C63F48", "Factor 4": "#418387", "Factor 5": "#DF8859", "Factor 6": "#6D93AB", } # Define the order order = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"] # Apply categorical ordering results_df["Predictor"] = pd.Categorical( results_df["Predictor"], categories=order, ordered=True ) results_df.sort_values("Predictor", ascending=False, inplace=True) # Create the figure fig, ax = plt.subplots(figsize=(10, 8)) # Set x-axis formatter formatter = FuncFormatter(lambda x, _: f"{x:.0f}%") ax.xaxis.set_major_formatter(formatter) # Determine x-axis range actual_min = results_df["Importance_percent"].min() actual_max = results_df["Importance_percent"].max() x_min = 0 x_max = actual_max + 5 plt.xlim(x_min, x_max) # Set x-axis ticks and grid x_ticks = np.arange(np.floor(x_min), np.ceil(x_max) + 5, 5) ax.set_xticks(x_ticks) for tick in x_ticks: ax.axvline(x=tick, color="grey", linestyle="--", linewidth=0.5, zorder=2) # Create bars for i, row in enumerate(results_df.itertuples(index=False)): color = color_map.get(row.Predictor, "#cccccc") ax.barh( row.Predictor, row.Importance_percent, left=0, color=color, edgecolor="white", height=0.6, zorder=3, ) ax.text( row.Importance_percent + 0.5, i, f"{row.Importance_percent:.1f}%", va="center", ha="left", color="#8c8b8c", ) # Draw average line ax.axvline(average_value, color="black", linewidth=1, linestyle="-", zorder=3) plt.title(title, fontsize=14) # Style the plot ax.spines[["left", "top", "right"]].set_color("none") ax.tick_params(axis="y", colors="#8c8b8c", length=0) ax.set_axisbelow(True) plt.tight_layout() # Save to image img_data = io.BytesIO() plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none") img_data.seek(0) img = Image.open(img_data) plt.close(fig) return img except Exception as e: logger.error("Error plotting driver analysis: %s", e) raise def plot_factor_performance(driver_df, title): """ Plot factor performance (agreement scores). """ factors = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"] try: color_map = { "Factor 1": "#375570", "Factor 2": "#E3B05B", "Factor 3": "#C63F48", "Factor 4": "#418387", "Factor 5": "#DF8859", "Factor 6": "#6D93AB", } # Calculate mean scores results_df = (driver_df[factors].mean()).reset_index() results_df.columns = ["Factor", "Agreement_Score"] fig, ax = plt.subplots(figsize=(10, 8)) ax.bar( results_df["Factor"], results_df["Agreement_Score"], color=[color_map[factor] for factor in results_df["Factor"]], edgecolor="white", zorder=2, ) # Add values on top for i, row in enumerate(results_df.itertuples(index=False, name=None)): factor, score = row ax.text( i, score + 0.1, f"{score:.1f}", ha="center", va="bottom", color="#8c8b8c", ) # Set y-axis plt.ylim(1, 10) plt.yticks(range(1, 11)) plt.ylabel("Agreement Score") plt.title(title, fontsize=14) ax.spines[["top", "right"]].set_color("none") # Add grid y_ticks = ax.get_yticks() for y_tick in y_ticks: ax.axhline(y=y_tick, color="grey", linestyle="--", linewidth=0.5, zorder=1) ax.set_axisbelow(True) plt.tight_layout() # Save to image img_data = io.BytesIO() plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none") img_data.seek(0) img = Image.open(img_data) plt.close(fig) return img except Exception as e: logger.error("Error plotting factor performance: %s", e) raise def calculate_r2_image(r2_percent): """ Create R² visualization. """ categories = [ ("<40%: Deficient", "#b03c3c"), (">50%: Gaps", "#bdd8da"), (">60%: Proven", "#89b7bc"), (">70%: Robust", "#375a5e"), ] labels = [c[0] for c in categories] colors = [c[1] for c in categories] fig, ax = plt.subplots(figsize=(3.6, 3.6), subplot_kw=dict(aspect="equal")) wedges, _ = ax.pie( [1] * 4, startangle=90, counterclock=False, colors=colors, wedgeprops=dict(width=0.35) ) # Add outer labels for i, wedge in enumerate(wedges): angle = (wedge.theta2 + wedge.theta1) / 2 x = 1.5 * np.cos(np.deg2rad(angle)) y = 1.5 * np.sin(np.deg2rad(angle)) ax.text( x, y, labels[i], ha='center', va='center', fontsize=9, color='black' ) # Center R² text ax.text( 0, 0, f"{int(round(r2_percent))}%", ha='center', va='center', fontsize=19, fontweight='bold' ) ax.set_title("Model Validity", fontsize=11, pad=10) ax.axis('off') fig.patch.set_facecolor('none') ax.patch.set_facecolor('none') plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', transparent=True, dpi=200) plt.close(fig) buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode("utf-8") return f"""
""" def create_avg_target_display(avg_target): """ Create average target (Purchase Consideration) visualization. """ fig, ax = plt.subplots(figsize=(3.6, 3.6)) # Create circular display circle = plt.Circle((0.5, 0.5), 0.4, color='#4CAF50', alpha=0.3) ax.add_patch(circle) ax.text(0.5, 0.5, f"{avg_target:.1f}", ha='center', va='center', fontsize=24, fontweight='bold') ax.text(0.5, 0.2, "Scale: 1-6", ha='center', va='center', fontsize=10, color='gray') ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_title("Avg Target", fontsize=11, pad=10) ax.axis('off') fig.patch.set_facecolor('none') ax.patch.set_facecolor('none') plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', transparent=True, dpi=200) plt.close(fig) buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode("utf-8") return f"""
""" def call_r_script_simplified(input_file, csv_output_path): """ Call R script for Shapley regression analysis on Consideration. """ command = [ "Rscript", "process_data.R", input_file, csv_output_path ] try: subprocess.run(command, check=True) except subprocess.CalledProcessError as e: logger.error("R script failed with error: %s", e) # For demo purposes, create mock data if R script fails mock_data = pd.DataFrame({ 'Predictor': ['Factor 1', 'Factor 2', 'Factor 3', 'Factor 4', 'Factor 5', 'Factor 6'], 'Importance': [0.15, 0.22, 0.18, 0.20, 0.13, 0.12] }) mock_data.to_csv(csv_output_path, index=False) except Exception as e: logger.error("Error calling R script: %s", e) raise def analyze_prospects_data(file_path): """ Analyze prospects data focusing on Purchase Consideration as target. """ logger.info("Analyzing prospects file: %s", file_path) try: # Load Excel file df = pd.read_excel(file_path, sheet_name="Driver", header=3) # Map column names from trust buckets to factors column_mapping = { "Stability": "Factor 1", "Development": "Factor 2", "Relationship": "Factor 3", "Benefit": "Factor 4", "Vision": "Factor 5", "Competence": "Factor 6" } # Create a copy with renamed columns for analysis df_analysis = df.copy() for old_name, new_name in column_mapping.items(): if old_name in df_analysis.columns: df_analysis.rename(columns={old_name: new_name}, inplace=True) # Check if Consideration column exists if "Consideration" not in df.columns: logger.error("Consideration column not found in dataset") return None, None, None, None # Calculate R² for Consideration model factors = list(column_mapping.values()) X = df_analysis[factors].dropna() y = df.loc[X.index, "Consideration"] # Use Consideration as target model = LinearRegression() model.fit(X, y) r2 = r2_score(y, model.predict(X)) r2_percent = r2 * 100 # Calculate average target (Consideration) avg_target = df["Consideration"].mean() # Create visualizations r2_html = calculate_r2_image(r2_percent) avg_target_html = create_avg_target_display(avg_target) # Factor performance plot factor_performance_img = plot_factor_performance(df_analysis, "Factor Performance (Agreement Scores)") # Run Shapley analysis on Consideration temp_dir = tempfile.mkdtemp() csv_output_path = os.path.join(temp_dir, "consideration_results.csv") # Call R script or create mock results call_r_script_simplified(file_path, csv_output_path) # Load results with renamed predictors results_df = pd.read_csv(csv_output_path) # Map predictor names if they come from R script with original names if "Predictor" in results_df.columns: results_df["Predictor"] = results_df["Predictor"].map( lambda x: column_mapping.get(x, x) ) results_df["Importance_percent"] = results_df["Importance"] * 100 average_value = results_df["Importance_percent"].mean() # Create driver analysis plot driver_analysis_img = plot_driver_analysis( results_df, average_value, "Shapley Driver Analysis - Purchase Consideration" ) # Clean up os.remove(csv_output_path) os.rmdir(temp_dir) return r2_html, avg_target_html, factor_performance_img, driver_analysis_img except Exception as e: logger.error(f"Error analyzing data: {e}") raise # Gradio interface with light theme css = """ .metric-container { display: flex; justify-content: space-around; margin: 20px 0; } """ # JavaScript to force light theme js = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } } """ # Create the Gradio interface with light theme with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo: gr.Markdown("""

Driver Analysis Demo - Purchase Consideration

""") gr.Markdown("### Prospects Analysis") gr.Markdown("Analysis showing what drives Purchase Consideration among prospects") # Metrics row with gr.Row(equal_height=True): with gr.Column(scale=1): r2_output = gr.HTML() with gr.Column(scale=1): avg_target_output = gr.HTML() # Charts with gr.Row(): with gr.Column(): factor_performance_plot = gr.Image(show_label=False) with gr.Column(): driver_analysis_plot = gr.Image(show_label=False) # Hidden state for file path prospects_file = gr.State(value="example_files/Volkswagen Non Customers.xlsx") # Auto-load on page load demo.load( fn=analyze_prospects_data, inputs=[prospects_file], outputs=[r2_output, avg_target_output, factor_performance_plot, driver_analysis_plot] ) # Launch without the theme parameter demo.launch(server_name="0.0.0.0", share=False)