File size: 7,902 Bytes
4fd5d51
 
 
 
 
3069a62
4d7a703
ef4d0d3
64a5764
b52cbbd
 
 
018207f
 
 
3e26956
 
 
 
64a5764
51b97c4
 
 
 
 
cf617ec
 
 
 
51b97c4
 
ab072bd
51b97c4
 
505126b
51b97c4
 
 
a2fe1ac
 
 
 
f6ed366
 
 
 
076f58d
 
 
51b97c4
ab072bd
51b97c4
 
 
 
84a6547
 
 
d1fb6e6
 
78ed822
84a6547
 
 
51b97c4
84a6547
64a5764
 
3ba27a3
 
64a5764
 
95f6158
a48895c
95f6158
a48895c
8054117
 
64a5764
 
 
4c2ac72
 
3ba27a3
7ba55c8
4c2ac72
fec8171
3ba27a3
7ba55c8
4c2ac72
 
 
64a5764
 
 
076f58d
 
089826a
3e26956
2a21d97
9fd2556
822e61a
51b97c4
 
2a21d97
 
 
 
 
 
 
9fd2556
2a21d97
9fd2556
a48895c
 
 
 
734fc7e
9fd2556
 
9526432
64a5764
b52cbbd
64a5764
 
51fd64c
64a5764
4ee9ffb
 
efe74e7
c3200ea
 
7e1d4d8
c3200ea
e097964
4ee9ffb
 
 
efe74e7
4ee9ffb
 
7e1d4d8
24b5e1f
e097964
4ee9ffb
 
9218d53
734fc7e
64a5764
 
f40fcd8
eb5e924
 
 
 
 
 
 
 
64a5764
 
4ee9ffb
 
 
 
3e26956
 
4ee9ffb
 
 
 
07ccc15
3e26956
07ccc15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1fb6e6
36ea705
 
07ccc15
 
 
3e26956
64a5764
 
0d90d80
b52cbbd
f19061c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px

st.set_page_config(layout="wide")
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go



st.markdown(
    """
    <style>
    *{
    padding:0;
    margin:0;
    }
    .fixed-col {
        position: fixed;
        top: 4rem;
        right: 0;
        width: 30%;
        padding-left: 0rem;
        background: white;
        z-index: 100;
    }
    body {
        margin: 0;
        padding: 0;
    }
    .maint {
        margin: auto;
        margin-bottom:1.5rem;
    }
    .centered-title {
        text-align: center;
    }
    .scroller {
        margin-top: 2rem; /* Adjust as necessary to avoid overlap */
    }
    </style>
    """, unsafe_allow_html=True
)
margins_css = """
    <style>
        .main > div {
            padding-left: 3rem;
            padding-right:3rem;
            padding-top:0.4rem;
        }
    </style>
"""

st.markdown(margins_css, unsafe_allow_html=True)
# Sample data for demonstration purposes
models = ['SSD300', 'SSD512', 'DETR']
pruning_methods = ['VIB Pruning','Transfer Pruning']
datasets = ['VOC','SPARK']

hyperparameters = {
    'SSD300': {'Transfer Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-ITPCC-A', '-', '-', 120), ('SSD300-ITPCC-B', '-', '-', 120), ('SSD300-ITPCC-C', '-', '-', 120)],
               'VIB Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-VIB-v1', "0.0001", 240, 100), ('SSD300-VIB-v2', "0.0002", 240,100)]},
    'SSD512': {'Transfer Pruning': [('SSD512-ResNet50', '-', '-', 120), ('SSD512-ITPCC-A', '-', '-', 120), ('SSD512-ITPCC-B', '-', '-', 120), ('SSD512-ITPCC-C', '-', '-', 120)],
               'VIB Pruning': [('SSD512-ResNet50', '-', '-', 120),('SSD512-VIB-v1', "0.0003", 200, 100)]},
    'DETR': {'SPARK': [("DETR-baseline", "-", "-","-", 20), ("DETR-SPARK-A", "-","-", 30, 40), ("DETR-SPARK-B", "-","-", 30, 40)],
             'VOC': [("DETR-baseline", "-", "-","-", 130), ("DETR-VOC-A", "0.0001","0.00001", 80, 200), ("DETR-VOC-B", "0.00005","0.0001", 80, 200)]},
}

results_data = {
    'SSD300': {
        'VIB Pruning':{'model':['SSD300-ResNet50','SSD300-VIB-v1','SSD300-VIB-v2'],'map': ["77.79", "78.71", "77.41"], 'flops': ["11.1", "5.04", "3.49"],'flopsd':['0.0%','54.55%','68.54%'], 'params': ["49.2", "19.84", "11.18"],'paramsd':['0.0%','59.68%','77.28%'],},
        'Transfer Pruning':{'model':["SSD300-ResNet50",'SSD300-ITPCC-A','SSD300-ITPCC-B','SSD300-ITPCC-C'],'map': ["77.79", "77.86" , "77.06", "75.08"], 'flops': ["11.1", "6.85", "5.08", "3.38"],'flopsd':['0.0%','38.2%','54.2%',"69.5%"], 'params': ["49.2", "32.5", "25.7", "19.4"],'paramsd':['0.0%','33.94%','47.77%',"60.5%"]},
},
    'SSD512': {
        'VIB Pruning':{'model':["SSD512-ResNet50",'SSD512-VIB-v1'],'map': ["80.9","81.43"], 'flops': ["46.24", "9.73"],'flopsd':['0.0%','78.94%'], 'params': ["58.52","27.2"],'paramsd':['0.0%','53.42%'],},
        'Transfer Pruning':{'model':["SSD512-ResNet50",'SSD512-ITPCC-A','SSD512-ITPCC-B','SSD512-ITPCC-C'],'map': ["80.9","81.05" , "80.45", "78.82"], 'flops': ["46.2", "31.42", "25.6", "20.1"],'flopsd':['0.0%','31.9%','44.6%',"56.5%"], 'params': ["58.5", "41.8", "35.0", "28.7"],'paramsd':['0.0%','28.5%','40.17%',"50.1%"],},
    },
              
    'DETR': {'SPARK':{'model':["DETR-baseline",'DETR-SPARK-A','DETR-SPARK-B'],'map': ["96.77", "94.5", "95.18"], 'flops': ["85", "56", "58"],'flopsd':['0.0%','34.1%','31.7%'], 'params': ["41.2", "23.3", "26.6"],'paramsd':['0.0%','47.3%','45.4%'],},
              'VOC':{'model':["DETR-baseline",'DETR-VOC-A','DETR-VOC-B'],'map': ["79.34", "77.2", "78.0"], 'flops': ["85", "55", "60"],'flopsd':['0.0%','35.29%','29.41%'], 'params': ["41.2", "21.71", "22.47"],'paramsd':['0.0%','42.65%','35.5%'],}},
}

# Title of the research

st.markdown('<h1 class="centered-title">Variational Information bottleneck pruning for Object detection</h1>', unsafe_allow_html=True)# Create two columns with specified widths
col1, col2 = st.columns([5.2, 4.8])
# Right section: Filters and Hyperparameters
with col2:
    st.markdown('<div class="fixed-col">', unsafe_allow_html=True)
    st.subheader('Filters')
    model = st.selectbox('Select model:', models)
    
    if model in ['SSD300', 'SSD512']:
        pruning = st.selectbox('Select pruning method:', pruning_methods)
        hyperparameter_data = hyperparameters[model][pruning]
    else:
        dataset = st.selectbox('Select dataset:', datasets)
        hyperparameter_data = hyperparameters[model][dataset]

    st.markdown('<div class="scroller">', unsafe_allow_html=True)
    st.subheader('Hyperparameters')
    st.markdown('<br>', unsafe_allow_html=True)  # Add space between filter and hyperparameters
    if model in ['SSD300', 'SSD512']:
        df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl factor', 'Pruning Epochs', 'Finetuning Epochs'])
    else: 
        df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl backbone','kl transformer', 'Pruning Epochs', 'Finetuning Epochs'])
    st.markdown(df_hyperparams.style.hide(axis="index").to_html(), unsafe_allow_html=True)
    st.markdown('</div>', unsafe_allow_html=True)
    st.markdown('</div>', unsafe_allow_html=True)

# Left section: Results and Evolution Graphs
with col1:
    st.subheader('Results')
    # Display results table
    
    results = results_data[model]
    if model in ['SSD300', 'SSD512']:
        df_results = pd.DataFrame({
            'Model': results[pruning]['model'],
            'mAP (%)': results[pruning]['map'],
            'GFLOPs': results[pruning]['flops'],
            'down by': results[pruning]['flopsd'],
            'Parameters (M)': results[pruning]['params'],
            'down by ': results[pruning]['paramsd'],
        })
    else:
        df_results = pd.DataFrame({
            'Model': results[dataset]['model'],
            'mAP': results[dataset]['map'],
            'FLOPs': results[dataset]['flops'],
            'down by': results[dataset]['flopsd'],
            'Params': results[dataset]['params'],
            'down by ': results[dataset]['paramsd'],
        })
        
    
    st.markdown(df_results.style.hide(axis="index").to_html(), unsafe_allow_html=True)

    # Display evolution graphs
    st.markdown("<div style='margin-top: 15px;'></div>", unsafe_allow_html=True)
    st.markdown(
    """
    <h2 id="evolution-graphs" style="margin-bottom: 0px; padding-bottom: 0px;">
    Evolution Graphs
    </h2>
    """, 
    unsafe_allow_html=True
    )
    epochs = [1, 2, 3, 4]

    # fig = go.Figure()
    # fig.add_trace(go.Scatter(x=epochs, y=results['map'], mode='lines+markers', name='mAP'))
    # fig.update_layout(title='mAP per Epoch', xaxis_title='Epoch', yaxis_title='mAP')
    # st.plotly_chart(fig, use_container_width=True)

    fig = go.Figure()
    if model in ['SSD300', 'SSD512']:
        ff=pruning
    else:
        ff=dataset
    

    # Add FLOPs bar trace
    fig.add_trace(go.Bar(
        x=results[ff]['model'],
        y=results[ff]['flops'],
        name='FLOPs',
        marker_color='orange'
    ))
    
    # Add Params bar trace
    fig.add_trace(go.Bar(
        x=results[ff]['model'],
        y=results[ff]['params'],
        name='Params',
        marker_color='green'
    ))
    
    # Update the layout
    fig.update_layout(
        barmode='group',
        title='FLOPs and Params per model',
        xaxis_title='Model',
        yaxis_title='Count',
        legend_title='Metric',
        height=280,
        width=500,
    )
    
    # Show the plot using Streamlit
    st.plotly_chart(fig, use_container_width=True)