Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from sklearn.datasets import load_iris
|
| 4 |
+
from sklearn.svm import SVC
|
| 5 |
+
from sklearn.model_selection import StratifiedKFold, permutation_test_score
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tempfile
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
def run_permutation_test(display_option, kernel, random_state, n_permutations):
|
| 11 |
+
iris = load_iris()
|
| 12 |
+
X = iris.data
|
| 13 |
+
y = iris.target
|
| 14 |
+
|
| 15 |
+
n_uncorrelated_features = 20
|
| 16 |
+
rng = np.random.RandomState(seed=0)
|
| 17 |
+
X_rand = rng.normal(size=(X.shape[0], n_uncorrelated_features))
|
| 18 |
+
|
| 19 |
+
clf = SVC(kernel=kernel, random_state=random_state)
|
| 20 |
+
cv = StratifiedKFold(2, shuffle=True, random_state=0)
|
| 21 |
+
|
| 22 |
+
score_iris, perm_scores_iris, pvalue_iris = permutation_test_score(
|
| 23 |
+
clf, X, y, scoring="accuracy", cv=cv, n_permutations=n_permutations
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
score_rand, perm_scores_rand, pvalue_rand = permutation_test_score(
|
| 27 |
+
clf, X_rand, y, scoring="accuracy", cv=cv, n_permutations=n_permutations
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
original_plot_path = None
|
| 31 |
+
random_plot_path = None
|
| 32 |
+
|
| 33 |
+
if display_option in ['original', 'both']:
|
| 34 |
+
# Original data
|
| 35 |
+
fig, ax = plt.subplots()
|
| 36 |
+
ax.hist(perm_scores_iris, bins=20, density=True)
|
| 37 |
+
ax.axvline(score_iris, ls="--", color="r")
|
| 38 |
+
score_label = f"Score on original\ndata: {score_iris:.2f}\n(p-value: {pvalue_iris:.3f})"
|
| 39 |
+
ax.text(0.7, 10, score_label, fontsize=12)
|
| 40 |
+
ax.set_xlabel("Accuracy score")
|
| 41 |
+
ax.set_ylabel("Probability")
|
| 42 |
+
original_plot_path = os.path.join(tempfile.mkdtemp(), "original_plot.png")
|
| 43 |
+
plt.savefig(original_plot_path)
|
| 44 |
+
plt.close()
|
| 45 |
+
|
| 46 |
+
if display_option in ['random', 'both']:
|
| 47 |
+
# Random data
|
| 48 |
+
fig, ax = plt.subplots()
|
| 49 |
+
ax.hist(perm_scores_rand, bins=20, density=True)
|
| 50 |
+
ax.set_xlim(0.13)
|
| 51 |
+
ax.axvline(score_rand, ls="--", color="r")
|
| 52 |
+
score_label = f"Score on original\ndata: {score_rand:.2f}\n(p-value: {pvalue_rand:.3f})"
|
| 53 |
+
ax.text(0.14, 7.5, score_label, fontsize=12)
|
| 54 |
+
ax.set_xlabel("Accuracy score")
|
| 55 |
+
ax.set_ylabel("Probability")
|
| 56 |
+
random_plot_path = os.path.join(tempfile.mkdtemp(), "random_plot.png")
|
| 57 |
+
plt.savefig(random_plot_path)
|
| 58 |
+
plt.close()
|
| 59 |
+
|
| 60 |
+
return original_plot_path, random_plot_path
|
| 61 |
+
|
| 62 |
+
iface = gr.Interface(
|
| 63 |
+
fn=run_permutation_test,
|
| 64 |
+
inputs=[
|
| 65 |
+
gr.inputs.Dropdown(
|
| 66 |
+
choices=["original", "random", "both"],
|
| 67 |
+
label="Display Option",
|
| 68 |
+
default="both"
|
| 69 |
+
),
|
| 70 |
+
gr.inputs.Dropdown(
|
| 71 |
+
choices=["linear", "rbf", "poly"],
|
| 72 |
+
label="Kernel",
|
| 73 |
+
default="linear"
|
| 74 |
+
),
|
| 75 |
+
gr.inputs.Slider(
|
| 76 |
+
minimum=0, maximum=10, step=1,
|
| 77 |
+
label="Random State",
|
| 78 |
+
default=7
|
| 79 |
+
),
|
| 80 |
+
gr.inputs.Slider(
|
| 81 |
+
minimum=100, maximum=2000, step=100,
|
| 82 |
+
label="Number of Permutations",
|
| 83 |
+
default=1000
|
| 84 |
+
)
|
| 85 |
+
],
|
| 86 |
+
outputs=["image", "image"],
|
| 87 |
+
title="Test with permutations the significance of a classification score",
|
| 88 |
+
description="This example demonstrates the use of permutation_test_score to evaluate the significance of a cross-validated score using permutations. This operation is being performed on the Iris Dataset. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/model_selection/plot_permutation_tests_for_classification.html",
|
| 89 |
+
examples=[
|
| 90 |
+
["both", "linear", 7, 1000],
|
| 91 |
+
["original", "rbf", 3, 500],
|
| 92 |
+
["random", "poly", 5, 1500]
|
| 93 |
+
],
|
| 94 |
+
allow_flagging=False
|
| 95 |
+
)
|
| 96 |
+
iface.launch()
|