emsesc commited on
Commit
b466419
ยท
1 Parent(s): 77e9502

refactor code

Browse files
app.py CHANGED
@@ -1,9 +1,7 @@
1
- # Import packages
2
  from dash import Dash, html, dcc, Input, Output
3
  import pandas as pd
4
- import plotly.express as px
5
- from graphs.model_market_share import create_plotly_stacked_area_chart, create_plotly_world_map, create_plotly_range_slider, create_leaderboard
6
- from graphs.model_characteristics import create_plotly_language_concentration_chart, create_plotly_publication_curves_with_legend
7
 
8
  # Initialize the app
9
  app = Dash()
@@ -22,7 +20,7 @@ country_concentration_df = pd.read_pickle("data_frames/country_concentration_df.
22
  author_concentration_df = pd.read_pickle("data_frames/author_concentration_df.pkl")
23
  model_concentration_df = pd.read_pickle("data_frames/model_concentration_df.pkl")
24
 
25
-
26
  TEMP_MODEL_EVENTS = {
27
  # "Yolo World Mirror": "2024-03-01",
28
  "Llama 3": "2024-04-17",
@@ -50,25 +48,15 @@ PALETTE_0 = [
50
  "#540B0E"
51
  ]
52
 
53
- fig = create_plotly_stacked_area_chart(
54
- model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0
55
- )
56
-
57
  LANG_SEGMENT_ORDER = [
58
  'Monolingual: EN', 'Monolingual: HR', 'Monolingual: M/LR',
59
  'Multilingual: HR', 'Multilingual', 'Unknown',
60
  ]
61
- fig2 = create_plotly_language_concentration_chart(
62
- language_concentration_df, 'time', 'metric', 'value', LANG_SEGMENT_ORDER, PALETTE_0
63
- )
64
 
65
  LICENSE_SEGMENT_ORDER = [
66
  "Open Use", "Open Use (Acceptable Use Policy)", "Open Use (Non-Commercial Only)", "Attribution",
67
  "Acceptable Use Policy", "Non-Commercial Only", "Undocumented", "Undocumented (Acceptable Use Policy)",
68
  ]
69
- fig3 = create_plotly_language_concentration_chart(
70
- license_concentration_df, 'period', 'status', 'percent', LICENSE_SEGMENT_ORDER, PALETTE_0
71
- )
72
 
73
  METHOD_PLOT_CHOICES = {
74
  "cumulative": "none", # none, mean, sum
@@ -76,9 +64,6 @@ METHOD_PLOT_CHOICES = {
76
  "y_log": False, # True, False
77
  "period": "W",
78
  }
79
- fig4 = create_plotly_publication_curves_with_legend(
80
- download_method_cumsum_df, METHOD_PLOT_CHOICES, PALETTE_0
81
- )
82
 
83
  ARCHITECTURE_PLOT_CHOICES = {
84
  "cumulative": "none", # none, mean, sum
@@ -86,35 +71,41 @@ ARCHITECTURE_PLOT_CHOICES = {
86
  "y_log": False, # True, False
87
  "period": "W",
88
  }
89
- fig5 = create_plotly_publication_curves_with_legend(
90
- download_arch_cumsum_df, ARCHITECTURE_PLOT_CHOICES, PALETTE_0
 
 
 
91
  )
92
 
93
- fig6 = create_plotly_world_map(
94
  country_concentration_df, "time", "metric", "value"
95
  )
96
 
97
- fig7 = create_leaderboard(
98
  country_concentration_df, author_concentration_df, model_concentration_df
99
  )
100
 
101
- slider = create_plotly_range_slider(
102
  model_topk_df
103
  )
104
 
105
- slider2 = create_plotly_range_slider(
106
- country_concentration_df
 
 
 
 
 
 
 
 
 
107
  )
108
 
109
- # Make global font family
110
- fig.update_layout(font_family="Inter")
111
- fig2.update_layout(font_family="Inter")
112
- fig3.update_layout(font_family="Inter")
113
- fig4.update_layout(font_family="Inter")
114
- fig5.update_layout(font_family="Inter")
115
- fig6.update_layout(font_family="Inter")
116
- slider.update_layout(font_family="Inter")
117
- slider2.update_layout(font_family="Inter")
118
 
119
  # App layout
120
  app.layout = html.Div(
@@ -123,7 +114,34 @@ app.layout = html.Div(
123
  [
124
  html.Div(children='Visualizing the Open Model Ecosystem', style={'fontSize': 28, 'fontWeight': 'bold', 'marginBottom': 6}),
125
  html.Div(children='An interactive dashboard to explore trends in open models on Hugging Face', style={'fontSize': 16, 'marginBottom': 12}),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  html.Hr(style={'marginTop': 8, 'marginBottom': 8}),
 
127
  ],
128
  style={'textAlign': 'center'}
129
  ),
@@ -133,7 +151,7 @@ app.layout = html.Div(
133
  dcc.Tab(label='Model Market Share', children=[
134
  html.Div([
135
  html.Div(children='Select time range to update all graphs below:', style={'fontSize': 16, 'marginBottom': 6, 'marginTop': 10}),
136
- dcc.Graph(figure=slider2, id='time-slider', style={'height': '100px'}),
137
  html.Div(
138
  id='output-container-range-slider',
139
  style={
@@ -166,6 +184,8 @@ app.layout = html.Div(
166
  dcc.Dropdown(['Language Concentration', 'Architecture', 'License', 'Method'], 'Language Concentration', id='dropdown'),
167
  ], style={'marginTop': 6}),
168
  ]),
 
 
169
  ])
170
  ],
171
  style={
@@ -181,6 +201,10 @@ app.layout = html.Div(
181
  style={'fontFamily': 'Inter', 'backgroundColor': '#f7f7fa', 'minHeight': '100vh'}
182
  )
183
 
 
 
 
 
184
  @app.callback(
185
  Output('output-container-range-slider', 'children'),
186
  [Input('time-slider', 'relayoutData')]
@@ -192,22 +216,8 @@ def update_output(relayout_data):
192
  return f'Selected time range: {start_time} to {end_time}'
193
  else:
194
  return 'Selected time range: All data'
195
-
196
- # On dropdown change, update graph
197
- @app.callback(
198
- Output('language-concentration-chart', 'figure'),
199
- [Input('dropdown', 'value')]
200
- )
201
- def update_graph(selected_metric):
202
- if selected_metric == 'Language Concentration':
203
- return fig2
204
- elif selected_metric == 'License':
205
- return fig3
206
- elif selected_metric == 'Method':
207
- return fig4
208
- elif selected_metric == 'Architecture':
209
- return fig5
210
-
211
  @app.callback(
212
  Output('world-map-with-slider', 'figure'),
213
  [Input('time-slider', 'relayoutData')]
@@ -216,14 +226,15 @@ def update_map(relayout_data):
216
  if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
217
  start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
218
  end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
219
- updated_fig = create_plotly_world_map(
220
  country_concentration_df, "time", "metric", "value", start_time=start_time, end_time=end_time
221
  )
222
  updated_fig.update_layout(font_family="Inter")
223
  return updated_fig
224
  else:
225
- return fig6
226
 
 
227
  @app.callback(
228
  Output('leaderboard', 'figure'),
229
  [Input('time-slider', 'relayoutData')]
@@ -238,8 +249,9 @@ def update_leaderboard(relayout_data):
238
  updated_fig.update_layout(font_family="Inter")
239
  return updated_fig
240
  else:
241
- return fig7
242
-
 
243
  @app.callback(
244
  Output('stacked-area-chart', 'figure'),
245
  [Input('time-slider', 'relayoutData')]
@@ -248,14 +260,30 @@ def update_stacked_area(relayout_data):
248
  if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
249
  start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
250
  end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
251
- updated_fig = create_plotly_stacked_area_chart(
252
  model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0,
253
  start_time=start_time, end_time=end_time
254
  )
255
  updated_fig.update_layout(font_family="Inter")
256
  return updated_fig
257
  else:
258
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # Run the app
261
  if __name__ == '__main__':
 
 
1
  from dash import Dash, html, dcc, Input, Output
2
  import pandas as pd
3
+ from graphs.model_market_share import create_stacked_area_chart, create_world_map, create_range_slider, create_leaderboard
4
+ from graphs.model_characteristics import create_concentration_chart, create_line_plot
 
5
 
6
  # Initialize the app
7
  app = Dash()
 
20
  author_concentration_df = pd.read_pickle("data_frames/author_concentration_df.pkl")
21
  model_concentration_df = pd.read_pickle("data_frames/model_concentration_df.pkl")
22
 
23
+ # Configurations
24
  TEMP_MODEL_EVENTS = {
25
  # "Yolo World Mirror": "2024-03-01",
26
  "Llama 3": "2024-04-17",
 
48
  "#540B0E"
49
  ]
50
 
 
 
 
 
51
  LANG_SEGMENT_ORDER = [
52
  'Monolingual: EN', 'Monolingual: HR', 'Monolingual: M/LR',
53
  'Multilingual: HR', 'Multilingual', 'Unknown',
54
  ]
 
 
 
55
 
56
  LICENSE_SEGMENT_ORDER = [
57
  "Open Use", "Open Use (Acceptable Use Policy)", "Open Use (Non-Commercial Only)", "Attribution",
58
  "Acceptable Use Policy", "Non-Commercial Only", "Undocumented", "Undocumented (Acceptable Use Policy)",
59
  ]
 
 
 
60
 
61
  METHOD_PLOT_CHOICES = {
62
  "cumulative": "none", # none, mean, sum
 
64
  "y_log": False, # True, False
65
  "period": "W",
66
  }
 
 
 
67
 
68
  ARCHITECTURE_PLOT_CHOICES = {
69
  "cumulative": "none", # none, mean, sum
 
71
  "y_log": False, # True, False
72
  "period": "W",
73
  }
74
+
75
+ # Create initial figures
76
+ # Model Market Share Tab
77
+ model_market_share_area = create_stacked_area_chart(
78
+ model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0
79
  )
80
 
81
+ world_map = create_world_map(
82
  country_concentration_df, "time", "metric", "value"
83
  )
84
 
85
+ leaderboard = create_leaderboard(
86
  country_concentration_df, author_concentration_df, model_concentration_df
87
  )
88
 
89
+ slider = create_range_slider(
90
  model_topk_df
91
  )
92
 
93
+ # Model Characteristics Tab
94
+ language_concentration_area = create_concentration_chart(
95
+ language_concentration_df, 'time', 'metric', 'value', LANG_SEGMENT_ORDER, PALETTE_0
96
+ )
97
+
98
+ license_concentration_area = create_concentration_chart(
99
+ license_concentration_df, 'period', 'status', 'percent', LICENSE_SEGMENT_ORDER, PALETTE_0
100
+ )
101
+
102
+ download_method_cumsum_line = create_line_plot(
103
+ download_method_cumsum_df, METHOD_PLOT_CHOICES, PALETTE_0
104
  )
105
 
106
+ download_arch_cumsum_line = create_line_plot(
107
+ download_arch_cumsum_df, ARCHITECTURE_PLOT_CHOICES, PALETTE_0
108
+ )
 
 
 
 
 
 
109
 
110
  # App layout
111
  app.layout = html.Div(
 
114
  [
115
  html.Div(children='Visualizing the Open Model Ecosystem', style={'fontSize': 28, 'fontWeight': 'bold', 'marginBottom': 6}),
116
  html.Div(children='An interactive dashboard to explore trends in open models on Hugging Face', style={'fontSize': 16, 'marginBottom': 12}),
117
+ html.Div(
118
+ children=[
119
+ html.A(
120
+ "Data Provenance Initiative",
121
+ href="https://www.dataprovenance.org/",
122
+ target="_blank",
123
+ style={
124
+ 'display': 'inline-block',
125
+ 'padding': '4px 14px',
126
+ 'fontSize': 13,
127
+ 'color': 'white',
128
+ 'backgroundColor': '#2563eb',
129
+ 'border': 'none',
130
+ 'borderRadius': '18px',
131
+ 'textDecoration': 'none',
132
+ 'fontWeight': 'bold',
133
+ 'boxShadow': '0 2px 8px rgba(37,99,235,0.08)',
134
+ 'marginLeft': '6px',
135
+ 'marginBottom': '4px',
136
+ 'transition': 'background 0.2s',
137
+ 'cursor': 'pointer'
138
+ }
139
+ )
140
+ ],
141
+ style={'fontSize': 14, 'marginBottom': 12}
142
+ ),
143
  html.Hr(style={'marginTop': 8, 'marginBottom': 8}),
144
+ html.Div(children='Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry\'s standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.', style={'fontSize': 14, 'marginBottom': 12, 'marginLeft': 100, 'marginRight': 100}),
145
  ],
146
  style={'textAlign': 'center'}
147
  ),
 
151
  dcc.Tab(label='Model Market Share', children=[
152
  html.Div([
153
  html.Div(children='Select time range to update all graphs below:', style={'fontSize': 16, 'marginBottom': 6, 'marginTop': 10}),
154
+ dcc.Graph(figure=slider, id='time-slider', style={'height': '100px'}),
155
  html.Div(
156
  id='output-container-range-slider',
157
  style={
 
184
  dcc.Dropdown(['Language Concentration', 'Architecture', 'License', 'Method'], 'Language Concentration', id='dropdown'),
185
  ], style={'marginTop': 6}),
186
  ]),
187
+ dcc.Tab(label='Model Relationships', children=[
188
+ ]),
189
  ])
190
  ],
191
  style={
 
201
  style={'fontFamily': 'Inter', 'backgroundColor': '#f7f7fa', 'minHeight': '100vh'}
202
  )
203
 
204
+ # Callbacks for interactivity
205
+
206
+ # Model Market Share Tab
207
+ # On slider change, update output text
208
  @app.callback(
209
  Output('output-container-range-slider', 'children'),
210
  [Input('time-slider', 'relayoutData')]
 
216
  return f'Selected time range: {start_time} to {end_time}'
217
  else:
218
  return 'Selected time range: All data'
219
+
220
+ # On slider change, update world map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @app.callback(
222
  Output('world-map-with-slider', 'figure'),
223
  [Input('time-slider', 'relayoutData')]
 
226
  if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
227
  start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
228
  end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
229
+ updated_fig = create_world_map(
230
  country_concentration_df, "time", "metric", "value", start_time=start_time, end_time=end_time
231
  )
232
  updated_fig.update_layout(font_family="Inter")
233
  return updated_fig
234
  else:
235
+ return world_map
236
 
237
+ # On slider change, update leaderboard
238
  @app.callback(
239
  Output('leaderboard', 'figure'),
240
  [Input('time-slider', 'relayoutData')]
 
249
  updated_fig.update_layout(font_family="Inter")
250
  return updated_fig
251
  else:
252
+ return leaderboard
253
+
254
+ # On slider change, update stacked area chart
255
  @app.callback(
256
  Output('stacked-area-chart', 'figure'),
257
  [Input('time-slider', 'relayoutData')]
 
260
  if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
261
  start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
262
  end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
263
+ updated_fig = create_stacked_area_chart(
264
  model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0,
265
  start_time=start_time, end_time=end_time
266
  )
267
  updated_fig.update_layout(font_family="Inter")
268
  return updated_fig
269
  else:
270
+ return model_market_share_area
271
+
272
+ # Model Characteristics Tab
273
+ # On dropdown change, update graph
274
+ @app.callback(
275
+ Output('language-concentration-chart', 'figure'),
276
+ [Input('dropdown', 'value')]
277
+ )
278
+ def update_graph(selected_metric):
279
+ if selected_metric == 'Language Concentration':
280
+ return language_concentration_area
281
+ elif selected_metric == 'License':
282
+ return license_concentration_area
283
+ elif selected_metric == 'Method':
284
+ return download_method_cumsum_line
285
+ elif selected_metric == 'Architecture':
286
+ return download_arch_cumsum_line
287
 
288
  # Run the app
289
  if __name__ == '__main__':
graphs/__pycache__/model_characteristics.cpython-39.pyc CHANGED
Binary files a/graphs/__pycache__/model_characteristics.cpython-39.pyc and b/graphs/__pycache__/model_characteristics.cpython-39.pyc differ
 
graphs/__pycache__/model_market_share.cpython-39.pyc CHANGED
Binary files a/graphs/__pycache__/model_market_share.cpython-39.pyc and b/graphs/__pycache__/model_market_share.cpython-39.pyc differ
 
graphs/model_characteristics.py CHANGED
@@ -1,27 +1,19 @@
1
  import plotly.graph_objects as go
2
  import plotly.express as px
3
 
4
- def create_plotly_language_concentration_chart(
5
- language_concentration_df,
6
  period_col,
7
  metric_col,
8
  value_col,
9
- LANG_SEGMENT_ORDER,
10
- PALETTE_0
11
- ):
12
- """
13
- Convert the language concentration visualization to Plotly
14
- """
15
-
16
- # Create figure
17
  fig = go.Figure()
18
-
19
- # Get unique time periods
20
- time_periods = sorted(language_concentration_df[period_col].unique())
21
-
22
  # Create stacked area traces
23
- for i, metric in enumerate(LANG_SEGMENT_ORDER):
24
- metric_data = language_concentration_df[language_concentration_df[metric_col] == metric]
25
 
26
  # Sort by time and get values
27
  metric_data = metric_data.sort_values(period_col)
@@ -37,7 +29,7 @@ def create_plotly_language_concentration_chart(
37
  mode='lines',
38
  line=dict(width=0),
39
  fill='tonexty' if i > 0 else 'tozeroy',
40
- fillcolor=PALETTE_0[i % len(PALETTE_0)],
41
  stackgroup='one',
42
  hovertemplate='<b>%{fullData.name}</b><br>' +
43
  'Time: %{x}<br>' +
@@ -45,12 +37,10 @@ def create_plotly_language_concentration_chart(
45
  )
46
  )
47
 
48
- # Update layout
49
  fig.update_layout(
50
  autosize=True,
51
- font_family="Times New Roman",
52
  font_size=14,
53
- showlegend=True, # Show legend for language concentration
54
  legend=dict(
55
  title="Language Concentration",
56
  orientation="v",
@@ -64,7 +54,6 @@ def create_plotly_language_concentration_chart(
64
  hovermode='x unified'
65
  )
66
 
67
- # Update x-axis
68
  fig.update_xaxes(
69
  title_text="",
70
  showgrid=True,
@@ -72,7 +61,6 @@ def create_plotly_language_concentration_chart(
72
  gridwidth=1
73
  )
74
 
75
- # Update y-axis
76
  fig.update_yaxes(
77
  title_text="",
78
  showgrid=True,
@@ -82,30 +70,26 @@ def create_plotly_language_concentration_chart(
82
 
83
  return fig
84
 
85
- def create_plotly_publication_curves_with_legend(
86
- download_method_cumsum_df,
87
- METHOD_PLOT_CHOICES,
88
  color_palette=None
89
  ):
90
- """
91
- Version with traditional legend instead of inline labels
92
- """
93
-
94
  fig = go.Figure()
95
 
96
- groups = download_method_cumsum_df['status'].unique()
97
 
98
  if color_palette is None:
99
  color_palette = px.colors.qualitative.Set1
100
 
101
  for i, group in enumerate(groups):
102
- group_data = download_method_cumsum_df[download_method_cumsum_df['status'] == group]
103
  group_data = group_data.sort_values('period')
104
 
105
  x_vals = group_data['period']
106
- y_vals = group_data[METHOD_PLOT_CHOICES["y_col"]]
107
-
108
- if METHOD_PLOT_CHOICES.get("y_format") == "percent":
109
  y_vals = y_vals * 100
110
 
111
  fig.add_trace(
@@ -121,7 +105,7 @@ def create_plotly_publication_curves_with_legend(
121
  opacity=0.85,
122
  hovertemplate='<b>%{fullData.name}</b><br>' +
123
  'Period: %{x}<br>' +
124
- 'Value: %{y:.2f}%<extra></extra>' if METHOD_PLOT_CHOICES.get("y_format") == "percent"
125
  else '<b>%{fullData.name}</b><br>Period: %{x}<br>Value: %{y}<extra></extra>'
126
  )
127
  )
@@ -147,16 +131,16 @@ def create_plotly_publication_curves_with_legend(
147
  showgrid=False,
148
  zeroline=False
149
  )
150
-
151
- y_title = METHOD_PLOT_CHOICES["y_col"]
152
- if METHOD_PLOT_CHOICES.get("y_format") == "percent":
153
  y_title += " (%)"
154
 
155
  fig.update_yaxes(
156
  title_text=y_title,
157
  showgrid=False,
158
  zeroline=False,
159
- type='log' if METHOD_PLOT_CHOICES.get("y_log") else 'linear'
160
  )
161
 
162
  return fig
 
1
  import plotly.graph_objects as go
2
  import plotly.express as px
3
 
4
+ def create_concentration_chart(
5
+ df,
6
  period_col,
7
  metric_col,
8
  value_col,
9
+ order,
10
+ palette
11
+ ):
 
 
 
 
 
12
  fig = go.Figure()
13
+
 
 
 
14
  # Create stacked area traces
15
+ for i, metric in enumerate(order):
16
+ metric_data = df[df[metric_col] == metric]
17
 
18
  # Sort by time and get values
19
  metric_data = metric_data.sort_values(period_col)
 
29
  mode='lines',
30
  line=dict(width=0),
31
  fill='tonexty' if i > 0 else 'tozeroy',
32
+ fillcolor=palette[i % len(palette)],
33
  stackgroup='one',
34
  hovertemplate='<b>%{fullData.name}</b><br>' +
35
  'Time: %{x}<br>' +
 
37
  )
38
  )
39
 
 
40
  fig.update_layout(
41
  autosize=True,
 
42
  font_size=14,
43
+ showlegend=True,
44
  legend=dict(
45
  title="Language Concentration",
46
  orientation="v",
 
54
  hovermode='x unified'
55
  )
56
 
 
57
  fig.update_xaxes(
58
  title_text="",
59
  showgrid=True,
 
61
  gridwidth=1
62
  )
63
 
 
64
  fig.update_yaxes(
65
  title_text="",
66
  showgrid=True,
 
70
 
71
  return fig
72
 
73
+ def create_line_plot(
74
+ df,
75
+ plot_choices,
76
  color_palette=None
77
  ):
 
 
 
 
78
  fig = go.Figure()
79
 
80
+ groups = df['status'].unique()
81
 
82
  if color_palette is None:
83
  color_palette = px.colors.qualitative.Set1
84
 
85
  for i, group in enumerate(groups):
86
+ group_data = df[df['status'] == group]
87
  group_data = group_data.sort_values('period')
88
 
89
  x_vals = group_data['period']
90
+ y_vals = group_data[plot_choices["y_col"]]
91
+
92
+ if plot_choices.get("y_format") == "percent":
93
  y_vals = y_vals * 100
94
 
95
  fig.add_trace(
 
105
  opacity=0.85,
106
  hovertemplate='<b>%{fullData.name}</b><br>' +
107
  'Period: %{x}<br>' +
108
+ 'Value: %{y:.2f}%<extra></extra>' if plot_choices.get("y_format") == "percent"
109
  else '<b>%{fullData.name}</b><br>Period: %{x}<br>Value: %{y}<extra></extra>'
110
  )
111
  )
 
131
  showgrid=False,
132
  zeroline=False
133
  )
134
+
135
+ y_title = plot_choices["y_col"]
136
+ if plot_choices.get("y_format") == "percent":
137
  y_title += " (%)"
138
 
139
  fig.update_yaxes(
140
  title_text=y_title,
141
  showgrid=False,
142
  zeroline=False,
143
+ type='log' if plot_choices.get("y_log") else 'linear'
144
  )
145
 
146
  return fig
graphs/model_market_share.py CHANGED
@@ -4,12 +4,9 @@ import pandas as pd
4
 
5
  filtered_df = pd.read_pickle("data_frames/filtered_df.pkl")
6
 
7
- def create_plotly_stacked_area_chart(
8
- model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0, start_time=None, end_time=None
9
  ):
10
- """
11
- Convert the visualization_util stacked area chart to Plotly
12
- """
13
 
14
  # Create subplot with secondary y-axis
15
  fig = make_subplots(specs=[[{"secondary_y": True}]])
@@ -26,7 +23,7 @@ def create_plotly_stacked_area_chart(
26
 
27
  # Create stacked area traces
28
  for i, metric in enumerate(metric_order):
29
- metric_data = model_topk_df[model_topk_df["metric"] == metric]
30
 
31
  # Sort by time and get values
32
  metric_data = metric_data.sort_values("time")
@@ -45,9 +42,9 @@ def create_plotly_stacked_area_chart(
45
  y=y_vals,
46
  name=metric,
47
  mode="lines",
48
- line=dict(width=0, color=PALETTE_0[i % len(PALETTE_0)]),
49
  fill="tonexty" if i > 0 else "tozeroy",
50
- fillcolor=PALETTE_0[i % len(PALETTE_0)], # Add opacity
51
  stackgroup="one",
52
  hovertemplate="<b>%{fullData.name}</b><br>"
53
  + "Time: %{x}<br>"
@@ -58,7 +55,7 @@ def create_plotly_stacked_area_chart(
58
 
59
  # Add overlay lines
60
  # Gini Coefficient
61
- gini_data = model_gini_df.sort_values("time")
62
  if start_time:
63
  gini_data = gini_data[gini_data["time"] >= start_time]
64
  if end_time:
@@ -79,7 +76,7 @@ def create_plotly_stacked_area_chart(
79
  )
80
 
81
  # HHI (ร—10)
82
- hhi_data = model_hhi_df.sort_values("time")
83
  if start_time:
84
  hhi_data = hhi_data[hhi_data["time"] >= start_time]
85
  if end_time:
@@ -87,7 +84,7 @@ def create_plotly_stacked_area_chart(
87
  fig.add_trace(
88
  go.Scatter(
89
  x=hhi_data["time"],
90
- y=hhi_data["value"] * 10, # Multiply by 10 as indicated
91
  name="HHI (ร—10)",
92
  mode="lines",
93
  line=dict(color="#ec4899", width=3),
@@ -100,7 +97,7 @@ def create_plotly_stacked_area_chart(
100
  )
101
 
102
  # Add vertical lines for events
103
- for event_name, event_date in TEMP_MODEL_EVENTS.items():
104
  fig.add_shape(
105
  type="line",
106
  x0=event_date,
@@ -122,12 +119,10 @@ def create_plotly_stacked_area_chart(
122
  font=dict(size=12),
123
  )
124
 
125
- # Update layout
126
  fig.update_layout(
127
  autosize=True,
128
- font_family="Inter",
129
  font_size=14,
130
- showlegend=False, # Set to True if you want to show legend
131
  margin=dict(l=60, r=60, t=40, b=60),
132
  plot_bgcolor="white",
133
  hovermode="x unified",
@@ -167,30 +162,12 @@ def create_plotly_stacked_area_chart(
167
  return fig
168
 
169
 
170
- def create_plotly_world_map(
171
  df, time_col="time", metric_col="metric", value_col="value", top_n_labels=10, start_time=None, end_time=None
172
  ):
173
  # Get all unique times and sort them
174
  times = sorted(df[time_col].unique())
175
 
176
- # Create aggregated data across the full time range initially
177
- regions_to_exclude = [
178
- "Asia",
179
- "Europe",
180
- "North America",
181
- "South America",
182
- "Africa",
183
- "Oceania",
184
- "Middle East",
185
- "Unknown",
186
- "Online",
187
- "International",
188
- "HF",
189
- ]
190
-
191
- # Filter out regions
192
- country_data = df[~df[metric_col].isin(regions_to_exclude)].copy()
193
-
194
  # Country code mapping
195
  country_code_map = {
196
  "Germany": "DEU",
@@ -238,16 +215,13 @@ def create_plotly_world_map(
238
  "Turkey": "TUR",
239
  }
240
 
241
- country_data["country_code"] = country_data[metric_col].map(country_code_map)
242
- mapped_data = country_data.dropna(subset=["country_code"])
243
 
244
- # Create subplot with secondary plot for range slider
245
  fig = make_subplots(
246
- rows=2,
247
  cols=1,
248
- row_heights=[0.85, 0.15],
249
- vertical_spacing=0.02,
250
- specs=[[{"type": "geo"}], [{"type": "scatter"}]],
251
  )
252
 
253
  # Function to aggregate data for time range
@@ -264,13 +238,13 @@ def create_plotly_world_map(
264
  agg_data["percentage"] = agg_data[value_col] * 100
265
  return agg_data.sort_values("percentage", ascending=False)
266
 
267
- # Initial data (full range)
268
  if start_time is None:
269
  start_time = times[0]
270
  if end_time is None:
271
  end_time = times[-1]
272
  initial_data = aggregate_time_range(start_time, end_time)
273
- top_countries = initial_data.head(top_n_labels)
274
 
275
  # Create hover text
276
  hover_text = []
@@ -281,7 +255,7 @@ def create_plotly_world_map(
281
  f"Avg Value: {row[value_col]:.6f}"
282
  )
283
 
284
- # Add choropleth to first subplot
285
  fig.add_trace(
286
  go.Choropleth(
287
  locations=initial_data["country_code"],
@@ -300,13 +274,13 @@ def create_plotly_world_map(
300
  ],
301
  colorbar=dict(
302
  title="Avg % of Total Downloads",
303
- tickfont=dict(size=12, family="Inter, system-ui, sans-serif"),
304
  len=0.6,
305
  x=1.02,
306
  y=0.7,
307
  ),
308
- marker_line_color="#219ebc",
309
- marker_line_width=0.4,
310
  geo="geo",
311
  ),
312
  row=1,
@@ -314,63 +288,62 @@ def create_plotly_world_map(
314
  )
315
 
316
  # Country center coordinates for labels
317
- country_centers = {
318
- "USA": {"lat": 39.8, "lon": -98.5},
319
- "CHN": {"lat": 35.8, "lon": 104.2},
320
- "DEU": {"lat": 51.2, "lon": 10.4},
321
- "GBR": {"lat": 55.4, "lon": -3.4},
322
- "FRA": {"lat": 46.6, "lon": 2.2},
323
- "JPN": {"lat": 36.2, "lon": 138.3},
324
- "IND": {"lat": 20.6, "lon": 78.9},
325
- "CAN": {"lat": 56.1, "lon": -106.3},
326
- "RUS": {"lat": 61.5, "lon": 105.3},
327
- "BRA": {"lat": -14.2, "lon": -51.9},
328
- "AUS": {"lat": -25.3, "lon": 133.8},
329
- "KOR": {"lat": 35.9, "lon": 127.8},
330
- }
331
-
332
- # Add initial labels using scattergeo instead of annotations
333
- label_lons = []
334
- label_lats = []
335
- label_texts = []
336
-
337
- for _, country in top_countries.iterrows():
338
- country_code = country["country_code"]
339
- if country_code in country_centers:
340
- center = country_centers[country_code]
341
- label_lons.append(center["lon"])
342
- label_lats.append(center["lat"])
343
- label_texts.append(f"{country['percentage']:.1f}%")
344
-
345
- # Add text labels as a scattergeo trace
346
- fig.add_trace(
347
- go.Scattergeo(
348
- lon=label_lons,
349
- lat=label_lats,
350
- text=label_texts,
351
- mode="text",
352
- textfont=dict(
353
- color="#ffffff", size=13, family="Inter, system-ui, sans-serif"
354
- ),
355
- textposition="middle center",
356
- showlegend=False,
357
- hoverinfo="skip",
358
- geo="geo",
359
- ),
360
- row=1,
361
- col=1,
362
- )
363
 
364
  # Update layout
365
  fig.update_layout(
366
  title=dict(
367
  text="Model Downloads by Country",
368
  x=0.5,
369
- font=dict(size=20, family="Inter, system-ui, sans-serif", color="#212529"),
370
  ),
371
  width=1200,
372
  height=800,
373
- font=dict(family="Inter, system-ui, sans-serif"),
374
  plot_bgcolor="#ffffff",
375
  paper_bgcolor="#ffffff",
376
  margin=dict(l=0, r=120, t=100, b=60),
@@ -379,35 +352,27 @@ def create_plotly_world_map(
379
  # Update geo layout
380
  fig.update_geos(
381
  showframe=False,
382
- showcoastlines=True,
383
  showland=True,
384
- landcolor="#f8f9fa",
385
- coastlinecolor="#023047",
386
- oceancolor="#8ecae6",
387
- projection_type="equirectangular",
388
  bgcolor="#ffffff",
389
  )
390
 
391
- # Remove excessive whitespace below the map by adjusting subplot row heights and margins
392
- fig.update_layout(
393
- margin=dict(l=0, r=120, t=100, b=20), # Reduce bottom margin
394
- height=600, # Reduce overall figure height
395
- )
396
  return fig
397
 
398
- def create_plotly_range_slider(df):
399
  if df.empty or "time" not in df.columns:
400
  return go.Figure()
401
 
402
  times = sorted(df["time"].unique())
403
-
404
  fig = go.Figure()
405
 
406
  # Invisible trace just to attach slider to the x-axis
407
  fig.add_trace(
408
  go.Scatter(
409
  x=times,
410
- y=[0] * len(times), # Dummy y-values
411
  mode="lines",
412
  line=dict(color="rgba(0,0,0,0)"), # Invisible line
413
  hoverinfo="skip",
@@ -421,14 +386,33 @@ def create_plotly_range_slider(df):
421
  rangeslider=dict(visible=False),
422
  type="date"
423
  ),
424
- yaxis=dict(visible=False), # Hide y-axis since it's dummy
425
  margin=dict(t=20, b=20, l=20, r=20),
426
- height=100 # Compact slider-only view
427
  )
428
 
429
  return fig
430
 
431
  def create_leaderboard(country_df, developer_df, model_df, start_time=None, end_time=None, top_n=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  # Ensure datetime
433
  country_df["time"] = pd.to_datetime(country_df["time"])
434
  developer_df["time"] = pd.to_datetime(developer_df["time"])
@@ -449,25 +433,6 @@ def create_leaderboard(country_df, developer_df, model_df, start_time=None, end_
449
  how="left"
450
  ).rename(columns={"country": "country_metric"}).drop(columns=["model"])
451
 
452
- # Country -> Emoji mapping
453
- country_emoji_map = {
454
- "United States of America": "๐Ÿ‡บ๐Ÿ‡ธ",
455
- "China": "๐Ÿ‡จ๐Ÿ‡ณ",
456
- "Germany": "๐Ÿ‡ฉ๐Ÿ‡ช",
457
- "France": "๐Ÿ‡ซ๐Ÿ‡ท",
458
- "India": "๐Ÿ‡ฎ๐Ÿ‡ณ",
459
- "Italy": "๐Ÿ‡ฎ๐Ÿ‡น",
460
- "Japan": "๐Ÿ‡ฏ๐Ÿ‡ต",
461
- "South Korea": "๐Ÿ‡ฐ๐Ÿ‡ท",
462
- "United Kingdom": "๐Ÿ‡ฌ๐Ÿ‡ง",
463
- "Canada": "๐Ÿ‡จ๐Ÿ‡ฆ",
464
- "Brazil": "๐Ÿ‡ง๐Ÿ‡ท",
465
- "Australia": "๐Ÿ‡ฆ๐Ÿ‡บ",
466
- "Unknown": "โ“",
467
- "Finland": "๐Ÿ‡ซ๐Ÿ‡ฎ",
468
- "Lebanon": "๐Ÿ‡ฑ๐Ÿ‡ง ",
469
- }
470
-
471
  if start_time is None:
472
  start_time = country_df["time"].min()
473
  if end_time is None:
@@ -487,6 +452,7 @@ def create_leaderboard(country_df, developer_df, model_df, start_time=None, end_
487
  if country_df_filtered.empty and developer_df_filtered.empty and model_df_filtered.empty:
488
  return go.Figure()
489
 
 
490
  def get_top_n_leaderboard(df, group_col, label, top_n=10):
491
  top = (
492
  df.groupby(group_col)["value"]
@@ -501,9 +467,10 @@ def create_leaderboard(country_df, developer_df, model_df, start_time=None, end_
501
  top["% of total"] = top["Total Value"] / total_value * 100
502
  else:
503
  top["% of total"] = 0
 
504
  # add column with metadata (country emoji for country, country for developer/model)
505
  if label == "Country":
506
- top["Metadata"] = top[label].map(country_emoji_map).fillna("")
507
  else:
508
  # Get the country_metric for each developer/model with the already merged info
509
  top = top.merge(
@@ -512,8 +479,8 @@ def create_leaderboard(country_df, developer_df, model_df, start_time=None, end_
512
  right_on=group_col,
513
  how="left"
514
  ).drop(columns=[group_col])
515
- top["Metadata"] = top["country_metric"].map(country_emoji_map).fillna("")
516
- return top[[label, "Metadata", "% of total"]]
517
 
518
  top_countries = get_top_n_leaderboard(country_df_filtered, "metric", "Country", top_n=top_n)
519
  top_developers = get_top_n_leaderboard(developer_df_filtered, "metric", "Developer", top_n=top_n)
 
4
 
5
  filtered_df = pd.read_pickle("data_frames/filtered_df.pkl")
6
 
7
+ def create_stacked_area_chart(
8
+ topk_df, gini_df, hhi_df, events, palette, start_time=None, end_time=None
9
  ):
 
 
 
10
 
11
  # Create subplot with secondary y-axis
12
  fig = make_subplots(specs=[[{"secondary_y": True}]])
 
23
 
24
  # Create stacked area traces
25
  for i, metric in enumerate(metric_order):
26
+ metric_data = topk_df[topk_df["metric"] == metric]
27
 
28
  # Sort by time and get values
29
  metric_data = metric_data.sort_values("time")
 
42
  y=y_vals,
43
  name=metric,
44
  mode="lines",
45
+ line=dict(width=0, color=palette[i % len(palette)]),
46
  fill="tonexty" if i > 0 else "tozeroy",
47
+ fillcolor=palette[i % len(palette)],
48
  stackgroup="one",
49
  hovertemplate="<b>%{fullData.name}</b><br>"
50
  + "Time: %{x}<br>"
 
55
 
56
  # Add overlay lines
57
  # Gini Coefficient
58
+ gini_data = gini_df.sort_values("time")
59
  if start_time:
60
  gini_data = gini_data[gini_data["time"] >= start_time]
61
  if end_time:
 
76
  )
77
 
78
  # HHI (ร—10)
79
+ hhi_data = hhi_df.sort_values("time")
80
  if start_time:
81
  hhi_data = hhi_data[hhi_data["time"] >= start_time]
82
  if end_time:
 
84
  fig.add_trace(
85
  go.Scatter(
86
  x=hhi_data["time"],
87
+ y=hhi_data["value"] * 10,
88
  name="HHI (ร—10)",
89
  mode="lines",
90
  line=dict(color="#ec4899", width=3),
 
97
  )
98
 
99
  # Add vertical lines for events
100
+ for event_name, event_date in events.items():
101
  fig.add_shape(
102
  type="line",
103
  x0=event_date,
 
119
  font=dict(size=12),
120
  )
121
 
 
122
  fig.update_layout(
123
  autosize=True,
 
124
  font_size=14,
125
+ showlegend=True,
126
  margin=dict(l=60, r=60, t=40, b=60),
127
  plot_bgcolor="white",
128
  hovermode="x unified",
 
162
  return fig
163
 
164
 
165
+ def create_world_map(
166
  df, time_col="time", metric_col="metric", value_col="value", top_n_labels=10, start_time=None, end_time=None
167
  ):
168
  # Get all unique times and sort them
169
  times = sorted(df[time_col].unique())
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Country code mapping
172
  country_code_map = {
173
  "Germany": "DEU",
 
215
  "Turkey": "TUR",
216
  }
217
 
218
+ df["country_code"] = df[metric_col].map(country_code_map)
219
+ mapped_data = df.dropna(subset=["country_code"])
220
 
 
221
  fig = make_subplots(
222
+ rows=1,
223
  cols=1,
224
+ specs=[[{"type": "geo"}]],
 
 
225
  )
226
 
227
  # Function to aggregate data for time range
 
238
  agg_data["percentage"] = agg_data[value_col] * 100
239
  return agg_data.sort_values("percentage", ascending=False)
240
 
241
+ # Initial data if start or end time are not set (full range)
242
  if start_time is None:
243
  start_time = times[0]
244
  if end_time is None:
245
  end_time = times[-1]
246
  initial_data = aggregate_time_range(start_time, end_time)
247
+ # top_countries = initial_data.head(top_n_labels)
248
 
249
  # Create hover text
250
  hover_text = []
 
255
  f"Avg Value: {row[value_col]:.6f}"
256
  )
257
 
258
+ # Add choropleth to plot
259
  fig.add_trace(
260
  go.Choropleth(
261
  locations=initial_data["country_code"],
 
274
  ],
275
  colorbar=dict(
276
  title="Avg % of Total Downloads",
277
+ tickfont=dict(size=12),
278
  len=0.6,
279
  x=1.02,
280
  y=0.7,
281
  ),
282
+ marker_line_color="#ffffff",
283
+ marker_line_width=1.5,
284
  geo="geo",
285
  ),
286
  row=1,
 
288
  )
289
 
290
  # Country center coordinates for labels
291
+ # country_centers = {
292
+ # "USA": {"lat": 39.8, "lon": -98.5},
293
+ # "CHN": {"lat": 35.8, "lon": 104.2},
294
+ # "DEU": {"lat": 51.2, "lon": 10.4},
295
+ # "GBR": {"lat": 55.4, "lon": -3.4},
296
+ # "FRA": {"lat": 46.6, "lon": 2.2},
297
+ # "JPN": {"lat": 36.2, "lon": 138.3},
298
+ # "IND": {"lat": 20.6, "lon": 78.9},
299
+ # "CAN": {"lat": 56.1, "lon": -106.3},
300
+ # "RUS": {"lat": 61.5, "lon": 105.3},
301
+ # "BRA": {"lat": -14.2, "lon": -51.9},
302
+ # "AUS": {"lat": -25.3, "lon": 133.8},
303
+ # "KOR": {"lat": 35.9, "lon": 127.8},
304
+ # }
305
+
306
+ # # Add initial labels using scattergeo instead of annotations
307
+ # label_lons = []
308
+ # label_lats = []
309
+ # label_texts = []
310
+
311
+ # for _, country in top_countries.iterrows():
312
+ # country_code = country["country_code"]
313
+ # if country_code in country_centers:
314
+ # center = country_centers[country_code]
315
+ # label_lons.append(center["lon"])
316
+ # label_lats.append(center["lat"])
317
+ # label_texts.append(f"{country['percentage']:.1f}%")
318
+
319
+ # # Add text labels as a scattergeo trace
320
+ # fig.add_trace(
321
+ # go.Scattergeo(
322
+ # lon=label_lons,
323
+ # lat=label_lats,
324
+ # text=label_texts,
325
+ # mode="text",
326
+ # textfont=dict(
327
+ # color="#ffffff", size=13, family="Inter, system-ui, sans-serif"
328
+ # ),
329
+ # textposition="middle center",
330
+ # showlegend=False,
331
+ # hoverinfo="skip",
332
+ # geo="geo",
333
+ # ),
334
+ # row=1,
335
+ # col=1,
336
+ # )
337
 
338
  # Update layout
339
  fig.update_layout(
340
  title=dict(
341
  text="Model Downloads by Country",
342
  x=0.5,
343
+ font=dict(size=20),
344
  ),
345
  width=1200,
346
  height=800,
 
347
  plot_bgcolor="#ffffff",
348
  paper_bgcolor="#ffffff",
349
  margin=dict(l=0, r=120, t=100, b=60),
 
352
  # Update geo layout
353
  fig.update_geos(
354
  showframe=False,
 
355
  showland=True,
356
+ landcolor="#d0cfcf",
357
+ coastlinecolor="#b8b8b8",
358
+ projection_type="natural earth",
 
359
  bgcolor="#ffffff",
360
  )
361
 
 
 
 
 
 
362
  return fig
363
 
364
+ def create_range_slider(df):
365
  if df.empty or "time" not in df.columns:
366
  return go.Figure()
367
 
368
  times = sorted(df["time"].unique())
 
369
  fig = go.Figure()
370
 
371
  # Invisible trace just to attach slider to the x-axis
372
  fig.add_trace(
373
  go.Scatter(
374
  x=times,
375
+ y=[0] * len(times),
376
  mode="lines",
377
  line=dict(color="rgba(0,0,0,0)"), # Invisible line
378
  hoverinfo="skip",
 
386
  rangeslider=dict(visible=False),
387
  type="date"
388
  ),
389
+ yaxis=dict(visible=False),
390
  margin=dict(t=20, b=20, l=20, r=20),
391
+ height=100
392
  )
393
 
394
  return fig
395
 
396
  def create_leaderboard(country_df, developer_df, model_df, start_time=None, end_time=None, top_n=10):
397
+ # Country -> Emoji mapping
398
+ country_emoji_map = {
399
+ "United States of America": "๐Ÿ‡บ๐Ÿ‡ธ",
400
+ "China": "๐Ÿ‡จ๐Ÿ‡ณ",
401
+ "Germany": "๐Ÿ‡ฉ๐Ÿ‡ช",
402
+ "France": "๐Ÿ‡ซ๐Ÿ‡ท",
403
+ "India": "๐Ÿ‡ฎ๐Ÿ‡ณ",
404
+ "Italy": "๐Ÿ‡ฎ๐Ÿ‡น",
405
+ "Japan": "๐Ÿ‡ฏ๐Ÿ‡ต",
406
+ "South Korea": "๐Ÿ‡ฐ๐Ÿ‡ท",
407
+ "United Kingdom": "๐Ÿ‡ฌ๐Ÿ‡ง",
408
+ "Canada": "๐Ÿ‡จ๐Ÿ‡ฆ",
409
+ "Brazil": "๐Ÿ‡ง๐Ÿ‡ท",
410
+ "Australia": "๐Ÿ‡ฆ๐Ÿ‡บ",
411
+ "Unknown": "โ“",
412
+ "Finland": "๐Ÿ‡ซ๐Ÿ‡ฎ",
413
+ "Lebanon": "๐Ÿ‡ฑ๐Ÿ‡ง ",
414
+ }
415
+
416
  # Ensure datetime
417
  country_df["time"] = pd.to_datetime(country_df["time"])
418
  developer_df["time"] = pd.to_datetime(developer_df["time"])
 
433
  how="left"
434
  ).rename(columns={"country": "country_metric"}).drop(columns=["model"])
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  if start_time is None:
437
  start_time = country_df["time"].min()
438
  if end_time is None:
 
452
  if country_df_filtered.empty and developer_df_filtered.empty and model_df_filtered.empty:
453
  return go.Figure()
454
 
455
+ # Function to get top N leaderboard with percentage
456
  def get_top_n_leaderboard(df, group_col, label, top_n=10):
457
  top = (
458
  df.groupby(group_col)["value"]
 
467
  top["% of total"] = top["Total Value"] / total_value * 100
468
  else:
469
  top["% of total"] = 0
470
+
471
  # add column with metadata (country emoji for country, country for developer/model)
472
  if label == "Country":
473
+ top["Attributes"] = top[label].map(country_emoji_map).fillna("")
474
  else:
475
  # Get the country_metric for each developer/model with the already merged info
476
  top = top.merge(
 
479
  right_on=group_col,
480
  how="left"
481
  ).drop(columns=[group_col])
482
+ top["Attributes"] = top["country_metric"].map(country_emoji_map).fillna("")
483
+ return top[[label, "Attributes", "% of total"]]
484
 
485
  top_countries = get_top_n_leaderboard(country_df_filtered, "metric", "Country", top_n=top_n)
486
  top_developers = get_top_n_leaderboard(developer_df_filtered, "metric", "Developer", top_n=top_n)