Spaces:
Running
Running
| import subprocess | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from matplotlib.ticker import FuncFormatter | |
| import gradio as gr | |
| import tempfile | |
| import logging | |
| from PIL import Image | |
| import os | |
| import io | |
| import numpy as np | |
| from sklearn.linear_model import LinearRegression | |
| from sklearn.metrics import r2_score | |
| import base64 | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def plot_driver_analysis(results_df, average_value, title): | |
| """ | |
| Plot driver analysis results with factors instead of trust buckets. | |
| """ | |
| logger.info("Plotting driver analysis with title '%s'.", title) | |
| try: | |
| # Define color scheme for factors | |
| color_map = { | |
| "Factor 1": "#375570", | |
| "Factor 2": "#E3B05B", | |
| "Factor 3": "#C63F48", | |
| "Factor 4": "#418387", | |
| "Factor 5": "#DF8859", | |
| "Factor 6": "#6D93AB", | |
| } | |
| # Define the order | |
| order = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"] | |
| # Apply categorical ordering | |
| results_df["Predictor"] = pd.Categorical( | |
| results_df["Predictor"], categories=order, ordered=True | |
| ) | |
| results_df.sort_values("Predictor", ascending=False, inplace=True) | |
| # Create the figure | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| # Set x-axis formatter | |
| formatter = FuncFormatter(lambda x, _: f"{x:.0f}%") | |
| ax.xaxis.set_major_formatter(formatter) | |
| # Determine x-axis range | |
| actual_min = results_df["Importance_percent"].min() | |
| actual_max = results_df["Importance_percent"].max() | |
| x_min = 0 | |
| x_max = actual_max + 5 | |
| plt.xlim(x_min, x_max) | |
| # Set x-axis ticks and grid | |
| x_ticks = np.arange(np.floor(x_min), np.ceil(x_max) + 5, 5) | |
| ax.set_xticks(x_ticks) | |
| for tick in x_ticks: | |
| ax.axvline(x=tick, color="grey", linestyle="--", linewidth=0.5, zorder=2) | |
| # Create bars | |
| for i, row in enumerate(results_df.itertuples(index=False)): | |
| color = color_map.get(row.Predictor, "#cccccc") | |
| ax.barh( | |
| row.Predictor, | |
| row.Importance_percent, | |
| left=0, | |
| color=color, | |
| edgecolor="white", | |
| height=0.6, | |
| zorder=3, | |
| ) | |
| ax.text( | |
| row.Importance_percent + 0.5, | |
| i, | |
| f"{row.Importance_percent:.1f}%", | |
| va="center", | |
| ha="left", | |
| color="#8c8b8c", | |
| ) | |
| # Draw average line | |
| ax.axvline(average_value, color="black", linewidth=1, linestyle="-", zorder=3) | |
| plt.title(title, fontsize=14) | |
| # Style the plot | |
| ax.spines[["left", "top", "right"]].set_color("none") | |
| ax.tick_params(axis="y", colors="#8c8b8c", length=0) | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| # Save to image | |
| img_data = io.BytesIO() | |
| plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none") | |
| img_data.seek(0) | |
| img = Image.open(img_data) | |
| plt.close(fig) | |
| return img | |
| except Exception as e: | |
| logger.error("Error plotting driver analysis: %s", e) | |
| raise | |
| def plot_factor_performance(driver_df, title): | |
| """ | |
| Plot factor performance (agreement scores). | |
| """ | |
| factors = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"] | |
| try: | |
| color_map = { | |
| "Factor 1": "#375570", | |
| "Factor 2": "#E3B05B", | |
| "Factor 3": "#C63F48", | |
| "Factor 4": "#418387", | |
| "Factor 5": "#DF8859", | |
| "Factor 6": "#6D93AB", | |
| } | |
| # Calculate mean scores | |
| results_df = (driver_df[factors].mean()).reset_index() | |
| results_df.columns = ["Factor", "Agreement_Score"] | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| ax.bar( | |
| results_df["Factor"], | |
| results_df["Agreement_Score"], | |
| color=[color_map[factor] for factor in results_df["Factor"]], | |
| edgecolor="white", | |
| zorder=2, | |
| ) | |
| # Add values on top | |
| for i, row in enumerate(results_df.itertuples(index=False, name=None)): | |
| factor, score = row | |
| ax.text( | |
| i, | |
| score + 0.1, | |
| f"{score:.1f}", | |
| ha="center", | |
| va="bottom", | |
| color="#8c8b8c", | |
| ) | |
| # Set y-axis | |
| plt.ylim(1, 10) | |
| plt.yticks(range(1, 11)) | |
| plt.ylabel("Agreement Score") | |
| plt.title(title, fontsize=14) | |
| ax.spines[["top", "right"]].set_color("none") | |
| # Add grid | |
| y_ticks = ax.get_yticks() | |
| for y_tick in y_ticks: | |
| ax.axhline(y=y_tick, color="grey", linestyle="--", linewidth=0.5, zorder=1) | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| # Save to image | |
| img_data = io.BytesIO() | |
| plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none") | |
| img_data.seek(0) | |
| img = Image.open(img_data) | |
| plt.close(fig) | |
| return img | |
| except Exception as e: | |
| logger.error("Error plotting factor performance: %s", e) | |
| raise | |
| def calculate_r2_image(r2_percent): | |
| """ | |
| Create R² visualization. | |
| """ | |
| categories = [ | |
| ("<40%: Deficient", "#b03c3c"), | |
| (">50%: Gaps", "#bdd8da"), | |
| (">60%: Proven", "#89b7bc"), | |
| (">70%: Robust", "#375a5e"), | |
| ] | |
| labels = [c[0] for c in categories] | |
| colors = [c[1] for c in categories] | |
| fig, ax = plt.subplots(figsize=(3.6, 3.6), subplot_kw=dict(aspect="equal")) | |
| wedges, _ = ax.pie( | |
| [1] * 4, | |
| startangle=90, | |
| counterclock=False, | |
| colors=colors, | |
| wedgeprops=dict(width=0.35) | |
| ) | |
| # Add outer labels | |
| for i, wedge in enumerate(wedges): | |
| angle = (wedge.theta2 + wedge.theta1) / 2 | |
| x = 1.5 * np.cos(np.deg2rad(angle)) | |
| y = 1.5 * np.sin(np.deg2rad(angle)) | |
| ax.text( | |
| x, y, labels[i], | |
| ha='center', va='center', | |
| fontsize=9, | |
| color='black' | |
| ) | |
| # Center R² text | |
| ax.text( | |
| 0, 0, f"{int(round(r2_percent))}%", | |
| ha='center', va='center', | |
| fontsize=19, fontweight='bold' | |
| ) | |
| ax.set_title("Model Validity", fontsize=11, pad=10) | |
| ax.axis('off') | |
| fig.patch.set_facecolor('none') | |
| ax.patch.set_facecolor('none') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', transparent=True, dpi=200) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode("utf-8") | |
| return f""" | |
| <div style='display: flex; justify-content: center; align-items: center;'> | |
| <img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/> | |
| </div> | |
| """ | |
| def create_avg_target_display(avg_target): | |
| """ | |
| Create average target (Purchase Consideration) visualization. | |
| """ | |
| fig, ax = plt.subplots(figsize=(3.6, 3.6)) | |
| # Create circular display | |
| circle = plt.Circle((0.5, 0.5), 0.4, color='#4CAF50', alpha=0.3) | |
| ax.add_patch(circle) | |
| ax.text(0.5, 0.5, f"{avg_target:.1f}", | |
| ha='center', va='center', fontsize=24, fontweight='bold') | |
| ax.text(0.5, 0.2, "Scale: 1-6", | |
| ha='center', va='center', fontsize=10, color='gray') | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| ax.set_title("Avg Target", fontsize=11, pad=10) | |
| ax.axis('off') | |
| fig.patch.set_facecolor('none') | |
| ax.patch.set_facecolor('none') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', transparent=True, dpi=200) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode("utf-8") | |
| return f""" | |
| <div style='display: flex; justify-content: center; align-items: center;'> | |
| <img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/> | |
| </div> | |
| """ | |
| def call_r_script_simplified(input_file, csv_output_path): | |
| """ | |
| Call R script for Shapley regression analysis on Consideration. | |
| """ | |
| command = [ | |
| "Rscript", | |
| "process_data.R", | |
| input_file, | |
| csv_output_path | |
| ] | |
| try: | |
| subprocess.run(command, check=True) | |
| except subprocess.CalledProcessError as e: | |
| logger.error("R script failed with error: %s", e) | |
| # For demo purposes, create mock data if R script fails | |
| mock_data = pd.DataFrame({ | |
| 'Predictor': ['Factor 1', 'Factor 2', 'Factor 3', 'Factor 4', 'Factor 5', 'Factor 6'], | |
| 'Importance': [0.15, 0.22, 0.18, 0.20, 0.13, 0.12] | |
| }) | |
| mock_data.to_csv(csv_output_path, index=False) | |
| except Exception as e: | |
| logger.error("Error calling R script: %s", e) | |
| raise | |
| def analyze_prospects_data(file_path): | |
| """ | |
| Analyze prospects data focusing on Purchase Consideration as target. | |
| """ | |
| logger.info("Analyzing prospects file: %s", file_path) | |
| try: | |
| # Load Excel file | |
| df = pd.read_excel(file_path, sheet_name="Driver", header=3) | |
| # Map column names from trust buckets to factors | |
| column_mapping = { | |
| "Stability": "Factor 1", | |
| "Development": "Factor 2", | |
| "Relationship": "Factor 3", | |
| "Benefit": "Factor 4", | |
| "Vision": "Factor 5", | |
| "Competence": "Factor 6" | |
| } | |
| # Create a copy with renamed columns for analysis | |
| df_analysis = df.copy() | |
| for old_name, new_name in column_mapping.items(): | |
| if old_name in df_analysis.columns: | |
| df_analysis.rename(columns={old_name: new_name}, inplace=True) | |
| # Check if Consideration column exists | |
| if "Consideration" not in df.columns: | |
| logger.error("Consideration column not found in dataset") | |
| return None, None, None, None | |
| # Calculate R² for Consideration model | |
| factors = list(column_mapping.values()) | |
| X = df_analysis[factors].dropna() | |
| y = df.loc[X.index, "Consideration"] # Use Consideration as target | |
| model = LinearRegression() | |
| model.fit(X, y) | |
| r2 = r2_score(y, model.predict(X)) | |
| r2_percent = r2 * 100 | |
| # Calculate average target (Consideration) | |
| avg_target = df["Consideration"].mean() | |
| # Create visualizations | |
| r2_html = calculate_r2_image(r2_percent) | |
| avg_target_html = create_avg_target_display(avg_target) | |
| # Factor performance plot | |
| factor_performance_img = plot_factor_performance(df_analysis, "Factor Performance (Agreement Scores)") | |
| # Run Shapley analysis on Consideration | |
| temp_dir = tempfile.mkdtemp() | |
| csv_output_path = os.path.join(temp_dir, "consideration_results.csv") | |
| # Call R script or create mock results | |
| call_r_script_simplified(file_path, csv_output_path) | |
| # Load results with renamed predictors | |
| results_df = pd.read_csv(csv_output_path) | |
| # Map predictor names if they come from R script with original names | |
| if "Predictor" in results_df.columns: | |
| results_df["Predictor"] = results_df["Predictor"].map( | |
| lambda x: column_mapping.get(x, x) | |
| ) | |
| results_df["Importance_percent"] = results_df["Importance"] * 100 | |
| average_value = results_df["Importance_percent"].mean() | |
| # Create driver analysis plot | |
| driver_analysis_img = plot_driver_analysis( | |
| results_df, | |
| average_value, | |
| "Shapley Driver Analysis - Purchase Consideration" | |
| ) | |
| # Clean up | |
| os.remove(csv_output_path) | |
| os.rmdir(temp_dir) | |
| return r2_html, avg_target_html, factor_performance_img, driver_analysis_img | |
| except Exception as e: | |
| logger.error(f"Error analyzing data: {e}") | |
| raise | |
| # Gradio interface with light theme | |
| css = """ | |
| .metric-container { | |
| display: flex; | |
| justify-content: space-around; | |
| margin: 20px 0; | |
| } | |
| """ | |
| # JavaScript to force light theme | |
| js = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| # Create the Gradio interface with light theme | |
| with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| <h2 style="text-align: center; font-size: 2.25rem; font-weight: 600;"> | |
| Driver Analysis Demo - Purchase Consideration | |
| </h2> | |
| """) | |
| gr.Markdown("### Prospects Analysis") | |
| gr.Markdown("Analysis showing what drives Purchase Consideration among prospects") | |
| # Metrics row | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| r2_output = gr.HTML() | |
| with gr.Column(scale=1): | |
| avg_target_output = gr.HTML() | |
| # Charts | |
| with gr.Row(): | |
| with gr.Column(): | |
| factor_performance_plot = gr.Image(show_label=False) | |
| with gr.Column(): | |
| driver_analysis_plot = gr.Image(show_label=False) | |
| # Hidden state for file path | |
| prospects_file = gr.State(value="example_files/Volkswagen Non Customers.xlsx") | |
| # Auto-load on page load | |
| demo.load( | |
| fn=analyze_prospects_data, | |
| inputs=[prospects_file], | |
| outputs=[r2_output, avg_target_output, factor_performance_plot, driver_analysis_plot] | |
| ) | |
| # Launch without the theme parameter | |
| demo.launch(server_name="0.0.0.0", share=False) |