Spaces:
Running
Running
| import subprocess | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from matplotlib.ticker import FuncFormatter | |
| import gradio as gr | |
| import tempfile | |
| import logging | |
| from PIL import Image | |
| import os | |
| import io | |
| import numpy as np | |
| from sklearn.linear_model import LinearRegression | |
| from sklearn.metrics import r2_score | |
| import base64 | |
| import shutil | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def plot_driver_analysis(results_df, average_value, title): | |
| """ | |
| Plot driver analysis results with factors instead of trust buckets. | |
| """ | |
| logger.info("Plotting driver analysis with title '%s'.", title) | |
| try: | |
| # Define color scheme for factors | |
| color_map = { | |
| "Factor 1": "#375570", | |
| "Factor 2": "#E3B05B", | |
| "Factor 3": "#C63F48", | |
| "Factor 4": "#418387", | |
| "Factor 5": "#DF8859", | |
| "Factor 6": "#6D93AB", | |
| } | |
| # Define the order | |
| order = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"] | |
| # Apply categorical ordering | |
| results_df["Predictor"] = pd.Categorical( | |
| results_df["Predictor"], categories=order, ordered=True | |
| ) | |
| results_df.sort_values("Predictor", ascending=False, inplace=True) | |
| # Create the figure | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| # Set x-axis formatter | |
| formatter = FuncFormatter(lambda x, _: f"{x:.0f}%") | |
| ax.xaxis.set_major_formatter(formatter) | |
| # Determine x-axis range | |
| actual_min = results_df["Importance_percent"].min() | |
| actual_max = results_df["Importance_percent"].max() | |
| x_min = 0 | |
| x_max = actual_max + 5 | |
| plt.xlim(x_min, x_max) | |
| # Set x-axis ticks and grid | |
| x_ticks = np.arange(np.floor(x_min), np.ceil(x_max) + 5, 5) | |
| ax.set_xticks(x_ticks) | |
| for tick in x_ticks: | |
| ax.axvline(x=tick, color="grey", linestyle="--", linewidth=0.5, zorder=2) | |
| # Create bars | |
| for i, row in enumerate(results_df.itertuples(index=False)): | |
| color = color_map.get(row.Predictor, "#cccccc") | |
| ax.barh( | |
| row.Predictor, | |
| row.Importance_percent, | |
| left=0, | |
| color=color, | |
| edgecolor="white", | |
| height=0.6, | |
| zorder=3, | |
| ) | |
| ax.text( | |
| row.Importance_percent + 0.5, | |
| i, | |
| f"{row.Importance_percent:.1f}%", | |
| va="center", | |
| ha="left", | |
| color="#8c8b8c", | |
| ) | |
| # Draw average line | |
| ax.axvline(average_value, color="black", linewidth=1, linestyle="-", zorder=3) | |
| plt.title(title, fontsize=14) | |
| # Style the plot | |
| ax.spines[["left", "top", "right"]].set_color("none") | |
| ax.tick_params(axis="y", colors="#8c8b8c", length=0) | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| # Save to image | |
| img_data = io.BytesIO() | |
| plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none") | |
| img_data.seek(0) | |
| img = Image.open(img_data) | |
| plt.close(fig) | |
| return img | |
| except Exception as e: | |
| logger.error("Error plotting driver analysis: %s", e) | |
| raise | |
| def plot_factor_performance(driver_df, title): | |
| """ | |
| Plot factor performance (agreement scores). | |
| """ | |
| factors = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"] | |
| try: | |
| color_map = { | |
| "Factor 1": "#375570", | |
| "Factor 2": "#E3B05B", | |
| "Factor 3": "#C63F48", | |
| "Factor 4": "#418387", | |
| "Factor 5": "#DF8859", | |
| "Factor 6": "#6D93AB", | |
| } | |
| # Calculate mean scores | |
| results_df = (driver_df[factors].mean()).reset_index() | |
| results_df.columns = ["Factor", "Agreement_Score"] | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| ax.bar( | |
| results_df["Factor"], | |
| results_df["Agreement_Score"], | |
| color=[color_map[factor] for factor in results_df["Factor"]], | |
| edgecolor="white", | |
| zorder=2, | |
| ) | |
| # Add values on top | |
| for i, row in enumerate(results_df.itertuples(index=False, name=None)): | |
| factor, score = row | |
| ax.text( | |
| i, | |
| score + 0.1, | |
| f"{score:.1f}", | |
| ha="center", | |
| va="bottom", | |
| color="#8c8b8c", | |
| ) | |
| # Set y-axis | |
| plt.ylim(1, 10) | |
| plt.yticks(range(1, 10)) | |
| plt.ylabel("Agreement Score") | |
| plt.title(title, fontsize=14) | |
| ax.spines[["top", "right"]].set_color("none") | |
| # Add grid | |
| y_ticks = ax.get_yticks() | |
| for y_tick in y_ticks: | |
| ax.axhline(y=y_tick, color="grey", linestyle="--", linewidth=0.5, zorder=1) | |
| ax.set_axisbelow(True) | |
| plt.tight_layout() | |
| # Save to image | |
| img_data = io.BytesIO() | |
| plt.savefig(img_data, format="png", facecolor=fig.get_facecolor(), edgecolor="none") | |
| img_data.seek(0) | |
| img = Image.open(img_data) | |
| plt.close(fig) | |
| return img | |
| except Exception as e: | |
| logger.error("Error plotting factor performance: %s", e) | |
| raise | |
| def calculate_r2_image(r2_percent): | |
| """ | |
| Create R² visualization. | |
| """ | |
| categories = [ | |
| ("<40%: Deficient", "#b03c3c"), | |
| (">50%: Gaps", "#bdd8da"), | |
| (">60%: Proven", "#89b7bc"), | |
| (">70%: Robust", "#375a5e"), | |
| ] | |
| labels = [c[0] for c in categories] | |
| colors = [c[1] for c in categories] | |
| fig, ax = plt.subplots(figsize=(3.6, 3.6), subplot_kw=dict(aspect="equal")) | |
| wedges, _ = ax.pie( | |
| [1] * 4, | |
| startangle=90, | |
| counterclock=False, | |
| colors=colors, | |
| wedgeprops=dict(width=0.35) | |
| ) | |
| # Add outer labels | |
| for i, wedge in enumerate(wedges): | |
| angle = (wedge.theta2 + wedge.theta1) / 2 | |
| x = 1.5 * np.cos(np.deg2rad(angle)) | |
| y = 1.5 * np.sin(np.deg2rad(angle)) | |
| ax.text( | |
| x, y, labels[i], | |
| ha='center', va='center', | |
| fontsize=9, | |
| color='black' | |
| ) | |
| # Center R² text | |
| ax.text( | |
| 0, 0, f"{int(round(r2_percent))}%", | |
| ha='center', va='center', | |
| fontsize=19, fontweight='bold' | |
| ) | |
| ax.set_title("Model Validity", fontsize=11, pad=10) | |
| ax.axis('off') | |
| fig.patch.set_facecolor('none') | |
| ax.patch.set_facecolor('none') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', transparent=True, dpi=200) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode("utf-8") | |
| return f""" | |
| <div style='display: flex; justify-content: center; align-items: center;'> | |
| <img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/> | |
| </div> | |
| """ | |
| def create_avg_target_display(avg_target, target_name, scale_info): | |
| """ | |
| Create average target visualization. | |
| """ | |
| fig, ax = plt.subplots(figsize=(3.6, 3.6)) | |
| # Create circular display | |
| circle = plt.Circle((0.5, 0.5), 0.4, color='#4CAF50', alpha=0.3) | |
| ax.add_patch(circle) | |
| ax.text(0.5, 0.5, f"{avg_target:.1f}", | |
| ha='center', va='center', fontsize=24, fontweight='bold') | |
| ax.text(0.5, 0.2, scale_info, | |
| ha='center', va='center', fontsize=10, color='gray') | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| ax.set_title(f"Avg", fontsize=11, pad=10) | |
| ax.axis('off') | |
| fig.patch.set_facecolor('none') | |
| ax.patch.set_facecolor('none') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', transparent=True, dpi=200) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode("utf-8") | |
| return f""" | |
| <div style='display: flex; justify-content: center; align-items: center;'> | |
| <img src='data:image/png;base64,{img_base64}' style='max-width: 240px; height: auto;'/> | |
| </div> | |
| """ | |
| def create_error_message(message): | |
| """Create an HTML error message.""" | |
| return f""" | |
| <div style='background-color: #ffebee; border: 1px solid #f44336; border-radius: 4px; padding: 16px; color: #c62828;'> | |
| <strong>Error:</strong> {message} | |
| </div> | |
| """ | |
| def create_placeholder_message(): | |
| """Create a placeholder message when no file is uploaded.""" | |
| return f""" | |
| <div style='background-color: #f0f8ff; border: 1px solid #4682b4; border-radius: 4px; padding: 16px; color: #2e4057; text-align: center;'> | |
| <strong>Upload an Excel file to start analysis</strong><br> | |
| The analysis will begin automatically once your file is uploaded. | |
| </div> | |
| """ | |
| def find_target_column(df): | |
| """ | |
| Find the best target column in the dataset. | |
| Priority: Consideration -> Trust -> NPS -> Loyalty | |
| """ | |
| # Define target priorities | |
| target_priorities = [ | |
| "Consideration", | |
| "Trust", | |
| "NPS", | |
| "Loyalty", | |
| ] | |
| # Function to automatically detect scale | |
| def get_scale_info(column_data): | |
| min_val = column_data.dropna().min() | |
| max_val = column_data.dropna().max() | |
| return f"Scale: {min_val:.0f}-{max_val:.0f}" | |
| # Check for exact matches first | |
| for target in target_priorities: | |
| if target in df.columns: | |
| scale = get_scale_info(df[target]) | |
| return target, target, scale | |
| # Check for case-insensitive matches | |
| df_columns_lower = {col.lower(): col for col in df.columns} | |
| for target in target_priorities: | |
| target_lower = target.lower() | |
| if target_lower in df_columns_lower: | |
| actual_col = df_columns_lower[target_lower] | |
| scale = get_scale_info(df[actual_col]) | |
| return actual_col, target, scale | |
| # Check for partial matches | |
| for col in df.columns: | |
| col_lower = col.lower() | |
| if 'consider' in col_lower: | |
| scale = get_scale_info(df[col]) | |
| return col, "Consideration", scale | |
| elif 'trust' in col_lower: | |
| scale = get_scale_info(df[col]) | |
| return col, "Trust", scale | |
| elif 'nps' in col_lower: | |
| scale = get_scale_info(df[col]) | |
| return col, "NPS", scale | |
| elif 'loyal' in col_lower: | |
| scale = get_scale_info(df[col]) | |
| return col, "Loyalty", scale | |
| return None, None, None | |
| def call_r_script_for_target(input_file, csv_output_path, target_type): | |
| """ | |
| Call R script for Shapley regression analysis for any target type. | |
| """ | |
| # Create temporary files for all outputs | |
| temp_dir = os.path.dirname(csv_output_path) | |
| text_output_path = os.path.join(temp_dir, "output.txt") | |
| csv_output_path_trust = os.path.join(temp_dir, "trust.csv") | |
| csv_output_path_nps = os.path.join(temp_dir, "nps.csv") | |
| csv_output_path_loyalty = os.path.join(temp_dir, "loyalty.csv") | |
| csv_output_path_consideration = os.path.join(temp_dir, "consideration.csv") | |
| csv_output_path_satisfaction = os.path.join(temp_dir, "satisfaction.csv") | |
| csv_output_path_trustbuilder = os.path.join(temp_dir, "trustbuilder.csv") | |
| # Set the boolean flags based on target type | |
| nps_present = (target_type.lower() == "nps") | |
| loyalty_present = (target_type.lower() == "loyalty") | |
| consideration_present = (target_type.lower() == "consideration") | |
| satisfaction_present = (target_type.lower() == "satisfaction") | |
| trustbuilder_present = False | |
| # Map output file based on target type | |
| target_output_map = { | |
| "consideration": csv_output_path_consideration, | |
| "trust": csv_output_path_trust, | |
| "nps": csv_output_path_nps, | |
| "loyalty": csv_output_path_loyalty, | |
| } | |
| target_csv_path = target_output_map.get(target_type.lower(), csv_output_path_consideration) | |
| command = [ | |
| "Rscript", | |
| "process_data.R", | |
| input_file, | |
| text_output_path, | |
| csv_output_path_trust, | |
| csv_output_path_nps, | |
| csv_output_path_loyalty, | |
| csv_output_path_consideration, | |
| csv_output_path_satisfaction, | |
| csv_output_path_trustbuilder, | |
| str(nps_present).upper(), | |
| str(loyalty_present).upper(), | |
| str(consideration_present).upper(), | |
| str(satisfaction_present).upper(), | |
| str(trustbuilder_present).upper(), | |
| ] | |
| try: | |
| result = subprocess.run(command, check=True, capture_output=True, text=True) | |
| logger.info("R script executed successfully") | |
| # Copy the target-specific result to our expected output path | |
| if os.path.exists(target_csv_path) and target_csv_path != csv_output_path: | |
| shutil.copy2(target_csv_path, csv_output_path) | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| logger.error("R script failed with error: %s", e) | |
| logger.error("R script stderr: %s", e.stderr) | |
| logger.error("R script stdout: %s", e.stdout) | |
| return False | |
| except Exception as e: | |
| logger.error("Error calling R script: %s", e) | |
| return False | |
| def analyze_prospects_data(file_path): | |
| """ | |
| Analyze prospects data with flexible target detection. | |
| """ | |
| if file_path is None: | |
| return create_placeholder_message(), None, None, None | |
| logger.info("Analyzing prospects file: %s", file_path) | |
| try: | |
| # Load Excel file | |
| df = pd.read_excel(file_path, sheet_name="Driver", header=3) | |
| # Check required columns | |
| required_factor_columns = ["Stability", "Development", "Relationship", "Benefit", "Vision", "Competence"] | |
| missing_factors = [col for col in required_factor_columns if col not in df.columns] | |
| if missing_factors: | |
| logger.error(f"Missing factor columns: {missing_factors}") | |
| return create_error_message(f"Missing required columns: {missing_factors}"), None, None, None | |
| # Find target column | |
| target_col, target_name, scale_info = find_target_column(df) | |
| if target_col is None: | |
| logger.error("No suitable target column found") | |
| logger.info(f"Available columns: {list(df.columns)}") | |
| return create_error_message(f"No suitable target column found. Available columns: {list(df.columns)}"), None, None, None | |
| logger.info(f"Using target column: {target_col} (interpreted as {target_name})") | |
| # Map column names from trust buckets to factors | |
| column_mapping = { | |
| "Stability": "Factor 1", | |
| "Development": "Factor 2", | |
| "Relationship": "Factor 3", | |
| "Benefit": "Factor 4", | |
| "Vision": "Factor 5", | |
| "Competence": "Factor 6" | |
| } | |
| # Create a copy with renamed columns for analysis | |
| df_analysis = df.copy() | |
| for old_name, new_name in column_mapping.items(): | |
| if old_name in df_analysis.columns: | |
| df_analysis.rename(columns={old_name: new_name}, inplace=True) | |
| # Calculate R² for target model | |
| factors = ["Factor 1", "Factor 2", "Factor 3", "Factor 4", "Factor 5", "Factor 6"] | |
| X = df_analysis[factors].dropna() | |
| y = df.loc[X.index, target_col] | |
| # Remove any remaining NaN values | |
| valid_mask = ~y.isna() | |
| X = X[valid_mask] | |
| y = y[valid_mask] | |
| if len(X) < 10: | |
| logger.error(f"Not enough valid data points: {len(X)}") | |
| return create_error_message(f"Not enough valid data points: {len(X)}. Need at least 10."), None, None, None | |
| model = LinearRegression() | |
| model.fit(X, y) | |
| r2 = r2_score(y, model.predict(X)) | |
| r2_percent = r2 * 100 | |
| # Calculate average target | |
| avg_target = y.mean() | |
| logger.info(f"R² Score: {r2_percent:.1f}%, Average {target_name}: {avg_target:.1f}") | |
| # Create visualizations | |
| r2_html = calculate_r2_image(r2_percent) | |
| avg_target_html = create_avg_target_display(avg_target, target_name, scale_info) | |
| # Factor performance plot | |
| factor_performance_img = plot_factor_performance(df_analysis, "Factor Performance (Agreement Scores)") | |
| # Run Shapley analysis | |
| temp_dir = tempfile.mkdtemp() | |
| csv_output_path = os.path.join(temp_dir, "results.csv") | |
| # Call R script with proper parameters | |
| r_success = call_r_script_for_target(file_path, csv_output_path, target_name) | |
| if not r_success: | |
| # Clean up and return error | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except: | |
| pass | |
| return create_error_message("R script failed to execute. Shapley analysis not available."), None, None, None | |
| # Check if R script produced output file | |
| if not os.path.exists(csv_output_path): | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except: | |
| pass | |
| return create_error_message("R script did not produce expected output file."), None, None, None | |
| # Load results with renamed predictors | |
| try: | |
| results_df = pd.read_csv(csv_output_path) | |
| except Exception as e: | |
| logger.error(f"Error reading R script output: {e}") | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except: | |
| pass | |
| return create_error_message(f"Error reading R script output: {e}"), None, None, None | |
| # Validate R script output | |
| if "Predictor" not in results_df.columns or "Importance" not in results_df.columns: | |
| logger.error("R script output missing required columns") | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except: | |
| pass | |
| return create_error_message("R script output is invalid - missing required columns."), None, None, None | |
| # Map predictor names if they come from R script with original names | |
| results_df["Predictor"] = results_df["Predictor"].map( | |
| lambda x: column_mapping.get(x, x) | |
| ) | |
| results_df["Importance_percent"] = results_df["Importance"] * 100 | |
| average_value = results_df["Importance_percent"].mean() | |
| # Create driver analysis plot | |
| driver_analysis_img = plot_driver_analysis( | |
| results_df, | |
| average_value, | |
| f"Shapley Driver Analysis" | |
| ) | |
| # Clean up | |
| try: | |
| shutil.rmtree(temp_dir) | |
| except Exception as e: | |
| logger.error(f"Error cleaning up temp files: {e}") | |
| return r2_html, avg_target_html, factor_performance_img, driver_analysis_img | |
| except Exception as e: | |
| logger.error(f"Error analyzing data: {e}") | |
| return create_error_message(f"Analysis failed: {str(e)}"), None, None, None | |
| def load_default_file(): | |
| """Load default file on startup""" | |
| default_file = "example_files/Volkswagen Non Customers.xlsx" | |
| if os.path.exists(default_file): | |
| return analyze_prospects_data(default_file) | |
| return create_placeholder_message(), None, None, None | |
| def handle_file_upload(file): | |
| """Handle file upload and analysis - automatically analyze when file is uploaded""" | |
| if file is None: | |
| return create_placeholder_message(), None, None, None | |
| return analyze_prospects_data(file.name) | |
| # Gradio interface with light theme | |
| css = """ | |
| .metric-container { | |
| display: flex; | |
| justify-content: space-around; | |
| margin: 20px 0; | |
| } | |
| """ | |
| # JavaScript to force light theme | |
| js = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| # Create the Gradio interface with light theme | |
| with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("Analysis showing what drives your target variable (e.g NPS, Decisions) using Factors 1-6") | |
| # File upload section - REMOVED the analyze button | |
| with gr.Row(): | |
| file_upload = gr.File( | |
| label="Upload Excel File - Analysis will start automatically", | |
| file_types=[".xlsx", ".xls"], | |
| type="filepath" | |
| ) | |
| # Metrics row | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| r2_output = gr.HTML() | |
| with gr.Column(scale=1): | |
| avg_target_output = gr.HTML() | |
| # Charts | |
| with gr.Row(): | |
| with gr.Column(): | |
| factor_performance_plot = gr.Image(show_label=False) | |
| with gr.Column(): | |
| driver_analysis_plot = gr.Image(show_label=False) | |
| # REMOVED the analyze button click handler | |
| # Only keep the file upload change handler for automatic analysis | |
| file_upload.change( | |
| fn=handle_file_upload, | |
| inputs=[file_upload], | |
| outputs=[r2_output, avg_target_output, factor_performance_plot, driver_analysis_plot] | |
| ) | |
| # Launch the demo | |
| demo.launch(server_name="0.0.0.0", share=False) |