emsesc commited on
Commit
8206722
·
1 Parent(s): fe161c7

time slider and country map

Browse files
app.py CHANGED
@@ -1,9 +1,8 @@
1
  # Import packages
2
  from dash import Dash, html, dcc, callback, Input, Output
3
  import pandas as pd
4
- import pickle
5
  import plotly.express as px
6
- from graphs.model_market_share import create_plotly_stacked_area_chart
7
  from graphs.model_characteristics import create_plotly_language_concentration_chart, create_plotly_publication_curves_with_legend
8
 
9
  # Incorporate data
@@ -13,21 +12,17 @@ df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/gapmi
13
  app = Dash()
14
  server = app.server
15
 
16
- # Load all pickle files in data_frames/ as loop
17
- with open('data_frames/model_topk_df.pkl', 'rb') as f:
18
- model_topk_df = pickle.load(f)
19
- with open('data_frames/model_gini_df.pkl', 'rb') as f:
20
- model_gini_df = pickle.load(f)
21
- with open('data_frames/model_hhi_df.pkl', 'rb') as f:
22
- model_hhi_df = pickle.load(f)
23
- with open('data_frames/language_concentration_df.pkl', 'rb') as f:
24
- language_concentration_df = pickle.load(f)
25
- with open('data_frames/download_license_cumsum_df.pkl', 'rb') as f:
26
- license_concentration_df = pickle.load(f)
27
- with open('data_frames/download_method_cumsum_df.pkl', 'rb') as f:
28
- download_method_cumsum_df = pickle.load(f)
29
- with open('data_frames/download_arch_cumsum_df.pkl', 'rb') as f:
30
- download_arch_cumsum_df = pickle.load(f)
31
 
32
  TEMP_MODEL_EVENTS = {
33
  # "Yolo World Mirror": "2024-03-01",
@@ -96,12 +91,27 @@ fig5 = create_plotly_publication_curves_with_legend(
96
  download_arch_cumsum_df, ARCHITECTURE_PLOT_CHOICES, PALETTE_0
97
  )
98
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Make global font family
100
  fig.update_layout(font_family="Inter")
101
  fig2.update_layout(font_family="Inter")
102
  fig3.update_layout(font_family="Inter")
103
  fig4.update_layout(font_family="Inter")
104
  fig5.update_layout(font_family="Inter")
 
 
 
105
 
106
  # App layout
107
  app.layout = html.Div(
@@ -111,7 +121,14 @@ app.layout = html.Div(
111
  html.Hr(),
112
  dcc.Tabs([
113
  dcc.Tab(label='Model Market Share', children=[
114
- dcc.Graph(figure=fig, id='stacked-area-chart'),
 
 
 
 
 
 
 
115
  ]),
116
  dcc.Tab(label='Model Characteristics', children=[
117
  dcc.Graph(id='language-concentration-chart'),
@@ -138,6 +155,39 @@ def update_graph(selected_metric):
138
  return fig4
139
  elif selected_metric == 'Architecture':
140
  return fig5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  # Run the app
143
  if __name__ == '__main__':
 
1
  # Import packages
2
  from dash import Dash, html, dcc, callback, Input, Output
3
  import pandas as pd
 
4
  import plotly.express as px
5
+ from graphs.model_market_share import create_plotly_stacked_area_chart, create_plotly_world_map, create_plotly_range_slider
6
  from graphs.model_characteristics import create_plotly_language_concentration_chart, create_plotly_publication_curves_with_legend
7
 
8
  # Incorporate data
 
12
  app = Dash()
13
  server = app.server
14
 
15
+ # Load pre-processed data frames
16
+ filtered_df = pd.read_pickle("data_frames/filtered_df.pkl")
17
+ model_topk_df = pd.read_pickle("data_frames/model_topk_df.pkl")
18
+ model_gini_df = pd.read_pickle("data_frames/model_gini_df.pkl")
19
+ model_hhi_df = pd.read_pickle("data_frames/model_hhi_df.pkl")
20
+ language_concentration_df = pd.read_pickle("data_frames/language_concentration_df.pkl")
21
+ license_concentration_df = pd.read_pickle("data_frames/download_license_cumsum_df.pkl")
22
+ download_method_cumsum_df = pd.read_pickle("data_frames/download_method_cumsum_df.pkl")
23
+ download_arch_cumsum_df = pd.read_pickle("data_frames/download_arch_cumsum_df.pkl")
24
+ nat_topk_df = pd.read_pickle("data_frames/nat_topk_df.pkl")
25
+ country_concentration_df = pd.read_pickle("data_frames/country_concentration_df.pkl")
 
 
 
 
26
 
27
  TEMP_MODEL_EVENTS = {
28
  # "Yolo World Mirror": "2024-03-01",
 
91
  download_arch_cumsum_df, ARCHITECTURE_PLOT_CHOICES, PALETTE_0
92
  )
93
 
94
+ fig6 = create_plotly_world_map(
95
+ country_concentration_df, "time", "metric", "value"
96
+ )
97
+
98
+ slider = create_plotly_range_slider(
99
+ model_topk_df
100
+ )
101
+
102
+ slider2 = create_plotly_range_slider(
103
+ country_concentration_df
104
+ )
105
+
106
  # Make global font family
107
  fig.update_layout(font_family="Inter")
108
  fig2.update_layout(font_family="Inter")
109
  fig3.update_layout(font_family="Inter")
110
  fig4.update_layout(font_family="Inter")
111
  fig5.update_layout(font_family="Inter")
112
+ fig6.update_layout(font_family="Inter")
113
+ slider.update_layout(font_family="Inter")
114
+ slider2.update_layout(font_family="Inter")
115
 
116
  # App layout
117
  app.layout = html.Div(
 
121
  html.Hr(),
122
  dcc.Tabs([
123
  dcc.Tab(label='Model Market Share', children=[
124
+ html.Div([
125
+ dcc.Graph(id='stacked-area-chart'),
126
+ dcc.Graph(figure=slider, id='time-slider-stacked'),
127
+ ]),
128
+ html.Div([
129
+ dcc.Graph(id='world-map-with-slider'),
130
+ dcc.Graph(figure=slider2, id='time-slider'),
131
+ ])
132
  ]),
133
  dcc.Tab(label='Model Characteristics', children=[
134
  dcc.Graph(id='language-concentration-chart'),
 
155
  return fig4
156
  elif selected_metric == 'Architecture':
157
  return fig5
158
+
159
+ @app.callback(
160
+ Output('world-map-with-slider', 'figure'),
161
+ [Input('time-slider', 'relayoutData')]
162
+ )
163
+ def update_map(relayout_data):
164
+ if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
165
+ start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
166
+ end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
167
+ updated_fig = create_plotly_world_map(
168
+ country_concentration_df, "time", "metric", "value", start_time=start_time, end_time=end_time
169
+ )
170
+ updated_fig.update_layout(font_family="Inter")
171
+ return updated_fig
172
+ else:
173
+ return fig6
174
+
175
+ @app.callback(
176
+ Output('stacked-area-chart', 'figure'),
177
+ [Input('time-slider-stacked', 'relayoutData')]
178
+ )
179
+ def update_stacked_area(relayout_data):
180
+ if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
181
+ start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
182
+ end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
183
+ updated_fig = create_plotly_stacked_area_chart(
184
+ model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0,
185
+ start_time=start_time, end_time=end_time
186
+ )
187
+ updated_fig.update_layout(font_family="Inter")
188
+ return updated_fig
189
+ else:
190
+ return fig
191
 
192
  # Run the app
193
  if __name__ == '__main__':
graphs/__pycache__/model_characteristics.cpython-39.pyc CHANGED
Binary files a/graphs/__pycache__/model_characteristics.cpython-39.pyc and b/graphs/__pycache__/model_characteristics.cpython-39.pyc differ
 
graphs/__pycache__/model_market_share.cpython-39.pyc CHANGED
Binary files a/graphs/__pycache__/model_market_share.cpython-39.pyc and b/graphs/__pycache__/model_market_share.cpython-39.pyc differ
 
graphs/model_characteristics.py CHANGED
@@ -1,4 +1,5 @@
1
  import plotly.graph_objects as go
 
2
 
3
  def create_plotly_language_concentration_chart(
4
  language_concentration_df,
 
1
  import plotly.graph_objects as go
2
+ import plotly.express as px
3
 
4
  def create_plotly_language_concentration_chart(
5
  language_concentration_df,
graphs/model_market_share.py CHANGED
@@ -1,98 +1,114 @@
1
  import plotly.graph_objects as go
2
  from plotly.subplots import make_subplots
 
3
 
4
  def create_plotly_stacked_area_chart(
5
- model_topk_df,
6
- model_gini_df,
7
- model_hhi_df,
8
- TEMP_MODEL_EVENTS,
9
- PALETTE_0
10
  ):
11
  """
12
  Convert the visualization_util stacked area chart to Plotly
13
  """
14
-
15
  # Create subplot with secondary y-axis
16
  fig = make_subplots(specs=[[{"secondary_y": True}]])
17
-
18
  # Define metric order
19
- metric_order = ['Top 1', 'Top 1 - 10', 'Top 10 - 100', 'Top 100 - 1000', 'Top 1000 - 10000', 'Rest']
20
-
21
- # Get unique time periods
22
- time_periods = sorted(model_topk_df['time'].unique())
23
-
 
 
 
 
24
  # Create stacked area traces
25
  for i, metric in enumerate(metric_order):
26
- metric_data = model_topk_df[model_topk_df['metric'] == metric]
27
-
28
  # Sort by time and get values
29
- metric_data = metric_data.sort_values('time')
30
- x_vals = metric_data['time']
31
- y_vals = metric_data['value']
32
-
 
 
 
 
 
33
  # Add area trace
34
  fig.add_trace(
35
  go.Scatter(
36
  x=x_vals,
37
  y=y_vals,
38
  name=metric,
39
- mode='lines',
40
  line=dict(width=0, color=PALETTE_0[i % len(PALETTE_0)]),
41
- fill='tonexty' if i > 0 else 'tozeroy',
42
  fillcolor=PALETTE_0[i % len(PALETTE_0)], # Add opacity
43
- stackgroup='one',
44
- hovertemplate='<b>%{fullData.name}</b><br>' +
45
- 'Time: %{x}<br>' +
46
- 'Value: %{y}<extra></extra>'
47
  ),
48
- secondary_y=False
49
  )
50
-
51
  # Add overlay lines
52
  # Gini Coefficient
53
- gini_data = model_gini_df.sort_values('time')
 
 
 
 
54
  fig.add_trace(
55
  go.Scatter(
56
- x=gini_data['time'],
57
- y=gini_data['value'],
58
- name='Gini Coefficient',
59
- mode='lines',
60
- line=dict(color='#6b46c1', width=3),
61
- yaxis='y2',
62
- hovertemplate='<b>Gini Coefficient</b><br>' +
63
- 'Time: %{x}<br>' +
64
- 'Value: %{y:.3f}<extra></extra>'
65
  ),
66
- secondary_y=True
67
  )
68
-
69
  # HHI (×10)
70
- hhi_data = model_hhi_df.sort_values('time')
 
 
 
 
71
  fig.add_trace(
72
  go.Scatter(
73
- x=hhi_data['time'],
74
- y=hhi_data['value'] * 10, # Multiply by 10 as indicated
75
- name='HHI (×10)',
76
- mode='lines',
77
- line=dict(color='#ec4899', width=3),
78
- yaxis='y2',
79
- hovertemplate='<b>HHI (×10)</b><br>' +
80
- 'Time: %{x}<br>' +
81
- 'Value: %{y:.3f}<extra></extra>'
82
  ),
83
- secondary_y=True
84
  )
85
-
86
  # Add vertical lines for events
87
  for event_name, event_date in TEMP_MODEL_EVENTS.items():
88
  fig.add_shape(
89
  type="line",
90
- x0=event_date, x1=event_date,
91
- y0=0, y1=1,
 
 
92
  yref="paper",
93
- line=dict(color='#333333', width=2, dash='dash')
94
  )
95
-
96
  # Add annotation for the event
97
  fig.add_annotation(
98
  x=event_date,
@@ -101,9 +117,9 @@ def create_plotly_stacked_area_chart(
101
  text=event_name,
102
  showarrow=False,
103
  yshift=10,
104
- font=dict(size=12)
105
  )
106
-
107
  # Update layout
108
  fig.update_layout(
109
  autosize=True,
@@ -111,32 +127,413 @@ def create_plotly_stacked_area_chart(
111
  font_size=14,
112
  showlegend=False, # Set to True if you want to show legend
113
  margin=dict(l=60, r=60, t=40, b=60),
114
- plot_bgcolor='white',
115
- hovermode='x unified'
116
  )
117
-
118
- # Update x-axis
 
 
 
 
 
 
 
 
119
  fig.update_xaxes(
120
  title_text="",
121
  showgrid=True,
122
- gridcolor='lightgray',
123
- gridwidth=1
 
124
  )
125
-
126
  # Update primary y-axis (left)
127
  fig.update_yaxes(
128
  title_text="Model Market Share",
129
  showgrid=True,
130
- gridcolor='lightgray',
131
  gridwidth=1,
132
- secondary_y=False
133
  )
134
-
135
  # Update secondary y-axis (right)
136
  fig.update_yaxes(
137
- title_text="Concentration Indices",
138
- showgrid=False,
139
- secondary_y=True
140
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  return fig
 
 
1
  import plotly.graph_objects as go
2
  from plotly.subplots import make_subplots
3
+ import pandas as pd
4
 
5
  def create_plotly_stacked_area_chart(
6
+ model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0, start_time=None, end_time=None
 
 
 
 
7
  ):
8
  """
9
  Convert the visualization_util stacked area chart to Plotly
10
  """
11
+
12
  # Create subplot with secondary y-axis
13
  fig = make_subplots(specs=[[{"secondary_y": True}]])
14
+
15
  # Define metric order
16
+ metric_order = [
17
+ "Top 1",
18
+ "Top 1 - 10",
19
+ "Top 10 - 100",
20
+ "Top 100 - 1000",
21
+ "Top 1000 - 10000",
22
+ "Rest",
23
+ ]
24
+
25
  # Create stacked area traces
26
  for i, metric in enumerate(metric_order):
27
+ metric_data = model_topk_df[model_topk_df["metric"] == metric]
28
+
29
  # Sort by time and get values
30
+ metric_data = metric_data.sort_values("time")
31
+ if start_time:
32
+ metric_data = metric_data[metric_data["time"] >= start_time]
33
+ if end_time:
34
+ metric_data = metric_data[metric_data["time"] <= end_time]
35
+
36
+ x_vals = metric_data["time"]
37
+ y_vals = metric_data["value"]
38
+
39
  # Add area trace
40
  fig.add_trace(
41
  go.Scatter(
42
  x=x_vals,
43
  y=y_vals,
44
  name=metric,
45
+ mode="lines",
46
  line=dict(width=0, color=PALETTE_0[i % len(PALETTE_0)]),
47
+ fill="tonexty" if i > 0 else "tozeroy",
48
  fillcolor=PALETTE_0[i % len(PALETTE_0)], # Add opacity
49
+ stackgroup="one",
50
+ hovertemplate="<b>%{fullData.name}</b><br>"
51
+ + "Time: %{x}<br>"
52
+ + "Value: %{y}<extra></extra>",
53
  ),
54
+ secondary_y=False,
55
  )
56
+
57
  # Add overlay lines
58
  # Gini Coefficient
59
+ gini_data = model_gini_df.sort_values("time")
60
+ if start_time:
61
+ gini_data = gini_data[gini_data["time"] >= start_time]
62
+ if end_time:
63
+ gini_data = gini_data[gini_data["time"] <= end_time]
64
  fig.add_trace(
65
  go.Scatter(
66
+ x=gini_data["time"],
67
+ y=gini_data["value"],
68
+ name="Gini Coefficient",
69
+ mode="lines",
70
+ line=dict(color="#6b46c1", width=3),
71
+ yaxis="y2",
72
+ hovertemplate="<b>Gini Coefficient</b><br>"
73
+ + "Time: %{x}<br>"
74
+ + "Value: %{y:.3f}<extra></extra>",
75
  ),
76
+ secondary_y=True,
77
  )
78
+
79
  # HHI (×10)
80
+ hhi_data = model_hhi_df.sort_values("time")
81
+ if start_time:
82
+ hhi_data = hhi_data[hhi_data["time"] >= start_time]
83
+ if end_time:
84
+ hhi_data = hhi_data[hhi_data["time"] <= end_time]
85
  fig.add_trace(
86
  go.Scatter(
87
+ x=hhi_data["time"],
88
+ y=hhi_data["value"] * 10, # Multiply by 10 as indicated
89
+ name="HHI (×10)",
90
+ mode="lines",
91
+ line=dict(color="#ec4899", width=3),
92
+ yaxis="y2",
93
+ hovertemplate="<b>HHI (×10)</b><br>"
94
+ + "Time: %{x}<br>"
95
+ + "Value: %{y:.3f}<extra></extra>",
96
  ),
97
+ secondary_y=True,
98
  )
99
+
100
  # Add vertical lines for events
101
  for event_name, event_date in TEMP_MODEL_EVENTS.items():
102
  fig.add_shape(
103
  type="line",
104
+ x0=event_date,
105
+ x1=event_date,
106
+ y0=0,
107
+ y1=1,
108
  yref="paper",
109
+ line=dict(color="#333333", width=2, dash="dash"),
110
  )
111
+
112
  # Add annotation for the event
113
  fig.add_annotation(
114
  x=event_date,
 
117
  text=event_name,
118
  showarrow=False,
119
  yshift=10,
120
+ font=dict(size=12),
121
  )
122
+
123
  # Update layout
124
  fig.update_layout(
125
  autosize=True,
 
127
  font_size=14,
128
  showlegend=False, # Set to True if you want to show legend
129
  margin=dict(l=60, r=60, t=40, b=60),
130
+ plot_bgcolor="white",
131
+ hovermode="x unified",
132
  )
133
+
134
+ # Update x-axis to be governed by start_time/end_time
135
+ xaxis_range = None
136
+ if start_time is not None and end_time is not None:
137
+ xaxis_range = [start_time, end_time]
138
+ elif start_time is not None:
139
+ xaxis_range = [start_time, None]
140
+ elif end_time is not None:
141
+ xaxis_range = [None, end_time]
142
+
143
  fig.update_xaxes(
144
  title_text="",
145
  showgrid=True,
146
+ gridcolor="lightgray",
147
+ gridwidth=1,
148
+ range=xaxis_range,
149
  )
150
+
151
  # Update primary y-axis (left)
152
  fig.update_yaxes(
153
  title_text="Model Market Share",
154
  showgrid=True,
155
+ gridcolor="lightgray",
156
  gridwidth=1,
157
+ secondary_y=False,
158
  )
159
+
160
  # Update secondary y-axis (right)
161
  fig.update_yaxes(
162
+ title_text="Concentration Indices", showgrid=False, secondary_y=True
 
 
163
  )
164
+
165
+ return fig
166
+
167
+
168
+ def create_plotly_world_map(
169
+ df, time_col="time", metric_col="metric", value_col="value", top_n_labels=10, start_time=None, end_time=None
170
+ ):
171
+ # Get all unique times and sort them
172
+ times = sorted(df[time_col].unique())
173
+
174
+ # Create aggregated data across the full time range initially
175
+ regions_to_exclude = [
176
+ "Asia",
177
+ "Europe",
178
+ "North America",
179
+ "South America",
180
+ "Africa",
181
+ "Oceania",
182
+ "Middle East",
183
+ "Unknown",
184
+ "Online",
185
+ "International",
186
+ "HF",
187
+ ]
188
+
189
+ # Filter out regions
190
+ country_data = df[~df[metric_col].isin(regions_to_exclude)].copy()
191
+
192
+ # Country code mapping
193
+ country_code_map = {
194
+ "Germany": "DEU",
195
+ "United States of America": "USA",
196
+ "China": "CHN",
197
+ "France": "FRA",
198
+ "India": "IND",
199
+ "Israel": "ISR",
200
+ "South Korea": "KOR",
201
+ "United Kingdom": "GBR",
202
+ "Switzerland": "CHE",
203
+ "United Arab Emirates": "ARE",
204
+ "Vietnam": "VNM",
205
+ "Singapore": "SGP",
206
+ "Chile": "CHL",
207
+ "Hong Kong": "HKG",
208
+ "Japan": "JPN",
209
+ "Canada": "CAN",
210
+ "Spain": "ESP",
211
+ "Finland": "FIN",
212
+ "Indonesia": "IDN",
213
+ "Russia": "RUS",
214
+ "Iran": "IRN",
215
+ "Belarus": "BLR",
216
+ "Thailand": "THA",
217
+ "UAE": "ARE",
218
+ "Argentina": "ARG",
219
+ "Iceland": "ISL",
220
+ "Poland": "POL",
221
+ "Sweden": "SWE",
222
+ "Taiwan": "TWN",
223
+ "Lebanon": "LBN",
224
+ "Algeria": "DZA",
225
+ "Bulgaria": "BGR",
226
+ "Norway": "NOR",
227
+ "Netherlands": "NLD",
228
+ "Hungary": "HUN",
229
+ "Estonia": "EST",
230
+ "Qatar": "QAT",
231
+ "Brazil": "BRA",
232
+ "Morocco": "MAR",
233
+ "Slovenia": "SVN",
234
+ "Ghana": "GHA",
235
+ "Uganda": "UGA",
236
+ "Turkey": "TUR",
237
+ }
238
+
239
+ country_data["country_code"] = country_data[metric_col].map(country_code_map)
240
+ mapped_data = country_data.dropna(subset=["country_code"])
241
+
242
+ # Create subplot with secondary plot for range slider
243
+ fig = make_subplots(
244
+ rows=2,
245
+ cols=1,
246
+ row_heights=[0.85, 0.15],
247
+ vertical_spacing=0.02,
248
+ specs=[[{"type": "geo"}], [{"type": "scatter"}]],
249
+ )
250
+
251
+ # Function to aggregate data for time range
252
+ def aggregate_time_range(start_time, end_time):
253
+ range_data = mapped_data[
254
+ (mapped_data[time_col] >= start_time) & (mapped_data[time_col] <= end_time)
255
+ ]
256
+ # Average values across time range
257
+ agg_data = (
258
+ range_data.groupby([metric_col, "country_code"])[value_col]
259
+ .mean()
260
+ .reset_index()
261
+ )
262
+ agg_data["percentage"] = agg_data[value_col] * 100
263
+ return agg_data.sort_values("percentage", ascending=False)
264
+
265
+ # Initial data (full range)
266
+ if start_time is None:
267
+ start_time = times[0]
268
+ if end_time is None:
269
+ end_time = times[-1]
270
+ initial_data = aggregate_time_range(start_time, end_time)
271
+ top_countries = initial_data.head(top_n_labels)
272
+
273
+ # Create hover text
274
+ hover_text = []
275
+ for _, row in initial_data.iterrows():
276
+ hover_text.append(
277
+ f"<b>{row[metric_col]}</b><br>"
278
+ f"Avg Downloads: {row['percentage']:.1f}% of total<br>"
279
+ f"Avg Value: {row[value_col]:.6f}"
280
+ )
281
+
282
+ # Add choropleth to first subplot
283
+ fig.add_trace(
284
+ go.Choropleth(
285
+ locations=initial_data["country_code"],
286
+ z=initial_data["percentage"],
287
+ text=hover_text,
288
+ hovertemplate="%{text}<extra></extra>",
289
+ colorscale=[
290
+ "#001219",
291
+ "#0a9396",
292
+ "#94d2bd",
293
+ "#e9d8a6",
294
+ "#ee9b00",
295
+ "#ca6702",
296
+ "#bb3e03",
297
+ "#9b2226",
298
+ ],
299
+ colorbar=dict(
300
+ title="Avg % of Total Downloads",
301
+ tickfont=dict(size=12, family="Inter, system-ui, sans-serif"),
302
+ len=0.6,
303
+ x=1.02,
304
+ y=0.7,
305
+ ),
306
+ marker_line_color="#219ebc",
307
+ marker_line_width=0.4,
308
+ geo="geo",
309
+ ),
310
+ row=1,
311
+ col=1,
312
+ )
313
+
314
+ # Country center coordinates for labels
315
+ country_centers = {
316
+ "USA": {"lat": 39.8, "lon": -98.5},
317
+ "CHN": {"lat": 35.8, "lon": 104.2},
318
+ "DEU": {"lat": 51.2, "lon": 10.4},
319
+ "GBR": {"lat": 55.4, "lon": -3.4},
320
+ "FRA": {"lat": 46.6, "lon": 2.2},
321
+ "JPN": {"lat": 36.2, "lon": 138.3},
322
+ "IND": {"lat": 20.6, "lon": 78.9},
323
+ "CAN": {"lat": 56.1, "lon": -106.3},
324
+ "RUS": {"lat": 61.5, "lon": 105.3},
325
+ "BRA": {"lat": -14.2, "lon": -51.9},
326
+ "AUS": {"lat": -25.3, "lon": 133.8},
327
+ "KOR": {"lat": 35.9, "lon": 127.8},
328
+ }
329
+
330
+ # Add initial labels using scattergeo instead of annotations
331
+ label_lons = []
332
+ label_lats = []
333
+ label_texts = []
334
+
335
+ for _, country in top_countries.iterrows():
336
+ country_code = country["country_code"]
337
+ if country_code in country_centers:
338
+ center = country_centers[country_code]
339
+ label_lons.append(center["lon"])
340
+ label_lats.append(center["lat"])
341
+ label_texts.append(f"{country['percentage']:.1f}%")
342
+
343
+ # Add text labels as a scattergeo trace
344
+ fig.add_trace(
345
+ go.Scattergeo(
346
+ lon=label_lons,
347
+ lat=label_lats,
348
+ text=label_texts,
349
+ mode="text",
350
+ textfont=dict(
351
+ color="#ffffff", size=13, family="Inter, system-ui, sans-serif"
352
+ ),
353
+ textposition="middle center",
354
+ showlegend=False,
355
+ hoverinfo="skip",
356
+ geo="geo",
357
+ ),
358
+ row=1,
359
+ col=1,
360
+ )
361
+
362
+ # Add background circles for better text visibility
363
+ # fig.add_trace(
364
+ # go.Scattergeo(
365
+ # lon=label_lons,
366
+ # lat=label_lats,
367
+ # mode='markers',
368
+ # marker=dict(
369
+ # size=20,
370
+ # color='rgba(2, 48, 71, 0.9)',
371
+ # line=dict(color='#8ecae6', width=1.5)
372
+ # ),
373
+ # showlegend=False,
374
+ # hoverinfo='skip',
375
+ # geo="geo"
376
+ # ),
377
+ # row=1, col=1
378
+ # )
379
+
380
+ # Update layout
381
+ fig.update_layout(
382
+ title=dict(
383
+ text=f"Model Downloads by Country - Time Range Analysis<br><sub>Select time range below to update map</sub>",
384
+ x=0.5,
385
+ font=dict(size=20, family="Inter, system-ui, sans-serif", color="#212529"),
386
+ ),
387
+ width=1200,
388
+ height=800,
389
+ font=dict(family="Inter, system-ui, sans-serif"),
390
+ plot_bgcolor="#ffffff",
391
+ paper_bgcolor="#ffffff",
392
+ margin=dict(l=0, r=120, t=100, b=60),
393
+ )
394
+
395
+ # Update geo layout
396
+ fig.update_geos(
397
+ showframe=False,
398
+ showcoastlines=True,
399
+ showland=True,
400
+ landcolor="#f8f9fa",
401
+ coastlinecolor="#023047",
402
+ oceancolor="#8ecae6",
403
+ projection_type="equirectangular",
404
+ bgcolor="#ffffff",
405
+ )
406
+
407
+ return fig
408
+
409
+ def create_plotly_range_slider(df):
410
+ """
411
+ Creates a standalone time range slider using Plotly.
412
+
413
+ Args:
414
+ df (pd.DataFrame): A DataFrame with a "time" column containing datetime data.
415
+
416
+ Returns:
417
+ go.Figure: A Plotly Figure object with a functional range slider.
418
+ """
419
+ if df.empty or "time" not in df.columns:
420
+ return go.Figure()
421
+
422
+ times = sorted(df["time"].unique())
423
+
424
+ fig = go.Figure()
425
+
426
+ # Invisible trace just to attach slider to the x-axis
427
+ fig.add_trace(
428
+ go.Scatter(
429
+ x=times,
430
+ y=[0] * len(times), # Dummy y-values
431
+ mode="lines",
432
+ line=dict(color="rgba(0,0,0,0)"), # Invisible line
433
+ hoverinfo="skip",
434
+ showlegend=False
435
+ )
436
+ )
437
+
438
+ # Enable range slider
439
+ fig.update_layout(
440
+ xaxis=dict(
441
+ rangeslider=dict(visible=False),
442
+ type="date"
443
+ ),
444
+ yaxis=dict(visible=False), # Hide y-axis since it's dummy
445
+ margin=dict(t=20, b=20, l=20, r=20),
446
+ height=100 # Compact slider-only view
447
+ )
448
+
449
+ return fig
450
+
451
+ def create_leaderboard(df, start_time, end_time, top_n=10):
452
+ # Ensure datetime
453
+ df["time"] = pd.to_datetime(df["time"])
454
 
455
+ # Filter time range
456
+ mask = (df["time"] >= pd.to_datetime(start_time)) & (df["time"] <= pd.to_datetime(end_time))
457
+ df_filtered = df.loc[mask]
458
+
459
+ if df_filtered.empty:
460
+ return go.Figure()
461
+
462
+ # Top N countries
463
+ top_countries = (
464
+ df_filtered["country"]
465
+ .value_counts(normalize=True)
466
+ .mul(100)
467
+ .reset_index(name="% of total")
468
+ .rename(columns={"index": "Country"})
469
+ .head(top_n)
470
+ )
471
+
472
+ # Top N developers
473
+ top_developers = (
474
+ df_filtered["developer"]
475
+ .value_counts(normalize=True)
476
+ .mul(100)
477
+ .reset_index(name="% of total")
478
+ .rename(columns={"index": "Developer"})
479
+ .head(top_n)
480
+ )
481
+
482
+ # Top N models
483
+ top_models = (
484
+ df_filtered["model"]
485
+ .value_counts(normalize=True)
486
+ .mul(100)
487
+ .reset_index(name="% of total")
488
+ .rename(columns={"index": "Model"})
489
+ .head(top_n)
490
+ )
491
+
492
+ # Create subplot grid with 3 columns
493
+ fig = make_subplots(
494
+ rows=1, cols=3,
495
+ subplot_titles=("Top Countries", "Top Developers", "Top Models"),
496
+ specs=[[{"type": "table"}, {"type": "table"}, {"type": "table"}]]
497
+ )
498
+
499
+ # Add country table
500
+ fig.add_trace(
501
+ go.Table(
502
+ header=dict(values=list(top_countries.columns),
503
+ fill_color="lightgrey", align="left"),
504
+ cells=dict(values=[top_countries[col] for col in top_countries.columns],
505
+ fill_color="white", align="left"),
506
+ ),
507
+ row=1, col=1
508
+ )
509
+
510
+ # Add developer table
511
+ fig.add_trace(
512
+ go.Table(
513
+ header=dict(values=list(top_developers.columns),
514
+ fill_color="lightgrey", align="left"),
515
+ cells=dict(values=[top_developers[col] for col in top_developers.columns],
516
+ fill_color="white", align="left"),
517
+ ),
518
+ row=1, col=2
519
+ )
520
+
521
+ # Add model table
522
+ fig.add_trace(
523
+ go.Table(
524
+ header=dict(values=list(top_models.columns),
525
+ fill_color="lightgrey", align="left"),
526
+ cells=dict(values=[top_models[col] for col in top_models.columns],
527
+ fill_color="white", align="left"),
528
+ ),
529
+ row=1, col=3
530
+ )
531
+
532
+ fig.update_layout(
533
+ height=400,
534
+ showlegend=False,
535
+ title_text=f"Leaderboards ({start_time} → {end_time})"
536
+ )
537
+
538
  return fig
539
+