File size: 7,902 Bytes
4fd5d51 3069a62 4d7a703 ef4d0d3 64a5764 b52cbbd 018207f 3e26956 64a5764 51b97c4 cf617ec 51b97c4 ab072bd 51b97c4 505126b 51b97c4 a2fe1ac f6ed366 076f58d 51b97c4 ab072bd 51b97c4 84a6547 d1fb6e6 78ed822 84a6547 51b97c4 84a6547 64a5764 3ba27a3 64a5764 95f6158 a48895c 95f6158 a48895c 8054117 64a5764 4c2ac72 3ba27a3 7ba55c8 4c2ac72 fec8171 3ba27a3 7ba55c8 4c2ac72 64a5764 076f58d 089826a 3e26956 2a21d97 9fd2556 822e61a 51b97c4 2a21d97 9fd2556 2a21d97 9fd2556 a48895c 734fc7e 9fd2556 9526432 64a5764 b52cbbd 64a5764 51fd64c 64a5764 4ee9ffb efe74e7 c3200ea 7e1d4d8 c3200ea e097964 4ee9ffb efe74e7 4ee9ffb 7e1d4d8 24b5e1f e097964 4ee9ffb 9218d53 734fc7e 64a5764 f40fcd8 eb5e924 64a5764 4ee9ffb 3e26956 4ee9ffb 07ccc15 3e26956 07ccc15 d1fb6e6 36ea705 07ccc15 3e26956 64a5764 0d90d80 b52cbbd f19061c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
st.set_page_config(layout="wide")
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
st.markdown(
"""
<style>
*{
padding:0;
margin:0;
}
.fixed-col {
position: fixed;
top: 4rem;
right: 0;
width: 30%;
padding-left: 0rem;
background: white;
z-index: 100;
}
body {
margin: 0;
padding: 0;
}
.maint {
margin: auto;
margin-bottom:1.5rem;
}
.centered-title {
text-align: center;
}
.scroller {
margin-top: 2rem; /* Adjust as necessary to avoid overlap */
}
</style>
""", unsafe_allow_html=True
)
margins_css = """
<style>
.main > div {
padding-left: 3rem;
padding-right:3rem;
padding-top:0.4rem;
}
</style>
"""
st.markdown(margins_css, unsafe_allow_html=True)
# Sample data for demonstration purposes
models = ['SSD300', 'SSD512', 'DETR']
pruning_methods = ['VIB Pruning','Transfer Pruning']
datasets = ['VOC','SPARK']
hyperparameters = {
'SSD300': {'Transfer Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-ITPCC-A', '-', '-', 120), ('SSD300-ITPCC-B', '-', '-', 120), ('SSD300-ITPCC-C', '-', '-', 120)],
'VIB Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-VIB-v1', "0.0001", 240, 100), ('SSD300-VIB-v2', "0.0002", 240,100)]},
'SSD512': {'Transfer Pruning': [('SSD512-ResNet50', '-', '-', 120), ('SSD512-ITPCC-A', '-', '-', 120), ('SSD512-ITPCC-B', '-', '-', 120), ('SSD512-ITPCC-C', '-', '-', 120)],
'VIB Pruning': [('SSD512-ResNet50', '-', '-', 120),('SSD512-VIB-v1', "0.0003", 200, 100)]},
'DETR': {'SPARK': [("DETR-baseline", "-", "-","-", 20), ("DETR-SPARK-A", "-","-", 30, 40), ("DETR-SPARK-B", "-","-", 30, 40)],
'VOC': [("DETR-baseline", "-", "-","-", 130), ("DETR-VOC-A", "0.0001","0.00001", 80, 200), ("DETR-VOC-B", "0.00005","0.0001", 80, 200)]},
}
results_data = {
'SSD300': {
'VIB Pruning':{'model':['SSD300-ResNet50','SSD300-VIB-v1','SSD300-VIB-v2'],'map': ["77.79", "78.71", "77.41"], 'flops': ["11.1", "5.04", "3.49"],'flopsd':['0.0%','54.55%','68.54%'], 'params': ["49.2", "19.84", "11.18"],'paramsd':['0.0%','59.68%','77.28%'],},
'Transfer Pruning':{'model':["SSD300-ResNet50",'SSD300-ITPCC-A','SSD300-ITPCC-B','SSD300-ITPCC-C'],'map': ["77.79", "77.86" , "77.06", "75.08"], 'flops': ["11.1", "6.85", "5.08", "3.38"],'flopsd':['0.0%','38.2%','54.2%',"69.5%"], 'params': ["49.2", "32.5", "25.7", "19.4"],'paramsd':['0.0%','33.94%','47.77%',"60.5%"]},
},
'SSD512': {
'VIB Pruning':{'model':["SSD512-ResNet50",'SSD512-VIB-v1'],'map': ["80.9","81.43"], 'flops': ["46.24", "9.73"],'flopsd':['0.0%','78.94%'], 'params': ["58.52","27.2"],'paramsd':['0.0%','53.42%'],},
'Transfer Pruning':{'model':["SSD512-ResNet50",'SSD512-ITPCC-A','SSD512-ITPCC-B','SSD512-ITPCC-C'],'map': ["80.9","81.05" , "80.45", "78.82"], 'flops': ["46.2", "31.42", "25.6", "20.1"],'flopsd':['0.0%','31.9%','44.6%',"56.5%"], 'params': ["58.5", "41.8", "35.0", "28.7"],'paramsd':['0.0%','28.5%','40.17%',"50.1%"],},
},
'DETR': {'SPARK':{'model':["DETR-baseline",'DETR-SPARK-A','DETR-SPARK-B'],'map': ["96.77", "94.5", "95.18"], 'flops': ["85", "56", "58"],'flopsd':['0.0%','34.1%','31.7%'], 'params': ["41.2", "23.3", "26.6"],'paramsd':['0.0%','47.3%','45.4%'],},
'VOC':{'model':["DETR-baseline",'DETR-VOC-A','DETR-VOC-B'],'map': ["79.34", "77.2", "78.0"], 'flops': ["85", "55", "60"],'flopsd':['0.0%','35.29%','29.41%'], 'params': ["41.2", "21.71", "22.47"],'paramsd':['0.0%','42.65%','35.5%'],}},
}
# Title of the research
st.markdown('<h1 class="centered-title">Variational Information bottleneck pruning for Object detection</h1>', unsafe_allow_html=True)# Create two columns with specified widths
col1, col2 = st.columns([5.2, 4.8])
# Right section: Filters and Hyperparameters
with col2:
st.markdown('<div class="fixed-col">', unsafe_allow_html=True)
st.subheader('Filters')
model = st.selectbox('Select model:', models)
if model in ['SSD300', 'SSD512']:
pruning = st.selectbox('Select pruning method:', pruning_methods)
hyperparameter_data = hyperparameters[model][pruning]
else:
dataset = st.selectbox('Select dataset:', datasets)
hyperparameter_data = hyperparameters[model][dataset]
st.markdown('<div class="scroller">', unsafe_allow_html=True)
st.subheader('Hyperparameters')
st.markdown('<br>', unsafe_allow_html=True) # Add space between filter and hyperparameters
if model in ['SSD300', 'SSD512']:
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl factor', 'Pruning Epochs', 'Finetuning Epochs'])
else:
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl backbone','kl transformer', 'Pruning Epochs', 'Finetuning Epochs'])
st.markdown(df_hyperparams.style.hide(axis="index").to_html(), unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# Left section: Results and Evolution Graphs
with col1:
st.subheader('Results')
# Display results table
results = results_data[model]
if model in ['SSD300', 'SSD512']:
df_results = pd.DataFrame({
'Model': results[pruning]['model'],
'mAP (%)': results[pruning]['map'],
'GFLOPs': results[pruning]['flops'],
'down by': results[pruning]['flopsd'],
'Parameters (M)': results[pruning]['params'],
'down by ': results[pruning]['paramsd'],
})
else:
df_results = pd.DataFrame({
'Model': results[dataset]['model'],
'mAP': results[dataset]['map'],
'FLOPs': results[dataset]['flops'],
'down by': results[dataset]['flopsd'],
'Params': results[dataset]['params'],
'down by ': results[dataset]['paramsd'],
})
st.markdown(df_results.style.hide(axis="index").to_html(), unsafe_allow_html=True)
# Display evolution graphs
st.markdown("<div style='margin-top: 15px;'></div>", unsafe_allow_html=True)
st.markdown(
"""
<h2 id="evolution-graphs" style="margin-bottom: 0px; padding-bottom: 0px;">
Evolution Graphs
</h2>
""",
unsafe_allow_html=True
)
epochs = [1, 2, 3, 4]
# fig = go.Figure()
# fig.add_trace(go.Scatter(x=epochs, y=results['map'], mode='lines+markers', name='mAP'))
# fig.update_layout(title='mAP per Epoch', xaxis_title='Epoch', yaxis_title='mAP')
# st.plotly_chart(fig, use_container_width=True)
fig = go.Figure()
if model in ['SSD300', 'SSD512']:
ff=pruning
else:
ff=dataset
# Add FLOPs bar trace
fig.add_trace(go.Bar(
x=results[ff]['model'],
y=results[ff]['flops'],
name='FLOPs',
marker_color='orange'
))
# Add Params bar trace
fig.add_trace(go.Bar(
x=results[ff]['model'],
y=results[ff]['params'],
name='Params',
marker_color='green'
))
# Update the layout
fig.update_layout(
barmode='group',
title='FLOPs and Params per model',
xaxis_title='Model',
yaxis_title='Count',
legend_title='Metric',
height=280,
width=500,
)
# Show the plot using Streamlit
st.plotly_chart(fig, use_container_width=True)
|