Spaces:
Runtime error
Runtime error
none
commited on
Commit
·
045d7d4
0
Parent(s):
Working version of the streamlit animation
Browse files- README.md +1 -0
- streamlit_viz.py +254 -0
- train_classifier.py +86 -0
- viz_classifier.py +215 -0
README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
The `id` column is baloney. There are lots of duplicates.
|
streamlit_viz.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
FEATS = [
|
| 10 |
+
'srcip',
|
| 11 |
+
'sport',
|
| 12 |
+
'dstip',
|
| 13 |
+
'dsport',
|
| 14 |
+
'proto',
|
| 15 |
+
#'state', I dropped this one when I trained the model
|
| 16 |
+
'dur',
|
| 17 |
+
'sbytes',
|
| 18 |
+
'dbytes',
|
| 19 |
+
'sttl',
|
| 20 |
+
'dttl',
|
| 21 |
+
'sloss',
|
| 22 |
+
'dloss',
|
| 23 |
+
'service',
|
| 24 |
+
'Sload',
|
| 25 |
+
'Dload',
|
| 26 |
+
'Spkts',
|
| 27 |
+
'Dpkts',
|
| 28 |
+
'swin',
|
| 29 |
+
'dwin',
|
| 30 |
+
'stcpb',
|
| 31 |
+
'dtcpb',
|
| 32 |
+
'smeansz',
|
| 33 |
+
'dmeansz',
|
| 34 |
+
'trans_depth',
|
| 35 |
+
'res_bdy_len',
|
| 36 |
+
'Sjit',
|
| 37 |
+
'Djit',
|
| 38 |
+
'Stime',
|
| 39 |
+
'Ltime',
|
| 40 |
+
'Sintpkt',
|
| 41 |
+
'Dintpkt',
|
| 42 |
+
'tcprtt',
|
| 43 |
+
'synack',
|
| 44 |
+
'ackdat',
|
| 45 |
+
'is_sm_ips_ports',
|
| 46 |
+
'ct_state_ttl',
|
| 47 |
+
'ct_flw_http_mthd',
|
| 48 |
+
'is_ftp_login',
|
| 49 |
+
'ct_ftp_cmd',
|
| 50 |
+
'ct_srv_src',
|
| 51 |
+
'ct_srv_dst',
|
| 52 |
+
'ct_dst_ltm',
|
| 53 |
+
'ct_src_ltm',
|
| 54 |
+
'ct_src_dport_ltm',
|
| 55 |
+
'ct_dst_sport_ltm',
|
| 56 |
+
'ct_dst_src_ltm',
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
COLORS = [
|
| 60 |
+
'aliceblue','aqua','aquamarine','azure',
|
| 61 |
+
'bisque','black','blanchedalmond','blue',
|
| 62 |
+
'blueviolet','brown','burlywood','cadetblue',
|
| 63 |
+
'chartreuse','chocolate','coral','cornflowerblue',
|
| 64 |
+
'cornsilk','crimson','cyan','darkblue','darkcyan',
|
| 65 |
+
'darkgoldenrod','darkgray','darkgreen',
|
| 66 |
+
'darkkhaki','darkmagenta','darkolivegreen','darkorange',
|
| 67 |
+
'darkorchid','darkred','darksalmon','darkseagreen',
|
| 68 |
+
'darkslateblue','darkslategray',
|
| 69 |
+
'darkturquoise','darkviolet','deeppink','deepskyblue',
|
| 70 |
+
'dimgray','dodgerblue',
|
| 71 |
+
'forestgreen','fuchsia','gainsboro',
|
| 72 |
+
'gold','goldenrod','gray','green',
|
| 73 |
+
'greenyellow','honeydew','hotpink','indianred','indigo',
|
| 74 |
+
'ivory','khaki','lavender','lavenderblush','lawngreen',
|
| 75 |
+
'lemonchiffon','lightblue','lightcoral','lightcyan',
|
| 76 |
+
'lightgoldenrodyellow','lightgray',
|
| 77 |
+
'lightgreen','lightpink','lightsalmon','lightseagreen',
|
| 78 |
+
'lightskyblue','lightslategray',
|
| 79 |
+
'lightsteelblue','lightyellow','lime','limegreen',
|
| 80 |
+
'linen','magenta','maroon','mediumaquamarine',
|
| 81 |
+
'mediumblue','mediumorchid','mediumpurple',
|
| 82 |
+
'mediumseagreen','mediumslateblue','mediumspringgreen',
|
| 83 |
+
'mediumturquoise','mediumvioletred','midnightblue',
|
| 84 |
+
'mintcream','mistyrose','moccasin','navy',
|
| 85 |
+
'oldlace','olive','olivedrab','orange','orangered',
|
| 86 |
+
'orchid','palegoldenrod','palegreen','paleturquoise',
|
| 87 |
+
'palevioletred','papayawhip','peachpuff','peru','pink',
|
| 88 |
+
'plum','powderblue','purple','red','rosybrown',
|
| 89 |
+
'royalblue','saddlebrown','salmon','sandybrown',
|
| 90 |
+
'seagreen','seashell','sienna','silver','skyblue',
|
| 91 |
+
'slateblue','slategray','slategrey','snow','springgreen',
|
| 92 |
+
'steelblue','tan','teal','thistle','tomato','turquoise',
|
| 93 |
+
'violet','wheat','yellow','yellowgreen'
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
def build_parents(tree, visit_order, node_id2plot_id):
|
| 97 |
+
parents = [None]
|
| 98 |
+
parent_plot_ids = [None]
|
| 99 |
+
directions = [None]
|
| 100 |
+
for i in visit_order[1:]:
|
| 101 |
+
parent = tree[tree['right']==i].index
|
| 102 |
+
if parent.empty:
|
| 103 |
+
p = tree[tree['left']==i].index[0]
|
| 104 |
+
parent_plot_ids.append(str(node_id2plot_id[p]))
|
| 105 |
+
parents.append(p)
|
| 106 |
+
directions.append('l')
|
| 107 |
+
else:
|
| 108 |
+
parent_plot_ids.append(str(node_id2plot_id[parent[0]]))
|
| 109 |
+
parents.append(parent[0])
|
| 110 |
+
directions.append('r')
|
| 111 |
+
return parents, parent_plot_ids, directions
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions):
|
| 115 |
+
labels = ['Histogram Gradient-Boosted Decision Tree']
|
| 116 |
+
colors = ['white']
|
| 117 |
+
for i, parent, parent_plot_id, direction in zip(
|
| 118 |
+
visit_order,
|
| 119 |
+
parents,
|
| 120 |
+
parent_plot_ids,
|
| 121 |
+
directions
|
| 122 |
+
):
|
| 123 |
+
# skip the first one (the root)
|
| 124 |
+
if i == 0:
|
| 125 |
+
continue
|
| 126 |
+
node = tree.loc[i]
|
| 127 |
+
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
|
| 128 |
+
|
| 129 |
+
thresh = tree.loc[int(parent), 'num_threshold']
|
| 130 |
+
if direction == 'l':
|
| 131 |
+
labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}")
|
| 132 |
+
else:
|
| 133 |
+
labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}")
|
| 134 |
+
|
| 135 |
+
# colors
|
| 136 |
+
offset = FEATS.index(feat)
|
| 137 |
+
colors.append(COLORS[offset])
|
| 138 |
+
return labels, colors
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def build_plot(tree):
|
| 142 |
+
#https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i
|
| 143 |
+
# if you use `ids`, then `parents` has to be in terms of `ids`
|
| 144 |
+
visit_order = breadth_first_traverse(tree)
|
| 145 |
+
node_id2plot_id = {node:i for i, node in enumerate(visit_order)}
|
| 146 |
+
parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id)
|
| 147 |
+
labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions)
|
| 148 |
+
# this should just be ['0', '1', '2', . . .]
|
| 149 |
+
plot_ids = [str(node_id2plot_id[x]) for x in visit_order]
|
| 150 |
+
|
| 151 |
+
return go.Treemap(
|
| 152 |
+
values=tree['count'].to_numpy(),
|
| 153 |
+
labels=labels,
|
| 154 |
+
ids=plot_ids,
|
| 155 |
+
parents=parent_plot_ids,
|
| 156 |
+
marker_colors=colors,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def breadth_first_traverse(tree):
|
| 161 |
+
"""
|
| 162 |
+
https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/
|
| 163 |
+
Iterative version makes more sense since I have the whole tree in a table
|
| 164 |
+
instead of just nodes and pointers
|
| 165 |
+
"""
|
| 166 |
+
q = [0]
|
| 167 |
+
visited_nodes = []
|
| 168 |
+
while len(q) != 0:
|
| 169 |
+
cur = q.pop(0)
|
| 170 |
+
visited_nodes.append(cur)
|
| 171 |
+
|
| 172 |
+
if tree.loc[cur, 'left'] != 0:
|
| 173 |
+
q.append(tree.loc[cur, 'left'])
|
| 174 |
+
|
| 175 |
+
if tree.loc[cur, 'right'] != 0:
|
| 176 |
+
q.append(tree.loc[cur, 'right'])
|
| 177 |
+
|
| 178 |
+
return visited_nodes
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def main():
|
| 182 |
+
# load the data
|
| 183 |
+
hgb = joblib.load('hgb_classifier.joblib')
|
| 184 |
+
trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors]
|
| 185 |
+
# make the plots
|
| 186 |
+
graph_objs = [build_plot(tree) for tree in trees]
|
| 187 |
+
figures = [go.Figure(graph_obj) for graph_obj in graph_objs]
|
| 188 |
+
frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs]
|
| 189 |
+
# show them with streamlit
|
| 190 |
+
|
| 191 |
+
# this puts them all on the screen at once
|
| 192 |
+
# like each new one shows up below the previous one
|
| 193 |
+
# instead of replacing the previous one
|
| 194 |
+
#for fig in figures:
|
| 195 |
+
# st.plotly_chart(fig)
|
| 196 |
+
# time.sleep(1)
|
| 197 |
+
|
| 198 |
+
# This works the way I want
|
| 199 |
+
# but the plot is tiny
|
| 200 |
+
# also it recalcualtes all of the plots
|
| 201 |
+
# every time the slider value changes
|
| 202 |
+
#
|
| 203 |
+
# I tried to cache the plots but build_plot() takes
|
| 204 |
+
# a DataFrame which is mutable and therefore unhashable I guess
|
| 205 |
+
# so it won't let me cache that function
|
| 206 |
+
# I could pack the dataframe bytes to smuggle them past that check
|
| 207 |
+
# but whatever
|
| 208 |
+
idx = st.slider(
|
| 209 |
+
label='which step to show',
|
| 210 |
+
min_value=0,
|
| 211 |
+
max_value=len(figures)-1,
|
| 212 |
+
value=0,
|
| 213 |
+
step=1
|
| 214 |
+
)
|
| 215 |
+
st.plotly_chart(figures[idx])
|
| 216 |
+
st.markdown(f'## Tree {idx}')
|
| 217 |
+
st.dataframe(trees[idx])
|
| 218 |
+
|
| 219 |
+
# Maybe just show a Plotly animated chart
|
| 220 |
+
# https://plotly.com/python/animations/#using-a-slider-and-buttons
|
| 221 |
+
# They don't really document the animation stuff on their website
|
| 222 |
+
# but it's in here
|
| 223 |
+
# https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json
|
| 224 |
+
# I guess it's only in the JS docs and hasn't made it to the Python docs yet
|
| 225 |
+
# https://plotly.com/javascript/animations/
|
| 226 |
+
# trying to find stuff here instead
|
| 227 |
+
# https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu
|
| 228 |
+
|
| 229 |
+
# this one finally set the speed
|
| 230 |
+
# no mention of how they figured this out but thank goodness I found it
|
| 231 |
+
# https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa
|
| 232 |
+
ani_fig = go.Figure(
|
| 233 |
+
data=graph_objs[0],
|
| 234 |
+
frames=frames,
|
| 235 |
+
layout=go.Layout(
|
| 236 |
+
updatemenus=[{
|
| 237 |
+
'type':'buttons',
|
| 238 |
+
'buttons':[{
|
| 239 |
+
'label':'Play',
|
| 240 |
+
'method': 'animate',
|
| 241 |
+
'args':[None, {
|
| 242 |
+
'frame': {'duration':5000},
|
| 243 |
+
'transition': {'duration': 2500}
|
| 244 |
+
}]
|
| 245 |
+
}]
|
| 246 |
+
}]
|
| 247 |
+
)
|
| 248 |
+
)
|
| 249 |
+
st.plotly_chart(ani_fig)
|
| 250 |
+
|
| 251 |
+
if __name__=='__main__':
|
| 252 |
+
main()
|
| 253 |
+
|
| 254 |
+
|
train_classifier.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
|
| 6 |
+
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
|
| 7 |
+
from sklearn.metrics import classification_report
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
train_df = pd.read_csv('train_data.csv', na_values='-')
|
| 12 |
+
# `service` is about half-empty and the rest are completely full
|
| 13 |
+
# one of the rows has `no` for `state` which isn't listed as an option in the description of the fields
|
| 14 |
+
# I'm just going to delete that
|
| 15 |
+
train_df = train_df.drop(columns=['id'])
|
| 16 |
+
train_df = train_df.drop(index=train_df[train_df['state']=='no'].index)
|
| 17 |
+
|
| 18 |
+
# It can predict `label` really well ~0.95 accuracy/f1/whatever other stat you care about
|
| 19 |
+
# It does a lot worse trying to predict `attack_cat` b/c there are 10 classes
|
| 20 |
+
# and some of them are not well-represented
|
| 21 |
+
# so that might be more interesting to visualize
|
| 22 |
+
cheating = train_df.pop('attack_cat')
|
| 23 |
+
y_enc = LabelEncoder().fit(train_df['label'])
|
| 24 |
+
train_y = y_enc.transform(train_df.pop('label'))
|
| 25 |
+
x_enc = OrdinalEncoder().fit(train_df)
|
| 26 |
+
train_df = x_enc.transform(train_df)
|
| 27 |
+
|
| 28 |
+
# Random forest doesn't handle NaNs
|
| 29 |
+
# I could drop the `service` column or I can use the HistGradientBoostingClassifier
|
| 30 |
+
# super helpful error message from sklearn pointed me to this list
|
| 31 |
+
# https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values
|
| 32 |
+
#rf = RandomForestClassifier()
|
| 33 |
+
#rf.fit(train_df, y_train)
|
| 34 |
+
|
| 35 |
+
# max_iter is the number of time it builds a gradient-boosted tree
|
| 36 |
+
# so it's the number of estimators
|
| 37 |
+
hgb = HistGradientBoostingClassifier(max_iter=10).fit(train_df, train_y)
|
| 38 |
+
joblib.dump(hgb, 'hgb_classifier.joblib', compress=9)
|
| 39 |
+
|
| 40 |
+
test_df = pd.read_csv('test_data.csv', na_values='-')
|
| 41 |
+
test_df = test_df.drop(columns=['id', 'attack_cat'])
|
| 42 |
+
test_y = y_enc.transform(test_df.pop('label'))
|
| 43 |
+
test_df = x_enc.transform(test_df)
|
| 44 |
+
test_preds = hgb.predict(test_df)
|
| 45 |
+
print(classification_report(test_y, test_preds))
|
| 46 |
+
|
| 47 |
+
# I guess they took out the RF feature importance
|
| 48 |
+
# or maybe that's only in XGBoost
|
| 49 |
+
# you can still kind of get to it
|
| 50 |
+
# with RandomForestClassifier.feature_importances_
|
| 51 |
+
# or like this
|
| 52 |
+
# https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html
|
| 53 |
+
# but there's really nothing for the HistGradientBoostingClassifier
|
| 54 |
+
# but you can get to the actual nodes for each predictor/estimator like this
|
| 55 |
+
# hgb._predictors[i][0].nodes
|
| 56 |
+
# and that has information gain metric for each node which might be viz-able
|
| 57 |
+
# so that might be an interesting viz
|
| 58 |
+
# like plot the whole forest
|
| 59 |
+
# maybe only do like 10 estimators to keep it smaller
|
| 60 |
+
# or stick with 100 and figure out a good way to viz big models
|
| 61 |
+
# the first two estimators are almost identical
|
| 62 |
+
# so maybe like plot the first estimator
|
| 63 |
+
# and then fuzz the nodes by how much the other estimators differ
|
| 64 |
+
# assuming there's some things they all agree on exactly and others where they differ a little bit
|
| 65 |
+
# idk I don't really know how the algorithm works
|
| 66 |
+
# the 96th estimator looks pretty different (I'm assuming from boosting)
|
| 67 |
+
# so maybe like an evolution animation from the first to the last
|
| 68 |
+
# to see the effect of the boosting
|
| 69 |
+
# like plot the points and show how the decision boundary shifts with each generation
|
| 70 |
+
# alongside an animation of the actual decision tree morphing each step
|
| 71 |
+
# That might look too much like an animation of the model being trained though
|
| 72 |
+
# which I guess that's sort of what it is so idk
|
| 73 |
+
|
| 74 |
+
# https://scikit-learn.org/stable/modules/ensemble.html#interpretation-with-feature-importance
|
| 75 |
+
|
| 76 |
+
# also
|
| 77 |
+
# you can see what path a data point takes through the forest
|
| 78 |
+
# with RandomForestClassifier.decision_path()
|
| 79 |
+
# which might be really cool
|
| 80 |
+
# to see like 10 trees and the path through each tree and what each tree predicted
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == '__main__':
|
| 84 |
+
main()
|
| 85 |
+
|
| 86 |
+
|
viz_classifier.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import plotly.express as px
|
| 7 |
+
|
| 8 |
+
hgb = joblib.load('hgb_classifier.joblib')
|
| 9 |
+
FEATS = [
|
| 10 |
+
'srcip',
|
| 11 |
+
'sport',
|
| 12 |
+
'dstip',
|
| 13 |
+
'dsport',
|
| 14 |
+
'proto',
|
| 15 |
+
#'state', I dropped this one when I trained the model
|
| 16 |
+
'dur',
|
| 17 |
+
'sbytes',
|
| 18 |
+
'dbytes',
|
| 19 |
+
'sttl',
|
| 20 |
+
'dttl',
|
| 21 |
+
'sloss',
|
| 22 |
+
'dloss',
|
| 23 |
+
'service',
|
| 24 |
+
'Sload',
|
| 25 |
+
'Dload',
|
| 26 |
+
'Spkts',
|
| 27 |
+
'Dpkts',
|
| 28 |
+
'swin',
|
| 29 |
+
'dwin',
|
| 30 |
+
'stcpb',
|
| 31 |
+
'dtcpb',
|
| 32 |
+
'smeansz',
|
| 33 |
+
'dmeansz',
|
| 34 |
+
'trans_depth',
|
| 35 |
+
'res_bdy_len',
|
| 36 |
+
'Sjit',
|
| 37 |
+
'Djit',
|
| 38 |
+
'Stime',
|
| 39 |
+
'Ltime',
|
| 40 |
+
'Sintpkt',
|
| 41 |
+
'Dintpkt',
|
| 42 |
+
'tcprtt',
|
| 43 |
+
'synack',
|
| 44 |
+
'ackdat',
|
| 45 |
+
'is_sm_ips_ports',
|
| 46 |
+
'ct_state_ttl',
|
| 47 |
+
'ct_flw_http_mthd',
|
| 48 |
+
'is_ftp_login',
|
| 49 |
+
'ct_ftp_cmd',
|
| 50 |
+
'ct_srv_src',
|
| 51 |
+
'ct_srv_dst',
|
| 52 |
+
'ct_dst_ltm',
|
| 53 |
+
'ct_src_ltm',
|
| 54 |
+
'ct_src_dport_ltm',
|
| 55 |
+
'ct_dst_sport_ltm',
|
| 56 |
+
'ct_dst_src_ltm',
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
# plotly only has the CSS named colors
|
| 60 |
+
# I don't think I can use xkcd colors
|
| 61 |
+
# I copied a bunch of CSS colors from somewhere online
|
| 62 |
+
# and then deleted whites and things that showed up too close on the tree
|
| 63 |
+
# this is not really a general solution, it just works for this specific tree
|
| 64 |
+
# I'll have to come up with a better colormap at some point
|
| 65 |
+
COLORS = [
|
| 66 |
+
'aliceblue','aqua','aquamarine','azure',
|
| 67 |
+
'bisque','black','blanchedalmond','blue',
|
| 68 |
+
'blueviolet','brown','burlywood','cadetblue',
|
| 69 |
+
'chartreuse','chocolate','coral','cornflowerblue',
|
| 70 |
+
'cornsilk','crimson','cyan','darkblue','darkcyan',
|
| 71 |
+
'darkgoldenrod','darkgray','darkgreen',
|
| 72 |
+
'darkkhaki','darkmagenta','darkolivegreen','darkorange',
|
| 73 |
+
'darkorchid','darkred','darksalmon','darkseagreen',
|
| 74 |
+
'darkslateblue','darkslategray',
|
| 75 |
+
'darkturquoise','darkviolet','deeppink','deepskyblue',
|
| 76 |
+
'dimgray','dodgerblue',
|
| 77 |
+
'forestgreen','fuchsia','gainsboro',
|
| 78 |
+
'gold','goldenrod','gray','green',
|
| 79 |
+
'greenyellow','honeydew','hotpink','indianred','indigo',
|
| 80 |
+
'ivory','khaki','lavender','lavenderblush','lawngreen',
|
| 81 |
+
'lemonchiffon','lightblue','lightcoral','lightcyan',
|
| 82 |
+
'lightgoldenrodyellow','lightgray',
|
| 83 |
+
'lightgreen','lightpink','lightsalmon','lightseagreen',
|
| 84 |
+
'lightskyblue','lightslategray',
|
| 85 |
+
'lightsteelblue','lightyellow','lime','limegreen',
|
| 86 |
+
'linen','magenta','maroon','mediumaquamarine',
|
| 87 |
+
'mediumblue','mediumorchid','mediumpurple',
|
| 88 |
+
'mediumseagreen','mediumslateblue','mediumspringgreen',
|
| 89 |
+
'mediumturquoise','mediumvioletred','midnightblue',
|
| 90 |
+
'mintcream','mistyrose','moccasin','navy',
|
| 91 |
+
'oldlace','olive','olivedrab','orange','orangered',
|
| 92 |
+
'orchid','palegoldenrod','palegreen','paleturquoise',
|
| 93 |
+
'palevioletred','papayawhip','peachpuff','peru','pink',
|
| 94 |
+
'plum','powderblue','purple','red','rosybrown',
|
| 95 |
+
'royalblue','saddlebrown','salmon','sandybrown',
|
| 96 |
+
'seagreen','seashell','sienna','silver','skyblue',
|
| 97 |
+
'slateblue','slategray','slategrey','snow','springgreen',
|
| 98 |
+
'steelblue','tan','teal','thistle','tomato','turquoise',
|
| 99 |
+
'violet','wheat','yellow','yellowgreen'
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
trees = [x[0].nodes for x in hgb._predictors]
|
| 103 |
+
|
| 104 |
+
# the final tree definitely has a similar structure but is noticably different
|
| 105 |
+
# that's really cool
|
| 106 |
+
# I think this will make a cool animation
|
| 107 |
+
# if I can figure it out
|
| 108 |
+
tree = pd.DataFrame(trees[0])
|
| 109 |
+
#tree = pd.DataFrame(trees[9])
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# parents is going to be tricky
|
| 114 |
+
# I need get the index of whichever node has the current node listed in either left or right
|
| 115 |
+
|
| 116 |
+
parents = [None]
|
| 117 |
+
# keep track of whether each node is a left or right child of the parent in the list
|
| 118 |
+
directions = [None]
|
| 119 |
+
# it uses 0 to say "no left/right child"
|
| 120 |
+
# so I have to skip searching for node 0
|
| 121 |
+
# which is fine b/c node 0 is the root
|
| 122 |
+
for i in tree.index[1:]:
|
| 123 |
+
# it seems to make a very even tree
|
| 124 |
+
# so just guess it's in the right side
|
| 125 |
+
# and that will be right half the time
|
| 126 |
+
parent = tree[tree['right']==i].index
|
| 127 |
+
if parent.empty:
|
| 128 |
+
parents.append(str(tree[tree['left']==i].index[0]))
|
| 129 |
+
directions.append('l')
|
| 130 |
+
else:
|
| 131 |
+
parents.append(str(parent[0]))
|
| 132 |
+
directions.append('r')
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# generate the labels
|
| 136 |
+
# and the colors
|
| 137 |
+
labels = ['Histogram Gradient-Boosted Decision Tree']
|
| 138 |
+
colors = ['white']
|
| 139 |
+
for i, node, parent, direction in zip(
|
| 140 |
+
tree.index.to_numpy(),
|
| 141 |
+
tree.iterrows(),
|
| 142 |
+
parents,
|
| 143 |
+
directions
|
| 144 |
+
):
|
| 145 |
+
# skip the first one (the root)
|
| 146 |
+
if i == 0:
|
| 147 |
+
continue
|
| 148 |
+
node = node[1]
|
| 149 |
+
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
|
| 150 |
+
thresh = tree.loc[int(parent), 'num_threshold']
|
| 151 |
+
if direction == 'l':
|
| 152 |
+
labels.append(f"[{i}] {feat} <= {thresh}")
|
| 153 |
+
else:
|
| 154 |
+
labels.append(f"[{i}] {feat} > {thresh}")
|
| 155 |
+
|
| 156 |
+
# colors
|
| 157 |
+
offset = FEATS.index(feat)
|
| 158 |
+
colors.append(COLORS[offset])
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# actual plot
|
| 162 |
+
f = go.Figure(
|
| 163 |
+
go.Treemap(
|
| 164 |
+
values=tree['count'].to_numpy(),
|
| 165 |
+
labels=labels,
|
| 166 |
+
ids=tree.index.to_numpy(),
|
| 167 |
+
parents=parents,
|
| 168 |
+
marker_colors=colors,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
#f.update_layout(
|
| 173 |
+
# treemapcolorway = ['pink']
|
| 174 |
+
#)
|
| 175 |
+
|
| 176 |
+
breakpoint()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# converting the ndarry with columns names to a pandas df
|
| 180 |
+
# 3284 bytes as an ndarry
|
| 181 |
+
# 3300 bytes as a dataframe
|
| 182 |
+
# so they're the same size
|
| 183 |
+
# do I need to convert it to pandas? idk
|
| 184 |
+
# just curious
|
| 185 |
+
|
| 186 |
+
# https://linuxtut.com/en/ffb2e319db5545965933/
|
| 187 |
+
|
| 188 |
+
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
|
| 189 |
+
# figuring out how the thing works
|
| 190 |
+
|
| 191 |
+
# `value` is the predicted class / value / whatever
|
| 192 |
+
# so if it's a leaf node, it returns that value as the prediction
|
| 193 |
+
# there are negative values in some of the leaves
|
| 194 |
+
# maybe the classes are +/-1 instead of 0/1?
|
| 195 |
+
|
| 196 |
+
# if the data value is <= `num_threshold` then it goes in the left node
|
| 197 |
+
# if it's > `num_threshold` then it goes in the right node
|
| 198 |
+
|
| 199 |
+
# okay and then all the leave have feature_idx=0, num_threshold=0, left=0, right=0
|
| 200 |
+
# that makes sense
|
| 201 |
+
# still kind of annoying that they use 0 instead of np.nan but oh well
|
| 202 |
+
|
| 203 |
+
# also super super hard to figure out what the labels on the tree map should be
|
| 204 |
+
# like it has to check the parent's feature_idx and num_threshold
|
| 205 |
+
# which I guess isn't too bad once we have the list of parents already built
|
| 206 |
+
# except that I don't know whether a node is left or right from its parent
|
| 207 |
+
# hmmmm
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|