Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from sklearn.datasets import make_moons | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.svm import LinearSVC | |
| from sklearn.ensemble import RandomForestClassifier | |
| from shiny import App, ui, reactive, render | |
| from shinywidgets import render_plotly, output_widget, render_widget | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) | |
| format="%(asctime)s [%(levelname)s] %(message)s", # log format | |
| ) | |
| logger = logging.getLogger("ELVIS") | |
| X, y = make_moons(noise=0.25, random_state=0) | |
| model = LogisticRegression().fit(X, y) | |
| print(y) | |
| #with ui.sidebar(): | |
| #ui.input_slider( | |
| #"grid_resolution", | |
| #"Grid resolution", | |
| #min=50, | |
| #max=400, | |
| #value=200, | |
| #step=25, | |
| #) | |
| #ui.input_slider( | |
| #"marker_size", | |
| #"Marker size", | |
| #min=4, | |
| #max=20, | |
| #value=8, | |
| #) | |
| #ui.input_action_button("reset", "Reset") | |
| app_ui = ui.page_fluid( | |
| ui.include_css("styles.css"), | |
| ui.layout_columns( | |
| # first column | |
| ui.card(style="border: none !important; box-shadow: none !important;"), | |
| # second column | |
| ui.markdown(""" | |
| # Decision Boundary Visualizer | |
| """), | |
| col_widths=[1, 11] | |
| ), | |
| ui.layout_columns( | |
| #ui.card(class_="card-no-border"), | |
| ui.card(style="border: none !important; box-shadow: none !important;"), | |
| ui.card(output_widget("decision_plot")), #hover=True)), | |
| ui.card(ui.navset_tab( | |
| ui.nav_panel("Dataset", | |
| ui.input_radio_buttons( | |
| "dataset", | |
| "Dataset type", | |
| { | |
| "a": "Draw2D", | |
| "b": "Upload", | |
| "c": "sklearn",}, | |
| inline=True, | |
| ), | |
| ui.input_radio_buttons( | |
| "pointlabel_choice", | |
| "Choose point label", | |
| { | |
| "a": "Gray", | |
| "b": "Orange", | |
| "c": "Blue",}, | |
| inline=True, | |
| ), | |
| ), | |
| ui.nav_panel("Classifier", | |
| ui.input_select( | |
| "classifier", # input ID | |
| "Select classifier:", # label displayed above the dropdown | |
| {"LinearSVC": "LinearSVC", | |
| "RandomForestClassifier": "RandomForestClassifier", "LogisticRegression": "LogisticRegression"}, # choices: value: label | |
| selected="LinearSVC" # optional default selected value | |
| ), | |
| ), | |
| ui.nav_panel("Export", ui.p("This is the table tab")), | |
| ui.nav_panel("Options", ui.p("This is the table tab")), | |
| ui.nav_panel("Usage", ui.markdown(''.join(open('usage.md', 'r').readlines()))), | |
| )), | |
| ui.card(class_="card-no-border"), | |
| col_widths=[1, 6, 4, 1]) | |
| ) | |
| def server(input, output, session): | |
| def decision_data(): | |
| #res = input.grid_resolution() | |
| res = 200 | |
| x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 | |
| y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 | |
| xx, yy = np.meshgrid( | |
| np.linspace(x_min, x_max, res), | |
| np.linspace(y_min, y_max, res), | |
| ) | |
| Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) | |
| Z = Z.reshape(xx.shape) | |
| return xx, yy, Z | |
| def decision_plot(): | |
| print('here') | |
| xx, yy, Z = decision_data() | |
| color = ['blue' if p == 1 else 'red' for p in y] | |
| fig = px.scatter( | |
| x=X[:, 0], | |
| y=X[:, 1], | |
| #marker=dict(color=color), | |
| color=color, | |
| ) | |
| #fig.update_layout( | |
| #xaxis_title="x1", | |
| #yaxis_title="x2", | |
| #showlegend=False, | |
| #) | |
| print(input.classifier()) | |
| if input.classifier() == "LogisticRegression": | |
| model = LogisticRegression() #self.model_class(**parse_param_string(self.model_args)) | |
| if input.classifier() == "LinearSVC": | |
| model = LinearSVC() | |
| if input.classifier() == "RandomForestClassifier": | |
| model = RandomForestClassifier() | |
| logger.info("Fitting model " + str(model)) | |
| model.fit(X, y) | |
| # plot decision boundary in the projected space | |
| # xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), 1, 100), np.linspace(0, 1, 100)) | |
| # Note: Should not apply normalization/jittering to meshgrid points | |
| num_dots = 100 | |
| xx, yy = np.meshgrid(np.linspace(X[:, 0].min(), X[:, 0].max(), num_dots), | |
| np.linspace(X[:, 1].min(), X[:, 1].max(), num_dots)) | |
| #logger.info("grid_x shape = " + str(xx.shape) + "; grid_y shape =" + str(yy.shape)) | |
| grid = np.c_[xx.ravel(), yy.ravel()] | |
| #scores = clf.decision_function(grid)[:, 1].reshape(xx.shape) | |
| #scores = clf.decision_function(grid).reshape(xx.shape) | |
| #ax.contour(xx, yy, scores)#, levels=[0], colors="black", linestyles="--") | |
| print('grid', grid) | |
| #print('inverse', embedder.inverse_transform(grid)) | |
| #preds = model.predict(embedder.inverse_transform(grid)).reshape(xx.shape) | |
| preds = model.predict(grid) | |
| #print(preds.shape, xx.shape, yy.shape) | |
| color = ['blue' if p == 1 else 'red' for p in preds] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=xx.ravel(), | |
| y=yy.ravel(), | |
| mode='markers', # 'lines', 'markers', 'lines+markers' | |
| #color=preds, | |
| marker=dict(color=color, size=1), | |
| #line=dict(color='orange', width=2, dash='dash') # optional style | |
| ) | |
| ) | |
| #ax.scatter(xx.ravel(), yy.ravel(), c=[l2c[l] for l in preds.ravel()], s=1, alpha=0.5) | |
| return fig | |
| #def decision_plot1(): #decision_boundary=False, save_figure=False): | |
| #''' | |
| #Plot data and decision boundary with matplotlib and return as PIL image. | |
| #''' | |
| #logger.info("Initializing figure") | |
| ##fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=self.dpi) | |
| #fig = plt.figure() | |
| ## set entire figure to be the canvas to allow simple conversion of mouse | |
| ## position to coordinates in the figure | |
| #ax = fig.add_axes([0., 0., 1., 1.]) # | |
| #ax.margins(x=0, y=0) # no padding in both directions | |
| ##if self.dataset_type == 'Draw2D': | |
| ### draw canvas boundary | |
| ###ax.scatter([0, 0, 1, 1], [0, 1, 0, 1], color='brown') | |
| ### DO NOT CHANGE THE COLOR OF THE BOUNDARY | |
| ### IT WILL BREAK THE ORIGIN COORDINATE DETECTION | |
| ##ax.plot([0, 0, 1, 1, 0], [0, 1, 1, 0, 0], color='black') | |
| ##for spine in ax.spines.values(): | |
| ##spine.set_color((0.1, 0.1, 0.1)) | |
| ## TODO: allow showing x and y axes with ticks and labels | |
| #if (self.dataset is not None and len(self.dataset) > 0): | |
| #try: | |
| #X = self._get_features() | |
| #y = self.dataset.target.values | |
| #logger.info("Data:\n" + str(X)) | |
| #logger.info("Target:\n" + str(y)) | |
| ## preprocess features | |
| #X = self._process_features(X) | |
| ## embed features to 2D for visualization | |
| #Z, embedder = self._embed_features(X, return_embedder=True) | |
| ##ax.set_title("Click to add points") | |
| #labels = np.unique(y) | |
| #colors = label2color(labels, cmap=self.cmap) | |
| #logger.info("Classes:\n" + str(labels)) | |
| #logger.info("Colors:\n" + str(colors)) | |
| #l2c = dict(zip(labels, colors)) | |
| ## scatter plots for data | |
| #for l, label in enumerate(labels): | |
| ##print('class', label) | |
| ##ax.scatter(*zip(*self.dataset[self.dataset.label == label].features), color=label, label=label) | |
| #subset = Z[y == label] | |
| #ax.scatter(subset[:, 0], subset[:, 1], color=colors[l], label=label, s=self.marker_size) | |
| #ax.legend() | |
| ## plot the decision boundary | |
| #if decision_boundary: | |
| #model = self.model_class(**parse_param_string(self.model_args)) | |
| #logger.info("Fitting model " + str(model)) | |
| #model.fit(X, y) | |
| #self.model = model | |
| ## plot decision boundary in the projected space | |
| ## xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), 1, 100), np.linspace(0, 1, 100)) | |
| ## Note: Should not apply normalization/jittering to meshgrid points | |
| #if self.dataset_type == 'Draw2D': | |
| #xx, yy = np.meshgrid(np.linspace(0, 1, self.num_dots), | |
| #np.linspace(0, 1, self.num_dots)) | |
| #else: | |
| #xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), Z[:, 0].max(), self.num_dots), | |
| #np.linspace(Z[:, 1].min(), Z[:, 1].max(), self.num_dots)) | |
| ##logger.info("grid_x shape = " + str(xx.shape) + "; grid_y shape =" + str(yy.shape)) | |
| #grid = np.c_[xx.ravel(), yy.ravel()] | |
| ##scores = clf.decision_function(grid)[:, 1].reshape(xx.shape) | |
| ##scores = clf.decision_function(grid).reshape(xx.shape) | |
| ##ax.contour(xx, yy, scores)#, levels=[0], colors="black", linestyles="--") | |
| #print('grid', grid) | |
| #print('inverse', embedder.inverse_transform(grid)) | |
| #preds = model.predict(embedder.inverse_transform(grid)).reshape(xx.shape) | |
| ##print(preds.shape, xx.shape, yy.shape) | |
| #ax.scatter(xx.ravel(), yy.ravel(), c=[l2c[l] for l in preds.ravel()], s=1, alpha=0.5) | |
| #except Exception as e: | |
| #print(traceback.format_exc()) | |
| ##raise gr.Error(f"⚠️ {e}") | |
| #gr.Info(f"⚠️ {e}") | |
| #buf = io.BytesIO() | |
| #ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| #plt.close(fig) | |
| #buf.seek(0) | |
| #img = Image.open(buf) | |
| #if save_figure: | |
| #ax.figure.savefig(f"{self.FIGURE_BASENAME}{self.figure_extension}") | |
| ## detect axis pixel positions | |
| #if self.dataset_type == 'Draw2D': | |
| #array = np.array(img.convert("RGB")) | |
| #bgr = cv2.cvtColor(array, cv2.COLOR_RGB2BGR) | |
| #gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) | |
| #black_mask = gray < 0.05 * 255 | |
| #contours, _ = cv2.findContours(black_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| ## find the contour with the largest area | |
| #max_area = 0 | |
| #most_likely_topleft = 0, 0 | |
| #for contour in contours: | |
| #x, y, w, h = cv2.boundingRect(contour) | |
| #area = w * h | |
| #if w * h > max_area: | |
| #max_area = area | |
| #most_likely_topleft = x, y | |
| #self._axis_topleft = most_likely_topleft | |
| #else: | |
| #self._axis_topleft = 0, 0 | |
| ## TODO: add a save function for saving screenshot | |
| ##img.save('image.png') | |
| #return img | |
| # -------------------------------------------------- | |
| # Reset logic | |
| # -------------------------------------------------- | |
| #@reactive.effect | |
| #@reactive.event(input.reset) | |
| #def _(): | |
| ##ui.update_slider("grid_resolution", value=200) | |
| ##ui.update_slider("marker_size", value=8) | |
| #pass | |
| app = App(app_ui, server) | |