emsesc commited on
Commit
66b3482
·
1 Parent(s): b233a23

move legends

Browse files
app.py CHANGED
@@ -23,6 +23,9 @@ nat_topk_df = pd.read_pickle("data_frames/nat_topk_df.pkl")
23
  country_concentration_df = pd.read_pickle("data_frames/country_concentration_df.pkl")
24
  author_concentration_df = pd.read_pickle("data_frames/author_concentration_df.pkl")
25
  model_concentration_df = pd.read_pickle("data_frames/model_concentration_df.pkl")
 
 
 
26
 
27
  # Configurations
28
  TEMP_MODEL_EVENTS = {
@@ -76,10 +79,46 @@ ARCHITECTURE_PLOT_CHOICES = {
76
  "period": "W",
77
  }
78
 
79
- # Create initial figures
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Model Market Share Tab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  model_market_share_area = create_stacked_area_chart(
82
- model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0
83
  )
84
 
85
  world_map = create_world_map(
@@ -107,7 +146,18 @@ time_slider = dmc.RangeSlider(
107
  {"value": model_topk_df['time'].max().timestamp(), "label": model_topk_df['time'].max().strftime("%b %Y")}
108
  ],
109
  style={"width": "70%", "margin": "0 auto"},
110
- labelAlwaysOn=False
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
 
113
  # Model Characteristics Tab
@@ -196,14 +246,15 @@ app.layout = dmc.MantineProvider(
196
  ),
197
  ], style={'marginBottom': 12, 'justifyContent': 'center', 'textAlign': 'center'}),
198
  html.Div([
199
- dcc.Graph(id='stacked-area-chart'),
 
200
  ], style={'marginBottom': 12}),
201
  html.Div([
202
  html.Div(
203
  dcc.Graph(id='world-map-with-slider'),
204
- style={'display': 'flex', 'justifyContent': 'center'}
205
  ),
206
- # dcc.Graph(id='leaderboard'),
207
  ], style={'marginBottom': 12})
208
  ]),
209
  dcc.Tab(label='Leaderboard', children=[
@@ -259,18 +310,19 @@ def update_output(value):
259
  # On slider change, update world map
260
  @app.callback(
261
  Output('world-map-with-slider', 'figure'),
262
- Input('time-slider', 'value')
263
  )
264
  def update_world_map(value):
265
- if value and len(value) == 2:
266
- start_time = pd.to_datetime(value[0], unit='s').strftime('%Y-%m-%d')
267
- end_time = pd.to_datetime(value[1], unit='s').strftime('%Y-%m-%d')
268
- updated_fig = create_world_map(
269
- filtered_df
270
- )
271
- updated_fig.update_layout(font_family="Inter")
272
- return updated_fig
273
- return world_map
 
274
 
275
 
276
  # On slider change, update leaderboard
@@ -291,22 +343,22 @@ def update_world_map(value):
291
  # return leaderboard
292
 
293
  # On slider change, update stacked area chart
294
- @app.callback(
295
- Output('stacked-area-chart', 'figure'),
296
- Input('time-slider', 'value')
297
- )
298
- def update_stacked_area(value):
299
- if value and len(value) == 2:
300
- start_time = pd.to_datetime(value[0], unit='s').strftime('%Y-%m-%d')
301
- end_time = pd.to_datetime(value[1], unit='s').strftime('%Y-%m-%d')
302
- updated_fig = create_stacked_area_chart(
303
- model_topk_df, model_gini_df, model_hhi_df,
304
- TEMP_MODEL_EVENTS, PALETTE_0,
305
- start_time=start_time, end_time=end_time
306
- )
307
- updated_fig.update_layout(font_family="Inter")
308
- return updated_fig
309
- return model_market_share_area
310
 
311
  @app.callback(
312
  Output("top_countries-table", "children"),
 
23
  country_concentration_df = pd.read_pickle("data_frames/country_concentration_df.pkl")
24
  author_concentration_df = pd.read_pickle("data_frames/author_concentration_df.pkl")
25
  model_concentration_df = pd.read_pickle("data_frames/model_concentration_df.pkl")
26
+ derived_country_concentration_df = pd.read_pickle("data_frames/derived_country_concentration_df_rolling.pkl")
27
+ nat_gini_df = pd.read_pickle("data_frames/nat_gini_df.pkl")
28
+ nat_hhi_df = pd.read_pickle("data_frames/nat_hhi_df.pkl")
29
 
30
  # Configurations
31
  TEMP_MODEL_EVENTS = {
 
79
  "period": "W",
80
  }
81
 
82
+ metric_order = [
83
+ 'USA', 'China', 'Germany', 'France', 'International / Online',
84
+ 'Asia', 'Middle East', 'Rest of Europe', 'South America', 'UK',
85
+ 'Africa', 'Other', "User",
86
+ ]
87
+
88
+ palette = [
89
+ "#3870f2",
90
+ "#e74c3c", # Green (Top 10-100) # Red (Top 1%)
91
+ "#f39c12", # Orange (Top 1-10%)
92
+ "#3498db", # Blue (Top 100-1000)
93
+ "#7C2A50",
94
+ "#9467bd",
95
+ "#8c564b",
96
+ "#e377c2",
97
+ "#7f7f7f",
98
+ "#27ae60",
99
+ "#5ce7f6",
100
+ "#f0e442",
101
+ "#c2cbcc", # Gray (Rest)
102
+ "#56b4e9",
103
+ ]
104
+
105
  # Model Market Share Tab
106
+ country_market_share_area = create_stacked_area_chart(
107
+ derived_country_concentration_df, nat_gini_df, nat_hhi_df, TEMP_MODEL_EVENTS, palette, metric_order
108
+ )
109
+
110
+ # Define metric order
111
+ metric_order = [
112
+ "Top 1",
113
+ "Top 1 - 10",
114
+ "Top 10 - 100",
115
+ "Top 100 - 1000",
116
+ "Top 1000 - 10000",
117
+ "Rest",
118
+ ]
119
+
120
  model_market_share_area = create_stacked_area_chart(
121
+ model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0, metric_order
122
  )
123
 
124
  world_map = create_world_map(
 
146
  {"value": model_topk_df['time'].max().timestamp(), "label": model_topk_df['time'].max().strftime("%b %Y")}
147
  ],
148
  style={"width": "70%", "margin": "0 auto"},
149
+ labelAlwaysOn=False,
150
+ )
151
+
152
+ # Create a dcc slider for time range selection by year
153
+ created_slider = dcc.Slider(
154
+ id='created-slider',
155
+ min=filtered_df['time'].min().year,
156
+ max=filtered_df['time'].max().year,
157
+ marks={year: str(year) for year in range(filtered_df['time'].min().year, filtered_df['time'].max().year + 1)},
158
+ step=1,
159
+ tooltip={"placement": "bottom", "always_visible": True},
160
+ updatemode='mouseup',
161
  )
162
 
163
  # Model Characteristics Tab
 
246
  ),
247
  ], style={'marginBottom': 12, 'justifyContent': 'center', 'textAlign': 'center'}),
248
  html.Div([
249
+ # dcc.Graph(id='stacked-area-chart'),
250
+ dcc.Graph(figure=country_market_share_area),
251
  ], style={'marginBottom': 12}),
252
  html.Div([
253
  html.Div(
254
  dcc.Graph(id='world-map-with-slider'),
255
+ style={'display': 'flex', 'justifyContent': 'center', 'marginBottom': 0}
256
  ),
257
+ created_slider,
258
  ], style={'marginBottom': 12})
259
  ]),
260
  dcc.Tab(label='Leaderboard', children=[
 
310
  # On slider change, update world map
311
  @app.callback(
312
  Output('world-map-with-slider', 'figure'),
313
+ Input('created-slider', 'value')
314
  )
315
  def update_world_map(value):
316
+ # Filter by created year
317
+ if value is None:
318
+ return world_map
319
+
320
+ created_after = f"{int(value)}-01-01"
321
+ updated_fig = create_world_map(
322
+ filtered_df,
323
+ created_after=created_after
324
+ )
325
+ return updated_fig
326
 
327
 
328
  # On slider change, update leaderboard
 
343
  # return leaderboard
344
 
345
  # On slider change, update stacked area chart
346
+ # @app.callback(
347
+ # Output('stacked-area-chart', 'figure'),
348
+ # Input('time-slider', 'value')
349
+ # )
350
+ # def update_stacked_area(value):
351
+ # if value and len(value) == 2:
352
+ # start_time = pd.to_datetime(value[0], unit='s').strftime('%Y-%m-%d')
353
+ # end_time = pd.to_datetime(value[1], unit='s').strftime('%Y-%m-%d')
354
+ # updated_fig = create_stacked_area_chart(
355
+ # model_topk_df, model_gini_df, model_hhi_df,
356
+ # TEMP_MODEL_EVENTS, PALETTE_0,
357
+ # start_time=start_time, end_time=end_time
358
+ # )
359
+ # updated_fig.update_layout(font_family="Inter")
360
+ # return updated_fig
361
+ # return model_market_share_area
362
 
363
  @app.callback(
364
  Output("top_countries-table", "children"),
data_frames/derived_country_concentration_df_rolling.pkl ADDED
Binary file (83 kB). View file
 
graphs/leaderboard.py CHANGED
@@ -32,6 +32,17 @@ country_icon_map = {
32
  "Unknown": "❓",
33
  "Finland": "🇫🇮",
34
  "Lebanon": "🇱🇧",
 
 
 
 
 
 
 
 
 
 
 
35
  "User": "👤",
36
  "International/Online": "🌐",
37
  }
 
32
  "Unknown": "❓",
33
  "Finland": "🇫🇮",
34
  "Lebanon": "🇱🇧",
35
+ "Iceland": "🇮🇸",
36
+ "Singapore": "🇸🇬",
37
+ "Israel": "🇮🇱",
38
+ "Iran": "🇮🇷",
39
+ "Hong Kong": "🇭🇰",
40
+ "Netherlands": "🇳🇱",
41
+ "Chile": "🇨🇱",
42
+ "Vietnam": "🇻🇳",
43
+ "Russia": "🇷🇺",
44
+ "Qatar": "🇶🇦",
45
+ "Switzerland": "🇨🇭",
46
  "User": "👤",
47
  "International/Online": "🌐",
48
  }
graphs/model_characteristics.py CHANGED
@@ -41,17 +41,16 @@ def create_concentration_chart(
41
  autosize=True,
42
  font_size=14,
43
  showlegend=True,
 
 
 
44
  legend=dict(
45
- title="Language Concentration",
46
- orientation="v",
47
- yanchor="top",
48
- y=1,
49
- xanchor="left",
50
- x=1.02
51
- ),
52
- margin=dict(l=60, r=150, t=40, b=60), # Extra right margin for legend
53
- plot_bgcolor='white',
54
- hovermode='x unified'
55
  )
56
 
57
  fig.update_xaxes(
@@ -111,19 +110,19 @@ def create_line_plot(
111
  )
112
 
113
  fig.update_layout(
114
- width=1125,
115
- height=225,
116
  showlegend=True,
 
 
 
117
  legend=dict(
118
- orientation="h",
119
- yanchor="bottom",
120
- y=1.02,
121
- xanchor="right",
122
- x=1
123
- ),
124
- margin=dict(l=60, r=60, t=60, b=60),
125
- plot_bgcolor='white',
126
- hovermode='x unified'
127
  )
128
 
129
  fig.update_xaxes(
 
41
  autosize=True,
42
  font_size=14,
43
  showlegend=True,
44
+ margin=dict(l=60, r=60, t=40, b=80), # Increased bottom margin
45
+ plot_bgcolor="white",
46
+ hovermode="x unified",
47
  legend=dict(
48
+ orientation="h", # Horizontal legend
49
+ yanchor="top", # Anchor the top of the legend box
50
+ y=-0.25, # Place it below the plot
51
+ xanchor="center",
52
+ x=0.5
53
+ )
 
 
 
 
54
  )
55
 
56
  fig.update_xaxes(
 
110
  )
111
 
112
  fig.update_layout(
113
+ autosize=True,
114
+ font_size=14,
115
  showlegend=True,
116
+ margin=dict(l=60, r=60, t=40, b=80), # Increased bottom margin
117
+ plot_bgcolor="white",
118
+ hovermode="x unified",
119
  legend=dict(
120
+ orientation="h", # Horizontal legend
121
+ yanchor="top", # Anchor the top of the legend box
122
+ y=-0.25, # Place it below the plot
123
+ xanchor="center",
124
+ x=0.5
125
+ )
 
 
 
126
  )
127
 
128
  fig.update_xaxes(
graphs/model_market_share.py CHANGED
@@ -4,22 +4,12 @@ import plotly.graph_objects as go
4
  from plotly.subplots import make_subplots
5
 
6
  def create_stacked_area_chart(
7
- topk_df, gini_df, hhi_df, events, palette, start_time=None, end_time=None
8
  ):
9
 
10
  # Create subplot with secondary y-axis
11
  fig = make_subplots(specs=[[{"secondary_y": True}]])
12
 
13
- # Define metric order
14
- metric_order = [
15
- "Top 1",
16
- "Top 1 - 10",
17
- "Top 10 - 100",
18
- "Top 100 - 1000",
19
- "Top 1000 - 10000",
20
- "Rest",
21
- ]
22
-
23
  # Create stacked area traces
24
  for i, metric in enumerate(metric_order):
25
  metric_data = topk_df[topk_df["metric"] == metric]
@@ -54,46 +44,46 @@ def create_stacked_area_chart(
54
 
55
  # Add overlay lines
56
  # Gini Coefficient
57
- gini_data = gini_df.sort_values("time")
58
- if start_time:
59
- gini_data = gini_data[gini_data["time"] >= start_time]
60
- if end_time:
61
- gini_data = gini_data[gini_data["time"] <= end_time]
62
- fig.add_trace(
63
- go.Scatter(
64
- x=gini_data["time"],
65
- y=gini_data["value"],
66
- name="Gini Coefficient",
67
- mode="lines",
68
- line=dict(color="#6b46c1", width=3),
69
- yaxis="y2",
70
- hovertemplate="<b>Gini Coefficient</b><br>"
71
- + "Time: %{x}<br>"
72
- + "Value: %{y:.3f}<extra></extra>",
73
- ),
74
- secondary_y=True,
75
- )
76
-
77
- # HHI (×10)
78
- hhi_data = hhi_df.sort_values("time")
79
- if start_time:
80
- hhi_data = hhi_data[hhi_data["time"] >= start_time]
81
- if end_time:
82
- hhi_data = hhi_data[hhi_data["time"] <= end_time]
83
- fig.add_trace(
84
- go.Scatter(
85
- x=hhi_data["time"],
86
- y=hhi_data["value"] * 10,
87
- name="HHI (×10)",
88
- mode="lines",
89
- line=dict(color="#ec4899", width=3),
90
- yaxis="y2",
91
- hovertemplate="<b>HHI (×10)</b><br>"
92
- + "Time: %{x}<br>"
93
- + "Value: %{y:.3f}<extra></extra>",
94
- ),
95
- secondary_y=True,
96
- )
97
 
98
  # Add vertical lines for events
99
  for event_name, event_date in events.items():
@@ -124,11 +114,19 @@ def create_stacked_area_chart(
124
  autosize=True,
125
  font_size=14,
126
  showlegend=True,
127
- margin=dict(l=60, r=60, t=40, b=60),
128
  plot_bgcolor="white",
129
  hovermode="x unified",
 
 
 
 
 
 
 
130
  )
131
 
 
132
  # Update x-axis to be governed by start_time/end_time
133
  xaxis_range = None
134
  if start_time is not None and end_time is not None:
@@ -148,7 +146,7 @@ def create_stacked_area_chart(
148
 
149
  # Update primary y-axis (left)
150
  fig.update_yaxes(
151
- title_text="Model Market Share",
152
  showgrid=True,
153
  gridcolor="lightgray",
154
  gridwidth=1,
@@ -164,7 +162,7 @@ def create_stacked_area_chart(
164
 
165
 
166
  def create_world_map(
167
- df, top_n_labels=20
168
  ):
169
  # Create a filtered_df with only countries
170
  df = df[df['org_country_single'] != 'HF']
@@ -173,8 +171,9 @@ def create_world_map(
173
  df = df[df['org_country_single'] != 'user']
174
 
175
  # Filter out models created after 2024-01-01 and downloads after 2024-01-01
176
- # df = df[df['created'] > '2024-01-01']
177
- # df = df[df['time'] > '2024-01-01']
 
178
 
179
  # Country code mapping
180
  country_code_map = {
@@ -239,8 +238,6 @@ def create_world_map(
239
  .sum()
240
  .reset_index()
241
  )
242
-
243
- print(downloads_by_country.columns)
244
 
245
  # Prepare top countries for annotation
246
  total_downloads = float(downloads_by_country['downloads'].sum())
@@ -297,12 +294,12 @@ def create_world_map(
297
  text="Model Downloads by Country",
298
  x=0.5,
299
  font=dict(size=20),
 
300
  ),
301
  width=1200,
302
- height=800,
303
  plot_bgcolor="#ffffff",
304
  paper_bgcolor="#ffffff",
305
- margin=dict(l=0, r=120, t=100, b=60),
306
  )
307
 
308
  # Update geo layout
 
4
  from plotly.subplots import make_subplots
5
 
6
  def create_stacked_area_chart(
7
+ topk_df, gini_df, hhi_df, events, palette, metric_order, start_time=None, end_time=None
8
  ):
9
 
10
  # Create subplot with secondary y-axis
11
  fig = make_subplots(specs=[[{"secondary_y": True}]])
12
 
 
 
 
 
 
 
 
 
 
 
13
  # Create stacked area traces
14
  for i, metric in enumerate(metric_order):
15
  metric_data = topk_df[topk_df["metric"] == metric]
 
44
 
45
  # Add overlay lines
46
  # Gini Coefficient
47
+ # gini_data = gini_df.sort_values("time")
48
+ # if start_time:
49
+ # gini_data = gini_data[gini_data["time"] >= start_time]
50
+ # if end_time:
51
+ # gini_data = gini_data[gini_data["time"] <= end_time]
52
+ # fig.add_trace(
53
+ # go.Scatter(
54
+ # x=gini_data["time"],
55
+ # y=gini_data["value"],
56
+ # name="Gini Coefficient",
57
+ # mode="lines",
58
+ # line=dict(color="#6b46c1", width=3),
59
+ # yaxis="y2",
60
+ # hovertemplate="<b>Gini Coefficient</b><br>"
61
+ # + "Time: %{x}<br>"
62
+ # + "Value: %{y:.3f}<extra></extra>",
63
+ # ),
64
+ # secondary_y=True,
65
+ # )
66
+
67
+ # # HHI (×10)
68
+ # hhi_data = hhi_df.sort_values("time")
69
+ # if start_time:
70
+ # hhi_data = hhi_data[hhi_data["time"] >= start_time]
71
+ # if end_time:
72
+ # hhi_data = hhi_data[hhi_data["time"] <= end_time]
73
+ # fig.add_trace(
74
+ # go.Scatter(
75
+ # x=hhi_data["time"],
76
+ # y=hhi_data["value"] * 10,
77
+ # name="HHI (×10)",
78
+ # mode="lines",
79
+ # line=dict(color="#ec4899", width=3),
80
+ # yaxis="y2",
81
+ # hovertemplate="<b>HHI (×10)</b><br>"
82
+ # + "Time: %{x}<br>"
83
+ # + "Value: %{y:.3f}<extra></extra>",
84
+ # ),
85
+ # secondary_y=True,
86
+ # )
87
 
88
  # Add vertical lines for events
89
  for event_name, event_date in events.items():
 
114
  autosize=True,
115
  font_size=14,
116
  showlegend=True,
117
+ margin=dict(l=60, r=60, t=40, b=80), # Increased bottom margin
118
  plot_bgcolor="white",
119
  hovermode="x unified",
120
+ legend=dict(
121
+ orientation="h", # Horizontal legend
122
+ yanchor="top", # Anchor the top of the legend box
123
+ y=-0.25, # Place it below the plot
124
+ xanchor="center",
125
+ x=0.5
126
+ )
127
  )
128
 
129
+
130
  # Update x-axis to be governed by start_time/end_time
131
  xaxis_range = None
132
  if start_time is not None and end_time is not None:
 
146
 
147
  # Update primary y-axis (left)
148
  fig.update_yaxes(
149
+ title_text="National Concentration (%)",
150
  showgrid=True,
151
  gridcolor="lightgray",
152
  gridwidth=1,
 
162
 
163
 
164
  def create_world_map(
165
+ df, top_n_labels=20, created_after=None
166
  ):
167
  # Create a filtered_df with only countries
168
  df = df[df['org_country_single'] != 'HF']
 
171
  df = df[df['org_country_single'] != 'user']
172
 
173
  # Filter out models created after 2024-01-01 and downloads after 2024-01-01
174
+ if created_after:
175
+ df = df[df['created'] > created_after]
176
+ df = df[df['time'] > created_after]
177
 
178
  # Country code mapping
179
  country_code_map = {
 
238
  .sum()
239
  .reset_index()
240
  )
 
 
241
 
242
  # Prepare top countries for annotation
243
  total_downloads = float(downloads_by_country['downloads'].sum())
 
294
  text="Model Downloads by Country",
295
  x=0.5,
296
  font=dict(size=20),
297
+ pad=dict(t=10),
298
  ),
299
  width=1200,
300
+ height=700, # Increased height for a larger map
301
  plot_bgcolor="#ffffff",
302
  paper_bgcolor="#ffffff",
 
303
  )
304
 
305
  # Update geo layout
requirements.txt CHANGED
@@ -3,3 +3,4 @@ dash
3
  plotly
4
  gunicorn
5
  dash-mantine-components
 
 
3
  plotly
4
  gunicorn
5
  dash-mantine-components
6
+ dash-bootstrap-components