Update src/streamlit_app.py
Browse files- src/streamlit_app.py +66 -15
src/streamlit_app.py
CHANGED
|
@@ -151,23 +151,74 @@ def create_bar_chart(df, view_type):
|
|
| 151 |
"""Create interactive bar chart based on view type"""
|
| 152 |
|
| 153 |
if view_type == "Total Score":
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
fig.update_layout(
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
elif view_type == "Per Embodiment":
|
| 173 |
embodiment_cols = [col for col in df.columns if col.startswith('Embodiment-')]
|
|
|
|
| 151 |
"""Create interactive bar chart based on view type"""
|
| 152 |
|
| 153 |
if view_type == "Total Score":
|
| 154 |
+
|
| 155 |
+
# Format df
|
| 156 |
+
df_fig = df.copy()
|
| 157 |
+
df_fig["Models"] = df_fig["model"].str.replace('_', ' ')
|
| 158 |
+
df_fig = df_fig[df_fig["score"] != np.inf]
|
| 159 |
+
|
| 160 |
+
# Calculate mean score per model
|
| 161 |
+
df_fig = df_fig.groupby("Models")[["score"]].mean().reset_index()
|
| 162 |
+
|
| 163 |
+
# Sort the results from best to worst
|
| 164 |
+
df_fig = df_fig.sort_values(by="score", ascending=True)
|
| 165 |
+
|
| 166 |
+
# Create the Plotly figure using Plotly Express, now plotting only the 'frechet' score.
|
| 167 |
+
fig = px.bar(
|
| 168 |
+
model_scores,
|
| 169 |
+
x="Models",
|
| 170 |
+
y="score",
|
| 171 |
+
color="score",
|
| 172 |
+
color_continuous_scale=px.colors.diverging.Fall,
|
| 173 |
+
template="plotly_white",
|
| 174 |
+
orientation="v",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
fig.update_layout(
|
| 178 |
+
xaxis_title_text="Model",
|
| 179 |
+
yaxis_title_text="Score (Lower is better)",
|
| 180 |
+
title_text="",
|
| 181 |
+
font=dict(size=15, color="black"),
|
| 182 |
+
xaxis_tickangle=-45,
|
| 183 |
+
bargap=0.2, # Increase gap for slimmer bars
|
| 184 |
+
height=500, # Set the height of the plot
|
| 185 |
+
margin=dict(
|
| 186 |
+
l=0, # Left
|
| 187 |
+
r=0, # Right
|
| 188 |
+
b=0, # Bottom
|
| 189 |
+
t=5, # Top
|
| 190 |
+
pad=0 # Padding
|
| 191 |
+
),
|
| 192 |
)
|
| 193 |
+
|
| 194 |
+
# Remove the color legend from the chart.
|
| 195 |
+
fig.update_coloraxes(showscale=False)
|
| 196 |
+
|
| 197 |
+
# Add annotations to show the exact score on each bar.
|
| 198 |
+
fig.update_traces(
|
| 199 |
+
texttemplate=""%{y:.2f}",
|
| 200 |
+
textposition="outside"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# TODO
|
| 205 |
+
# fig = go.Figure(data=[
|
| 206 |
+
# go.Bar(
|
| 207 |
+
# x=df['Model'],
|
| 208 |
+
# y=df['Total Score'],
|
| 209 |
+
# orientation='v',
|
| 210 |
+
# marker_color=px.colors.sequential.Blues,
|
| 211 |
+
# text=df['Total Score'].round(1),
|
| 212 |
+
# textposition='outside',
|
| 213 |
+
# )
|
| 214 |
+
# ])
|
| 215 |
+
# fig.update_layout(
|
| 216 |
+
# title="Model Performance - Total Score",
|
| 217 |
+
# xaxis_title="Model",
|
| 218 |
+
# yaxis_title="Score",
|
| 219 |
+
# yaxis_range=[0, 100],
|
| 220 |
+
# height=500,
|
| 221 |
+
# )
|
| 222 |
|
| 223 |
elif view_type == "Per Embodiment":
|
| 224 |
embodiment_cols = [col for col in df.columns if col.startswith('Embodiment-')]
|