Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import datasets as ds | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.ensemble import RandomForestClassifier | |
| from lime.lime_tabular import LimeTabularExplainer | |
| wines = ds.load_dataset("katossky/wine-recognition", split='train') | |
| wines = wines.to_pandas() | |
| wines.columns = wines.columns.str.strip() | |
| predictor = RandomForestClassifier( | |
| n_estimators=1000, max_depth=5, n_jobs=4, | |
| random_state=44 # for reproducibility | |
| ) | |
| predictor.fit( wines.drop('label', axis=1), wines['label'] ) | |
| def plot_explanation(instance_part_1, instance_part_2, instance_part_3, sigma): | |
| instance_pd = pd.concat([instance_part_1, instance_part_2, instance_part_3], axis=1) | |
| instance_np = instance_pd.to_numpy().squeeze() | |
| explainer = LimeTabularExplainer( | |
| training_data = wines.drop('label', axis=1), #.to_numpy(), | |
| feature_names = wines.columns[1:].to_list(), | |
| discretize_continuous = False, kernel_width=sigma | |
| ) | |
| explanation = explainer.explain_instance( | |
| instance_np, | |
| predictor.predict_proba, #, | |
| top_labels=3, | |
| num_features=5 | |
| ) | |
| predictions = predictor.predict_proba(instance_pd)[0] | |
| label = np.argmax(predictions) | |
| confidences = {i: predictions[i] for i in range(3)} | |
| return ( | |
| confidences, | |
| explanation.as_pyplot_figure(label=label) | |
| ) | |
| sigma_default = 0.75*(wines.shape[1]-1)**0.5 | |
| sigma = gr.Slider(0.001, 2*sigma_default, value=sigma_default, label='σ') | |
| instance_complete = wines.sample(1) | |
| instance_part_1 = gr.Dataframe( | |
| label = "Chemical properties of the wine", | |
| headers = wines.columns[1:6].to_list(), | |
| row_count = (1,"fixed"), | |
| col_count = (5, "fixed"), | |
| datatype = "number", | |
| value = instance_complete.iloc[:,1:6].values.tolist() | |
| ) | |
| instance_part_2 = gr.Dataframe( | |
| label = "", | |
| show_label = False, # does not work | |
| headers = wines.columns[6:10].to_list(), | |
| row_count = (1,"fixed"), | |
| col_count = (4, "fixed"), | |
| datatype = "number", | |
| value = instance_complete.iloc[:,6:10].values.tolist() | |
| ) | |
| instance_part_3 = gr.Dataframe( | |
| label = "", | |
| show_label = False, # does not work | |
| headers = wines.columns[10:].to_list(), | |
| row_count = (1,"fixed"), | |
| col_count = (4, "fixed"), | |
| datatype = "number", | |
| value = instance_complete.iloc[:,10:].values.tolist() | |
| ) | |
| demo = gr.Interface( | |
| fn = plot_explanation, | |
| inputs = [instance_part_1, instance_part_2, instance_part_3, sigma], | |
| outputs = ["label", "plot"] | |
| ) | |
| demo.launch() | |