Spaces:
Runtime error
Runtime error
| import joblib | |
| import time | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| FEATS = [ | |
| 'srcip', | |
| 'sport', | |
| 'dstip', | |
| 'dsport', | |
| 'proto', | |
| #'state', I dropped this one when I trained the model | |
| 'dur', | |
| 'sbytes', | |
| 'dbytes', | |
| 'sttl', | |
| 'dttl', | |
| 'sloss', | |
| 'dloss', | |
| 'service', | |
| 'Sload', | |
| 'Dload', | |
| 'Spkts', | |
| 'Dpkts', | |
| 'swin', | |
| 'dwin', | |
| 'stcpb', | |
| 'dtcpb', | |
| 'smeansz', | |
| 'dmeansz', | |
| 'trans_depth', | |
| 'res_bdy_len', | |
| 'Sjit', | |
| 'Djit', | |
| 'Stime', | |
| 'Ltime', | |
| 'Sintpkt', | |
| 'Dintpkt', | |
| 'tcprtt', | |
| 'synack', | |
| 'ackdat', | |
| 'is_sm_ips_ports', | |
| 'ct_state_ttl', | |
| 'ct_flw_http_mthd', | |
| 'is_ftp_login', | |
| 'ct_ftp_cmd', | |
| 'ct_srv_src', | |
| 'ct_srv_dst', | |
| 'ct_dst_ltm', | |
| 'ct_src_ltm', | |
| 'ct_src_dport_ltm', | |
| 'ct_dst_sport_ltm', | |
| 'ct_dst_src_ltm', | |
| ] | |
| COLORS = [ | |
| 'aliceblue','aqua','aquamarine','azure', | |
| 'bisque','black','blanchedalmond','blue', | |
| 'blueviolet','brown','burlywood','cadetblue', | |
| 'chartreuse','chocolate','coral','cornflowerblue', | |
| 'cornsilk','crimson','cyan','darkblue','darkcyan', | |
| 'darkgoldenrod','darkgray','darkgreen', | |
| 'darkkhaki','darkmagenta','darkolivegreen','darkorange', | |
| 'darkorchid','darkred','darksalmon','darkseagreen', | |
| 'darkslateblue','darkslategray', | |
| 'darkturquoise','darkviolet','deeppink','deepskyblue', | |
| 'dimgray','dodgerblue', | |
| 'forestgreen','fuchsia','gainsboro', | |
| 'gold','goldenrod','gray','green', | |
| 'greenyellow','honeydew','hotpink','indianred','indigo', | |
| 'ivory','khaki','lavender','lavenderblush','lawngreen', | |
| 'lemonchiffon','lightblue','lightcoral','lightcyan', | |
| 'lightgoldenrodyellow','lightgray', | |
| 'lightgreen','lightpink','lightsalmon','lightseagreen', | |
| 'lightskyblue','lightslategray', | |
| 'lightsteelblue','lightyellow','lime','limegreen', | |
| 'linen','magenta','maroon','mediumaquamarine', | |
| 'mediumblue','mediumorchid','mediumpurple', | |
| 'mediumseagreen','mediumslateblue','mediumspringgreen', | |
| 'mediumturquoise','mediumvioletred','midnightblue', | |
| 'mintcream','mistyrose','moccasin','navy', | |
| 'oldlace','olive','olivedrab','orange','orangered', | |
| 'orchid','palegoldenrod','palegreen','paleturquoise', | |
| 'palevioletred','papayawhip','peachpuff','peru','pink', | |
| 'plum','powderblue','purple','red','rosybrown', | |
| 'royalblue','saddlebrown','salmon','sandybrown', | |
| 'seagreen','seashell','sienna','silver','skyblue', | |
| 'slateblue','slategray','slategrey','snow','springgreen', | |
| 'steelblue','tan','teal','thistle','tomato','turquoise', | |
| 'violet','wheat','yellow','yellowgreen' | |
| ] | |
| def build_parents(tree, visit_order, node_id2plot_id): | |
| parents = [None] | |
| parent_plot_ids = [None] | |
| directions = [None] | |
| for i in visit_order[1:]: | |
| parent = tree[tree['right']==i].index | |
| if parent.empty: | |
| p = tree[tree['left']==i].index[0] | |
| parent_plot_ids.append(str(node_id2plot_id[p])) | |
| parents.append(p) | |
| directions.append('l') | |
| else: | |
| parent_plot_ids.append(str(node_id2plot_id[parent[0]])) | |
| parents.append(parent[0]) | |
| directions.append('r') | |
| return parents, parent_plot_ids, directions | |
| def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions): | |
| labels = ['Histogram Gradient-Boosted Decision Tree'] | |
| colors = ['white'] | |
| for i, parent, parent_plot_id, direction in zip( | |
| visit_order, | |
| parents, | |
| parent_plot_ids, | |
| directions | |
| ): | |
| # skip the first one (the root) | |
| if i == 0: | |
| continue | |
| node = tree.loc[i] | |
| feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])] | |
| thresh = tree.loc[int(parent), 'num_threshold'] | |
| if direction == 'l': | |
| labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}") | |
| else: | |
| labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}") | |
| # colors | |
| offset = FEATS.index(feat) | |
| colors.append(COLORS[offset]) | |
| return labels, colors | |
| def build_plot(tree): | |
| #https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i | |
| # if you use `ids`, then `parents` has to be in terms of `ids` | |
| visit_order = breadth_first_traverse(tree) | |
| node_id2plot_id = {node:i for i, node in enumerate(visit_order)} | |
| parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id) | |
| labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions) | |
| # this should just be ['0', '1', '2', . . .] | |
| plot_ids = [str(node_id2plot_id[x]) for x in visit_order] | |
| return go.Treemap( | |
| values=tree['count'].to_numpy(), | |
| labels=labels, | |
| ids=plot_ids, | |
| parents=parent_plot_ids, | |
| marker_colors=colors, | |
| ) | |
| def breadth_first_traverse(tree): | |
| """ | |
| https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/ | |
| Iterative version makes more sense since I have the whole tree in a table | |
| instead of just nodes and pointers | |
| """ | |
| q = [0] | |
| visited_nodes = [] | |
| while len(q) != 0: | |
| cur = q.pop(0) | |
| visited_nodes.append(cur) | |
| if tree.loc[cur, 'left'] != 0: | |
| q.append(tree.loc[cur, 'left']) | |
| if tree.loc[cur, 'right'] != 0: | |
| q.append(tree.loc[cur, 'right']) | |
| return visited_nodes | |
| def main(): | |
| # load the data | |
| hgb = joblib.load('hgb_classifier.joblib') | |
| trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors] | |
| # make the plots | |
| graph_objs = [build_plot(tree) for tree in trees] | |
| figures = [go.Figure(graph_obj) for graph_obj in graph_objs] | |
| frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs] | |
| # show them with streamlit | |
| # this puts them all on the screen at once | |
| # like each new one shows up below the previous one | |
| # instead of replacing the previous one | |
| #for fig in figures: | |
| # st.plotly_chart(fig) | |
| # time.sleep(1) | |
| # This works the way I want | |
| # but the plot is tiny | |
| # also it recalcualtes all of the plots | |
| # every time the slider value changes | |
| # | |
| # I tried to cache the plots but build_plot() takes | |
| # a DataFrame which is mutable and therefore unhashable I guess | |
| # so it won't let me cache that function | |
| # I could pack the dataframe bytes to smuggle them past that check | |
| # but whatever | |
| idx = st.slider( | |
| label='which step to show', | |
| min_value=0, | |
| max_value=len(figures)-1, | |
| value=0, | |
| step=1 | |
| ) | |
| st.plotly_chart(figures[idx]) | |
| st.markdown(f'## Tree {idx}') | |
| st.dataframe(trees[idx]) | |
| # Maybe just show a Plotly animated chart | |
| # https://plotly.com/python/animations/#using-a-slider-and-buttons | |
| # They don't really document the animation stuff on their website | |
| # but it's in here | |
| # https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json | |
| # I guess it's only in the JS docs and hasn't made it to the Python docs yet | |
| # https://plotly.com/javascript/animations/ | |
| # trying to find stuff here instead | |
| # https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu | |
| # this one finally set the speed | |
| # no mention of how they figured this out but thank goodness I found it | |
| # https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa | |
| ani_fig = go.Figure( | |
| data=graph_objs[0], | |
| frames=frames, | |
| layout=go.Layout( | |
| updatemenus=[{ | |
| 'type':'buttons', | |
| 'buttons':[{ | |
| 'label':'Play', | |
| 'method': 'animate', | |
| 'args':[None, { | |
| 'frame': {'duration':5000}, | |
| 'transition': {'duration': 2500} | |
| }] | |
| }] | |
| }] | |
| ) | |
| ) | |
| st.plotly_chart(ani_fig) | |
| if __name__=='__main__': | |
| main() | |