Spaces:
Runtime error
Runtime error
| """Gradio demo for different clustering techiniques | |
| Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html | |
| """ | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sklearn.cluster import ( | |
| AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth | |
| ) | |
| from sklearn.datasets import make_blobs, make_circles, make_moons | |
| from sklearn.mixture import GaussianMixture | |
| from sklearn.neighbors import kneighbors_graph | |
| from sklearn.preprocessing import StandardScaler | |
| plt.style.use('seaborn') | |
| SEED = 0 | |
| N_CLUSTERS = 4 | |
| N_SAMPLES = 1000 | |
| np.random.seed(SEED) | |
| def normalize(X): | |
| return StandardScaler().fit_transform(X) | |
| def get_regular(): | |
| centers = [[1, 1], [1, -1], [-1, 1], [-1, -1]] | |
| assert len(centers) == N_CLUSTERS | |
| X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.7, random_state=SEED) | |
| return normalize(X), labels | |
| def get_circles(): | |
| X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED) | |
| return normalize(X), labels | |
| def get_moons(): | |
| X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED) | |
| return normalize(X), labels | |
| def get_noise(): | |
| X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES) | |
| return normalize(X), labels | |
| def get_anisotropic(): | |
| X, labels = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS, random_state=170) | |
| transformation = [[0.6, -0.6], [-0.4, 0.8]] | |
| X = np.dot(X, transformation) | |
| return X, labels | |
| def get_varied(): | |
| X, labels = make_blobs( | |
| n_samples=N_SAMPLES, cluster_std=[1.0, 2.5, 0.5], random_state=SEED | |
| ) | |
| return normalize(X), labels | |
| DATA_MAPPING = { | |
| 'regular': get_regular, | |
| 'circles': get_circles, | |
| 'moons': get_moons, | |
| 'noise': get_noise, | |
| 'anisotropic': get_anisotropic, | |
| 'varied': get_varied, | |
| } | |
| def get_kmeans(X, **kwargs): | |
| model = KMeans(init="k-means++", n_clusters=N_CLUSTERS, n_init=10, random_state=SEED) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| def get_dbscan(X, **kwargs): | |
| model = DBSCAN(eps=0.3) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| def get_agglomerative(X, **kwargs): | |
| connectivity = kneighbors_graph( | |
| X, n_neighbors=N_CLUSTERS, include_self=False | |
| ) | |
| # make connectivity symmetric | |
| connectivity = 0.5 * (connectivity + connectivity.T) | |
| model = AgglomerativeClustering( | |
| n_clusters=N_CLUSTERS, linkage="ward", connectivity=connectivity | |
| ) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| def get_meanshift(X, **kwargs): | |
| bandwidth = estimate_bandwidth(X, quantile=0.3) | |
| model = MeanShift(bandwidth=bandwidth, bin_seeding=True) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| def get_spectral(X, **kwargs): | |
| model = SpectralClustering( | |
| n_clusters=N_CLUSTERS, | |
| eigen_solver="arpack", | |
| affinity="nearest_neighbors", | |
| ) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| def get_optics(X, **kwargs): | |
| model = OPTICS( | |
| min_samples=7, | |
| xi=0.05, | |
| min_cluster_size=0.1, | |
| ) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| def get_birch(X, **kwargs): | |
| model = Birch(n_clusters=3) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| def get_gaussianmixture(X, **kwargs): | |
| model = GaussianMixture( | |
| n_components=N_CLUSTERS, covariance_type="full", random_state=SEED, | |
| ) | |
| model.set_params(**kwargs) | |
| return model.fit(X) | |
| MODEL_MAPPING = { | |
| 'KMeans': get_kmeans, | |
| 'DBSCAN': get_dbscan, | |
| 'AgglomerativeClustering': get_agglomerative, | |
| 'MeanShift': get_meanshift, | |
| 'SpectralClustering': get_spectral, | |
| 'OPTICS': get_optics, | |
| 'Birch': get_birch, | |
| 'GaussianMixture': get_gaussianmixture, | |
| } | |
| def plot_clusters(ax, X, labels): | |
| for label in range(N_CLUSTERS): | |
| idx = labels == label | |
| if not sum(idx): | |
| continue | |
| ax.scatter(X[idx, 0], X[idx, 1]) | |
| ax.grid(None) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| return ax | |
| def cluster(clustering_algorithm: str, dataset: str): | |
| X, labels = DATA_MAPPING[dataset]() | |
| model = MODEL_MAPPING[clustering_algorithm](X) | |
| if hasattr(model, "labels_"): | |
| y_pred = model.labels_.astype(int) | |
| else: | |
| y_pred = model.predict(X) | |
| fig, axes = plt.subplots(1, 2, figsize=(16, 8)) | |
| ax = axes[0] | |
| plot_clusters(ax, X, labels) | |
| ax.set_title("True clusters") | |
| ax = axes[1] | |
| plot_clusters(ax, X, y_pred) | |
| ax.set_title(clustering_algorithm) | |
| return fig | |
| demo = gr.Interface( | |
| fn=cluster, | |
| inputs=[ | |
| gr.Radio( | |
| list(MODEL_MAPPING), | |
| value="KMeans", | |
| label="clustering algorithm" | |
| ), | |
| gr.Radio( | |
| list(DATA_MAPPING), | |
| value="regular", | |
| label="dataset" | |
| ), | |
| ], | |
| outputs=gr.Plot(), | |
| ) | |
| demo.launch() | |