Refactor Streamlit app: improve data loading, error handling, and plot labeling

#3
Files changed (1) hide show
  1. app.py +81 -37
app.py CHANGED
@@ -12,45 +12,89 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import streamlit as st
16
- import json
17
  import pandas as pd
18
  import plotly.express as px
19
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- carbon_df= pd.read_pickle('./data/carbon_df.pkl')
22
- carbon_df.drop(carbon_df.loc[carbon_df['task']==''].index, inplace=True)
23
-
24
- st.set_page_config(
25
- page_title="Comparing the Carbon Footprint of Transformers",
26
- page_icon="./hf-earth.png",
27
- layout="wide",
28
- )
29
-
30
- st.title("Hugging Face Carbon Compare Tool")
31
-
32
- # Get the sidebar details
33
- with st.sidebar.expander("Models", expanded=True):
34
- st.image('./hf-earth.png')
35
- models=[]
36
- # choose a dataset to analyze
37
- for m in carbon_df['name'].items():
38
- try:
39
- modelname= m[1].split('/')[1]
40
- except:
41
- modelname = m[1]
42
- models.append(modelname)
43
- model_name = st.selectbox(
44
- f"Choose model to explore:",
45
- models)
46
-
47
- with st.expander("Model Comparison", expanded=False):
48
-
49
- st.markdown("### Here is how the model " + model_name + " compares to other models:")
50
- fig_model = px.bar(carbon_df.sort_values(by=['carbon']), x=models, y='carbon', hover_name= models, color_discrete_map = {model_name : 'red'})
51
- st.plotly_chart(fig_model, use_container_width=True)
52
-
53
- with st.expander("Task Comparison", expanded=False):
54
- fig = px.box(carbon_df, x=carbon_df['task'], y=carbon_df['carbon'], color='task', hover_name=carbon_df['name'])
55
- #fig.update_traces(quartilemethod="exclusive") # or "inclusive", or "linear" by default
56
- st.plotly_chart(fig, use_container_width=True)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+
16
  import streamlit as st
 
17
  import pandas as pd
18
  import plotly.express as px
19
  import numpy as np
20
+ import os
21
+
22
+ # --- Utility Functions ---
23
+ def load_carbon_data(path):
24
+ """Load carbon data from pickle file, handle errors gracefully."""
25
+ if not os.path.exists(path):
26
+ st.error(f"Data file not found: {path}")
27
+ return None
28
+ try:
29
+ df = pd.read_pickle(path)
30
+ return df
31
+ except Exception as e:
32
+ st.error(f"Failed to load data: {e}")
33
+ return None
34
+
35
+ # --- Load Data ---
36
+ carbon_df = load_carbon_data('./data/carbon_df.pkl')
37
+ if carbon_df is not None:
38
+ # Drop rows with missing task
39
+ carbon_df = carbon_df[carbon_df['task'].astype(str).str.strip() != '']
40
+
41
+ st.set_page_config(
42
+ page_title="Comparing the Carbon Footprint of Transformers",
43
+ page_icon="./hf-earth.png",
44
+ layout="wide",
45
+ )
46
+
47
+ st.title("Hugging Face Carbon Compare Tool")
48
+
49
+ # --- Sidebar: Model Selection ---
50
+ with st.sidebar.expander("Models", expanded=True):
51
+ st.image('./hf-earth.png')
52
+ models = []
53
+ model_full_names = []
54
+ for m in carbon_df['name'].items():
55
+ try:
56
+ modelname = m[1].split('/')[1]
57
+ except Exception:
58
+ modelname = m[1]
59
+ models.append(modelname)
60
+ model_full_names.append(m[1])
61
+ model_name = st.selectbox(
62
+ "Choose model to explore:",
63
+ models,
64
+ help="Select a model to compare its carbon footprint."
65
+ )
66
+
67
+ # --- Model Comparison ---
68
+ with st.expander("Model Comparison", expanded=False):
69
+ st.markdown(f"### How does **{model_name}** compare to other models?")
70
+ # Highlight the selected model in red, others in blue
71
+ color_map = {m: ('red' if m == model_name else '#1f77b4') for m in models}
72
+ sorted_df = carbon_df.copy()
73
+ sorted_df['short_name'] = models
74
+ sorted_df = sorted_df.sort_values(by=['carbon'])
75
+ fig_model = px.bar(
76
+ sorted_df,
77
+ x='short_name',
78
+ y='carbon',
79
+ hover_name='name',
80
+ color='short_name',
81
+ color_discrete_map=color_map,
82
+ labels={'short_name': 'Model', 'carbon': 'Carbon Footprint (kg CO₂e)'},
83
+ )
84
+ fig_model.update_layout(showlegend=False)
85
+ st.plotly_chart(fig_model, use_container_width=True)
86
 
87
+ # --- Task Comparison ---
88
+ with st.expander("Task Comparison", expanded=False):
89
+ fig = px.box(
90
+ carbon_df,
91
+ x='task',
92
+ y='carbon',
93
+ color='task',
94
+ hover_name='name',
95
+ labels={'task': 'Task', 'carbon': 'Carbon Footprint (kg CO₂e)'},
96
+ )
97
+ fig.update_layout(showlegend=False)
98
+ st.plotly_chart(fig, use_container_width=True)
99
+ else:
100
+ st.stop()