Spaces:
Runtime error
Runtime error
| import joblib | |
| import time | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| FEATS = [ | |
| 'srcip', | |
| 'sport', | |
| 'dstip', | |
| 'dsport', | |
| 'proto', | |
| #'state', I dropped this one when I trained the model | |
| 'dur', | |
| 'sbytes', | |
| 'dbytes', | |
| 'sttl', | |
| 'dttl', | |
| 'sloss', | |
| 'dloss', | |
| 'service', | |
| 'Sload', | |
| 'Dload', | |
| 'Spkts', | |
| 'Dpkts', | |
| 'swin', | |
| 'dwin', | |
| 'stcpb', | |
| 'dtcpb', | |
| 'smeansz', | |
| 'dmeansz', | |
| 'trans_depth', | |
| 'res_bdy_len', | |
| 'Sjit', | |
| 'Djit', | |
| 'Stime', | |
| 'Ltime', | |
| 'Sintpkt', | |
| 'Dintpkt', | |
| 'tcprtt', | |
| 'synack', | |
| 'ackdat', | |
| 'is_sm_ips_ports', | |
| 'ct_state_ttl', | |
| 'ct_flw_http_mthd', | |
| 'is_ftp_login', | |
| 'ct_ftp_cmd', | |
| 'ct_srv_src', | |
| 'ct_srv_dst', | |
| 'ct_dst_ltm', | |
| 'ct_src_ltm', | |
| 'ct_src_dport_ltm', | |
| 'ct_dst_sport_ltm', | |
| 'ct_dst_src_ltm', | |
| ] | |
| # Generated from | |
| # mokole.com/palette.html | |
| COLORS = [ | |
| '#000000', | |
| '#808080', | |
| '#2f4f4f', | |
| '#556b2f', | |
| '#8b4513', | |
| '#228b22', | |
| '#800000', | |
| '#808000', | |
| '#3cb371', | |
| '#663399', | |
| '#b8860b', | |
| '#008b8b', | |
| '#4682b4', | |
| '#d2691e', | |
| '#9acd32', | |
| '#cd5c5c', | |
| '#00008b', | |
| '#32cd32', | |
| '#8fbc8f', | |
| '#b03060', | |
| '#d2b48c', | |
| '#ff0000', | |
| '#ffa500', | |
| '#ffd700', | |
| '#ffff00', | |
| '#0000cd', | |
| '#00ff00', | |
| '#8a2be2', | |
| '#00ff7f', | |
| '#4169e1', | |
| '#dc143c', | |
| '#00ffff', | |
| '#00bfff', | |
| '#f4a460', | |
| '#adff2f', | |
| '#ff6347', | |
| '#da70d6', | |
| '#d8bfd8', | |
| '#ff00ff', | |
| '#f0e68c', | |
| '#6495ed', | |
| '#dda0dd', | |
| '#b0e0e6', | |
| '#98fb98', | |
| '#7fffd4', | |
| '#ff69b4', | |
| ] | |
| def build_parents(tree, visit_order, node_id2plot_id): | |
| parents = [None] | |
| parent_plot_ids = [None] | |
| directions = [None] | |
| for i in visit_order[1:]: | |
| parent = tree[tree['right']==i].index | |
| if parent.empty: | |
| p = tree[tree['left']==i].index[0] | |
| parent_plot_ids.append(str(node_id2plot_id[p])) | |
| parents.append(p) | |
| directions.append('l') | |
| else: | |
| parent_plot_ids.append(str(node_id2plot_id[parent[0]])) | |
| parents.append(parent[0]) | |
| directions.append('r') | |
| return parents, parent_plot_ids, directions | |
| def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions): | |
| labels = ['Histogram Gradient-Boosted Decision Tree'] | |
| colors = ['white'] | |
| for i, parent, parent_plot_id, direction in zip( | |
| visit_order, | |
| parents, | |
| parent_plot_ids, | |
| directions | |
| ): | |
| # skip the first one (the root) | |
| if i == 0: | |
| continue | |
| node = tree.loc[i] | |
| feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])] | |
| thresh = tree.loc[int(parent), 'num_threshold'] | |
| if direction == 'l': | |
| labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}") | |
| else: | |
| labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}") | |
| # colors | |
| offset = FEATS.index(feat) | |
| colors.append(COLORS[offset]) | |
| return labels, colors | |
| def build_plot(tree): | |
| #https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i | |
| # if you use `ids`, then `parents` has to be in terms of `ids` | |
| visit_order = breadth_first_traverse(tree) | |
| node_id2plot_id = {node:i for i, node in enumerate(visit_order)} | |
| parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id) | |
| labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions) | |
| # this should just be ['0', '1', '2', . . .] | |
| plot_ids = [str(node_id2plot_id[x]) for x in visit_order] | |
| return go.Treemap( | |
| values=tree['count'].to_numpy(), | |
| labels=labels, | |
| ids=plot_ids, | |
| parents=parent_plot_ids, | |
| marker_colors=colors, | |
| ) | |
| def breadth_first_traverse(tree): | |
| """ | |
| https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/ | |
| Iterative version makes more sense since I have the whole tree in a table | |
| instead of just nodes and pointers | |
| """ | |
| q = [0] | |
| visited_nodes = [] | |
| while len(q) != 0: | |
| cur = q.pop(0) | |
| visited_nodes.append(cur) | |
| if tree.loc[cur, 'left'] != 0: | |
| q.append(tree.loc[cur, 'left']) | |
| if tree.loc[cur, 'right'] != 0: | |
| q.append(tree.loc[cur, 'right']) | |
| return visited_nodes | |
| def build_donut_plot(): | |
| display_duration = ['a moment', 'an eternity'] | |
| legend_name = [ | |
| '<br>researching the data,<br>cleaning the data,<br>and building the model<br>for this project', | |
| 'battling Plotly' | |
| ] | |
| duration = [1, 19] | |
| df = pd.DataFrame().from_dict({'duration': duration, 'legend_name': legend_name, 'display_duration': display_duration}) | |
| fig = px.pie( | |
| df, | |
| values='duration', | |
| names='legend_name', | |
| hover_name='display_duration', | |
| # The docs claim this is okay. | |
| # Turns out, this is not okay | |
| # because Plotly tries to call .append() on whatever you pass to `hover_data` | |
| # but only if you pass something to `color` | |
| # so I can't set the color if I want to turn off columns in the hover text | |
| hover_data={'duration': False, 'legend_name': False}, | |
| #color=display_duration, | |
| #color_discrete_map={display_duration[0]:'fuschia', display_duration[1]:'rebecca_purple'}, | |
| hole=0.6, | |
| #title="Believe it or not, a bug in Plotly means I can't change these colors" | |
| title="I suppose one more donut won't kill me" | |
| ) | |
| fig.update_traces( | |
| textinfo='none', | |
| hoverlabel={'font':{'size': 20}}, | |
| # This has no effect for some reason | |
| # even though `hoverlabel` and `textinfo` work just fine | |
| marker_colors=['fuschia', 'rebecca_purple'], | |
| selector=dict(type='pie') | |
| ) | |
| return fig | |
| #def build_figures_cached(graph_objs): | |
| #return [go.Figure( | |
| def main(): | |
| # load the data | |
| hgb = joblib.load('hgb_classifier.joblib') | |
| trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors] | |
| # make the plots | |
| graph_objs = [build_plot(tree) for tree in trees] | |
| figures = [go.Figure(graph_obj) for graph_obj in graph_objs] | |
| # each frame has to have a name | |
| # https://community.plotly.com/t/animation-with-slider-not-moving-when-pressing-play/34763/2 | |
| frames = [go.Frame(data=graph_obj, name=str(i)) for i, graph_obj in enumerate(graph_objs)] | |
| # show them with streamlit | |
| #st.markdown('# Thankfully, Visualizing Decision Trees is Hard') | |
| st.markdown('# Thankfully, visualizing decision trees is hard') | |
| st.markdown('## Setting the scene') | |
| st.markdown(""" | |
| I make a lot of dashboards, which means I make a lot of the same plots over and over. | |
| """) | |
| st.plotly_chart(build_donut_plot()) | |
| st.markdown(""" | |
| Desperate for some creative outlet, I wanted to make a new visualization— | |
| something I'd never seen before. | |
| Inspired by interactive visualizations like | |
| [Tensorflow Playground](https://playground.tensorflow.org) | |
| and | |
| [GAN Lab](https://poloclub.github.io/ganlab), | |
| I decided I wanted to watch some kind of gradient-boosted tree as it learned. | |
| """) | |
| st.markdown('## Some kind of gradient-boosted tree') | |
| st.markdown(""" | |
| I trained an ensemble of | |
| [Histogram-based Gradient Boosting Decision Trees](https://scikit-learn.org/stable/modules/ensemble.html#histogram-based-gradient-boosting) | |
| on some | |
| [data](https://research.unsw.edu.au/projects/unsw-nb15-dataset). | |
| That algorithm looks at its mistakes and tries to avoid those mistakes the next time around. | |
| To do that, it starts off with a decision tree. | |
| From there, it looks at the points that tree got wrong and makes another decision tree that tries | |
| to get those points right. | |
| Then it looks at that second tree's mistakes and makes a third tree that tries to fix those mistakes. | |
| And so on. | |
| My model ends up with 10 trees. | |
| """) | |
| st.markdown('## Behold') | |
| st.markdown(""" | |
| I've plotted the progression of those 10 trees as an animated series of interactive Plotly tree maps. | |
| The nodes are color-coded by which feature the decision tree used to make that split. | |
| I've also labeled each node with the feature name and the decision boundary. | |
| If you click on a node, Plotly will show the path to that node in a banner at the top of the plot so you can see how a point ends up in the node you clicked. | |
| The numbers and letters in brackets like `[3.L]` refer to the parent node's position in a breadth-first traversal of the tree and whether the current node is a left or right child of that parent. | |
| The trees are a lot deeper than what it shows in the small plot. | |
| Hugging Face makes the plot ENORMOUS if you expand it, so that isn't much help. | |
| Pick your poison. | |
| Also, it seems to break on Firefox. | |
| It works locally on Firefox, but it breaks when I look at it on Hugging Face on Firefox. | |
| 🤷 | |
| """) | |
| # Build the slider steps | |
| slider_steps = [] | |
| for i in range(len(trees)): | |
| slider_steps.append({ | |
| 'args': [ | |
| [i], | |
| { | |
| 'frame': {'duration': 300, 'redraw': True}, | |
| 'mode': 'immediate', | |
| 'transition': {'duration': 300} | |
| } | |
| ], | |
| 'label': i, | |
| 'method': 'animate', | |
| }) | |
| sliders_dict = { | |
| 'active': 0, | |
| 'currentvalue': { | |
| 'font': {'size': 20}, | |
| 'prefix': 'Tree ', | |
| 'visible': True | |
| }, | |
| 'transition': {'duration': 300}, | |
| 'steps': slider_steps | |
| } | |
| # Maybe just show a Plotly animated chart | |
| # https://plotly.com/python/animations/#using-a-slider-and-buttons | |
| # They don't really document the animation stuff on their website | |
| # but it's in here | |
| # https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json | |
| # I guess it's only in the JS docs and hasn't made it to the Python docs yet | |
| # https://plotly.com/javascript/animations/ | |
| # trying to find stuff here instead | |
| # https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu | |
| # this one finally set the speed | |
| # no mention of how they figured this out but thank goodness I found it | |
| # https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa | |
| # this also has custom animation speeds in it | |
| # https://plotly.com/python/custom-buttons/#reference | |
| ani_fig = go.Figure( | |
| data=graph_objs[0], | |
| frames=frames, | |
| layout=go.Layout( | |
| updatemenus=[{ | |
| 'type':'buttons', | |
| # https://plotly.com/python/reference/layout/updatemenus/ | |
| # Always show the background color on buttons | |
| # streamlit breaks the background color of the active button in darkmode | |
| 'showactive': False, | |
| # background color of the buttons | |
| 'bgcolor': '#fff', | |
| # font in the buttons | |
| 'font': {'color': '#000'}, | |
| # border color of the buttons | |
| 'bordercolor': '#000', | |
| # Play and Pause buttons | |
| # trying to copy this exactly | |
| # https://plotly.com/python/animations/#adding-control-buttons-to-animations | |
| 'buttons':[{ | |
| 'label':'Play', | |
| 'method': 'animate', | |
| 'args':[None, { | |
| 'fromcurrent': True, | |
| 'frame': {'duration':4000}, | |
| 'transition': {'duration': 2000}, | |
| }], | |
| }, | |
| { | |
| 'label': 'Pause', | |
| 'method': 'animate', | |
| 'args':[[None], { | |
| 'frame': {'duration': 0}, | |
| 'transition': {'duration': 0}, | |
| 'mode': 'immediate' | |
| }] | |
| } | |
| ] | |
| }], | |
| # add the slider to the layout | |
| sliders=[sliders_dict] | |
| ) | |
| ) | |
| st.plotly_chart(ani_fig) | |
| st.markdown(""" | |
| This actually turned out to be a lot harder than I thought it would be. | |
| Plotly doesn't have many examples of how to create animations like this in Python. | |
| [The only example I could find](https://plotly.com/python/animations/#using-a-slider-and-buttons) | |
| was derided as an | |
| ["old example [. . .] that is not the best one to learn how to define an animation with slider."](https://community.plotly.com/t/slider-not-updating-during-animation/37261) | |
| That helpful poster didn't point out any other examples, so that one is still pretty much all I have to go on. | |
| Later on, | |
| [a different answer by the same poster](https://community.plotly.com/t/animation-with-slider-not-moving-when-pressing-play/34763) | |
| got me out of a jam. | |
| This `empet` character is pretty much the only one who answers Python posts on Plotly's forums. | |
| As far as I can tell, that's because they're the only person in the world who understands Plotly's Python library. | |
| """) | |
| st.markdown('## Check out the data!') | |
| st.markdown(""" | |
| This plot is similar to the plot above, but the slider here coordinates with a table of the data I extracted to plot each tree. | |
| """) | |
| # This works the way I want | |
| # but the plot is tiny | |
| # also it recalcualtes all of the plots | |
| # every time the slider value changes | |
| # | |
| # This seems to be affecting the animation too | |
| # so I'm going to leave it out | |
| # It's the largest thing by far in the flame graph | |
| # | |
| # I tried to cache the plots but build_plot() takes | |
| # a DataFrame which is mutable and therefore unhashable I guess | |
| # so it won't let me cache that function | |
| # I could pack the dataframe bytes to smuggle them past that check | |
| # but whatever | |
| idx = st.slider( | |
| label='Which tree do you want to see?', | |
| min_value=0, | |
| max_value=len(figures)-1, | |
| value=0, | |
| step=1 | |
| ) | |
| st.markdown(f'### Tree {idx}') | |
| st.plotly_chart(figures[idx]) | |
| st.dataframe(trees[idx]) | |
| st.markdown(""" | |
| This section is mostly just to warn you against making the same foolhardy decision to marry the innermost guts of SciKit-Learn to the sparsely documented world of Plotly animations in Python. | |
| I'm glad it was challenging, though. | |
| I did go into this hoping for something more interesting than a donut plot. | |
| Maybe I'll think on the `value` and `gain` fields a bit and come up with a version 2. | |
| """) | |
| # This is still super slow even if it's only showing the dataframes | |
| # I'm just going to leave it out entirely | |
| #st.markdown('## Check out the data!') | |
| #idx = st.slider( | |
| #label='Which tree do you want to see?', | |
| #min_value=0, | |
| #max_value=len(figures)-1, | |
| #value=0, | |
| #step=1, | |
| #) | |
| #st.dataframe(trees[idx]) | |
| # Cutting this out fixed the broken animation on Hugging Face | |
| # So that issue was perf-based | |
| # | |
| # The issue is back. My theory now is that it's just a Firefox problem. | |
| # works fine on Chromium | |
| #st.markdown('#### (secret third plot)') | |
| #st.markdown(""" | |
| #I orginally had a third viz here where you could move a slider to see the data I used to make each plot. | |
| #That viz recalulated every value in the entire app each time the slider moved. | |
| #I had to remove it to get enough perf for the animation to play correctly. | |
| #If you're feeling brave, you can follow the Quickstart in the README to run this app yourself. | |
| #Then you can uncomment that viz to satisfy your curiosity. | |
| #There's definitely some way to fix it. | |
| #Maybe another milestone for v2. | |
| #""") | |
| if __name__=='__main__': | |
| main() | |