emsesc commited on
Commit
6054b77
·
1 Parent(s): b74c315

leaderboard, tree, time slider, need to clean up

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from dash import Dash, html, dcc, Input, Output
2
  import pandas as pd
 
3
  from graphs.model_market_share import create_stacked_area_chart, create_world_map, create_range_slider
4
  from graphs.leaderboard import create_leaderboard
5
  from graphs.model_characteristics import create_concentration_chart, create_line_plot
@@ -89,6 +90,26 @@ slider = create_range_slider(
89
  model_topk_df
90
  )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Model Characteristics Tab
93
  language_concentration_area = create_concentration_chart(
94
  language_concentration_df, 'time', 'metric', 'value', LANG_SEGMENT_ORDER, PALETTE_0
@@ -111,7 +132,11 @@ tree_map = generate_model_treemap(
111
  )
112
 
113
  # App layout
114
- app.layout = html.Div(
 
 
 
 
115
  [
116
  html.Div(
117
  [
@@ -128,8 +153,8 @@ app.layout = html.Div(
128
  'padding': '4px 14px',
129
  'fontSize': 13,
130
  'color': 'white',
131
- 'backgroundColor': '#2563eb',
132
  'border': 'none',
 
133
  'borderRadius': '18px',
134
  'textDecoration': 'none',
135
  'fontWeight': 'bold',
@@ -154,7 +179,7 @@ app.layout = html.Div(
154
  dcc.Tab(label='Model Market Share', children=[
155
  html.Div([
156
  html.Div(children='Select time range to update all graphs below:', style={'fontSize': 16, 'marginBottom': 6, 'marginTop': 20}),
157
- dcc.Graph(figure=slider, id='time-slider', style={'height': '100px'}),
158
  html.Div(
159
  id='output-container-range-slider',
160
  style={
@@ -210,7 +235,7 @@ app.layout = html.Div(
210
  )
211
  ],
212
  style={'fontFamily': 'Inter', 'backgroundColor': '#f7f7fa', 'minHeight': '100vh'}
213
- )
214
 
215
  # Callbacks for interactivity
216
 
@@ -218,40 +243,32 @@ app.layout = html.Div(
218
  # On slider change, update output text
219
  @app.callback(
220
  Output('output-container-range-slider', 'children'),
221
- [Input('time-slider', 'relayoutData')]
222
  )
223
- def update_output(relayout_data):
224
- def format_date(date_str):
225
- date = pd.to_datetime(date_str)
226
- return date.strftime('%b {S}, %Y').replace('{S}', str(date.day) + (
227
- 'th' if 11 <= date.day <= 13 else {1: 'st', 2: 'nd', 3: 'rd'}.get(date.day % 10, 'th')
228
- ))
229
-
230
- if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
231
- start_time = format_date(relayout_data['xaxis.range[0]'])
232
- end_time = format_date(relayout_data['xaxis.range[1]'])
233
- else:
234
- # Earliest and latest dates in the dataset
235
- start_time = format_date(model_topk_df['time'].min())
236
- end_time = format_date(model_topk_df['time'].max())
237
- return f'{start_time} to {end_time}'
238
 
239
  # On slider change, update world map
240
  @app.callback(
241
  Output('world-map-with-slider', 'figure'),
242
- [Input('time-slider', 'relayoutData')]
243
  )
244
- def update_map(relayout_data):
245
- if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
246
- start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
247
- end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
248
  updated_fig = create_world_map(
249
- country_concentration_df, "time", "metric", "value", start_time=start_time, end_time=end_time
 
250
  )
251
  updated_fig.update_layout(font_family="Inter")
252
  return updated_fig
253
- else:
254
- return world_map
255
 
256
  # On slider change, update leaderboard
257
  # @app.callback(
@@ -273,20 +290,21 @@ def update_map(relayout_data):
273
  # On slider change, update stacked area chart
274
  @app.callback(
275
  Output('stacked-area-chart', 'figure'),
276
- [Input('time-slider', 'relayoutData')]
277
  )
278
- def update_stacked_area(relayout_data):
279
- if relayout_data and 'xaxis.range[0]' in relayout_data and 'xaxis.range[1]' in relayout_data:
280
- start_time = pd.to_datetime(relayout_data['xaxis.range[0]']).strftime('%Y-%m-%d')
281
- end_time = pd.to_datetime(relayout_data['xaxis.range[1]']).strftime('%Y-%m-%d')
282
  updated_fig = create_stacked_area_chart(
283
- model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0,
 
284
  start_time=start_time, end_time=end_time
285
  )
286
  updated_fig.update_layout(font_family="Inter")
287
  return updated_fig
288
- else:
289
- return model_market_share_area
290
 
291
  # Model Characteristics Tab
292
  # On dropdown change, update graph
 
1
  from dash import Dash, html, dcc, Input, Output
2
  import pandas as pd
3
+ import dash_mantine_components as dmc
4
  from graphs.model_market_share import create_stacked_area_chart, create_world_map, create_range_slider
5
  from graphs.leaderboard import create_leaderboard
6
  from graphs.model_characteristics import create_concentration_chart, create_line_plot
 
90
  model_topk_df
91
  )
92
 
93
+ time_slider = dmc.RangeSlider(
94
+ id="time-slider",
95
+ min=model_topk_df['time'].min().timestamp(),
96
+ max=model_topk_df['time'].max().timestamp(),
97
+ value=[
98
+ model_topk_df['time'].min().timestamp(),
99
+ model_topk_df['time'].max().timestamp()
100
+ ],
101
+ step=24 * 60 * 60,
102
+ color="blue",
103
+ size="md",
104
+ radius="xl",
105
+ marks=[
106
+ {"value": model_topk_df['time'].min().timestamp(), "label": model_topk_df['time'].min().strftime("%b %Y")},
107
+ {"value": model_topk_df['time'].max().timestamp(), "label": model_topk_df['time'].max().strftime("%b %Y")}
108
+ ],
109
+ style={"width": "70%", "margin": "0 auto"},
110
+ labelAlwaysOn=False
111
+ )
112
+
113
  # Model Characteristics Tab
114
  language_concentration_area = create_concentration_chart(
115
  language_concentration_df, 'time', 'metric', 'value', LANG_SEGMENT_ORDER, PALETTE_0
 
132
  )
133
 
134
  # App layout
135
+ app.layout = dmc.MantineProvider(
136
+ theme={"colorScheme": "light",
137
+ "primaryColor": "blue",
138
+ "fontFamily": "Inter, sans-serif"},
139
+ children=[html.Div(
140
  [
141
  html.Div(
142
  [
 
153
  'padding': '4px 14px',
154
  'fontSize': 13,
155
  'color': 'white',
 
156
  'border': 'none',
157
+ 'backgroundColor': '#228BE6',
158
  'borderRadius': '18px',
159
  'textDecoration': 'none',
160
  'fontWeight': 'bold',
 
179
  dcc.Tab(label='Model Market Share', children=[
180
  html.Div([
181
  html.Div(children='Select time range to update all graphs below:', style={'fontSize': 16, 'marginBottom': 6, 'marginTop': 20}),
182
+ time_slider,
183
  html.Div(
184
  id='output-container-range-slider',
185
  style={
 
235
  )
236
  ],
237
  style={'fontFamily': 'Inter', 'backgroundColor': '#f7f7fa', 'minHeight': '100vh'}
238
+ )])
239
 
240
  # Callbacks for interactivity
241
 
 
243
  # On slider change, update output text
244
  @app.callback(
245
  Output('output-container-range-slider', 'children'),
246
+ [Input('time-slider', 'value')]
247
  )
248
+ def update_output(value):
249
+ if value and len(value) == 2:
250
+ start_time = pd.to_datetime(value[0], unit='s').strftime("%b %d, %Y")
251
+ end_time = pd.to_datetime(value[1], unit='s').strftime("%b %d, %Y")
252
+ return f"Selected time range: {start_time} to {end_time}"
253
+ return "Select a time range"
 
 
 
 
 
 
 
 
 
254
 
255
  # On slider change, update world map
256
  @app.callback(
257
  Output('world-map-with-slider', 'figure'),
258
+ Input('time-slider', 'value')
259
  )
260
+ def update_world_map(value):
261
+ if value and len(value) == 2:
262
+ start_time = pd.to_datetime(value[0], unit='s').strftime('%Y-%m-%d')
263
+ end_time = pd.to_datetime(value[1], unit='s').strftime('%Y-%m-%d')
264
  updated_fig = create_world_map(
265
+ country_concentration_df, "time", "metric", "value",
266
+ start_time=start_time, end_time=end_time
267
  )
268
  updated_fig.update_layout(font_family="Inter")
269
  return updated_fig
270
+ return world_map
271
+
272
 
273
  # On slider change, update leaderboard
274
  # @app.callback(
 
290
  # On slider change, update stacked area chart
291
  @app.callback(
292
  Output('stacked-area-chart', 'figure'),
293
+ Input('time-slider', 'value')
294
  )
295
+ def update_stacked_area(value):
296
+ if value and len(value) == 2:
297
+ start_time = pd.to_datetime(value[0], unit='s').strftime('%Y-%m-%d')
298
+ end_time = pd.to_datetime(value[1], unit='s').strftime('%Y-%m-%d')
299
  updated_fig = create_stacked_area_chart(
300
+ model_topk_df, model_gini_df, model_hhi_df,
301
+ TEMP_MODEL_EVENTS, PALETTE_0,
302
  start_time=start_time, end_time=end_time
303
  )
304
  updated_fig.update_layout(font_family="Inter")
305
  return updated_fig
306
+ return model_market_share_area
307
+
308
 
309
  # Model Characteristics Tab
310
  # On dropdown change, update graph
graphs/__pycache__/model_market_share.cpython-39.pyc CHANGED
Binary files a/graphs/__pycache__/model_market_share.cpython-39.pyc and b/graphs/__pycache__/model_market_share.cpython-39.pyc differ
 
graphs/leaderboard.py CHANGED
@@ -1,5 +1,6 @@
1
  import pandas as pd
2
- from dash import html
 
3
 
4
  def create_leaderboard(filtered_df, country_df, developer_df, model_df, start_time=None, end_time=None, top_n=10):
5
  country_icon_map = {
@@ -33,21 +34,24 @@ def create_leaderboard(filtered_df, country_df, developer_df, model_df, start_ti
33
  for df in [country_df, developer_df, model_df]:
34
  df["time"] = pd.to_datetime(df["time"])
35
 
 
 
 
36
  # Merge country info for developers/models
37
  developer_df = developer_df.merge(
38
- filtered_df[["country", "author", "org_or_user", "model", "downloads"]].drop_duplicates(subset=["author"]),
39
  left_on="metric", right_on="author", how="left"
40
  ).drop(columns=["metric"])
41
 
42
  model_df = model_df.merge(
43
- filtered_df[["country", "author", "downloads", "org_or_user", "model", "merged_modality"]].drop_duplicates(subset=["model"]),
44
  left_on="metric", right_on="model", how="left"
45
  ).drop(columns=["metric"])
46
 
47
  # Rename metric columns
48
  # country_df = country_df.rename(columns={"metric": "country"})
49
  country_df = country_df.merge(
50
- filtered_df[["country", "downloads"]].drop_duplicates(subset=["country"]),
51
  left_on="metric", right_on="country", how="left"
52
  ).drop(columns=["metric"])
53
 
@@ -74,17 +78,25 @@ def create_leaderboard(filtered_df, country_df, developer_df, model_df, start_ti
74
  total_value = top["Total Value"].sum()
75
  top["% of total"] = top["Total Value"] / total_value * 100 if total_value else 0
76
 
 
 
 
 
 
77
  # All relevant metadata columns
78
- meta_cols = ["country", "author", "downloads", "org_or_user", "merged_modality"]
79
  # Collect all metadata per top n for each category (country, author, model)
80
  meta_map = {}
 
81
  for name in top["Name"]:
82
  name_data = df[df[group_col] == name]
83
  meta_map[name] = {}
 
84
  for col in meta_cols:
85
  if col in name_data.columns:
86
  unique_vals = name_data[col].unique()
87
  meta_map[name][col] = list(unique_vals)
 
88
 
89
  # Function to build metadata chips
90
  def build_metadata(nm):
@@ -111,17 +123,49 @@ def create_leaderboard(filtered_df, country_df, developer_df, model_df, start_ti
111
  # Modality
112
  for m in meta.get("merged_modality", []):
113
  chips.append(("", m))
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return chips
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Apply metadata builder to top dataframe
117
  top["Metadata"] = top["Name"].map(build_metadata)
 
 
 
118
 
119
- return top[["Name", "Metadata", "% of total"]]
120
 
121
  # Build leaderboards
122
- top_countries = get_top_n_leaderboard(country_df, "country", top_n)
123
- top_developers = get_top_n_leaderboard(developer_df, "author", top_n)
124
- top_models = get_top_n_leaderboard(model_df, "model", top_n)
125
 
126
  # Chip renderer
127
  def chip(text, bg_color="#F0F0F0"):
@@ -202,9 +246,37 @@ def create_leaderboard(filtered_df, country_df, developer_df, model_df, start_ti
202
  )
203
  ]
204
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  # Table renderer
207
- def render_table(df, title, chip_color="#F0F0F0", bar_color="#4CAF50"):
208
  return html.Div([
209
  html.H4(title, style={"textAlign": "left", "marginBottom": "10px", "fontSize": "20px"}),
210
  html.Table([
@@ -222,14 +294,15 @@ def create_leaderboard(filtered_df, country_df, developer_df, model_df, start_ti
222
  html.Td(progress_bar(row["% of total"], bar_color), style={"textAlign": "center"})
223
  ]) for idx, row in df.iterrows()
224
  ])
225
- ], style={"borderCollapse": "collapse", "width": "100%"})
 
226
  ], style={"marginBottom": "20px"})
227
 
228
  # Layout with 3 stacked tables
229
  layout = html.Div([
230
- render_table(top_countries, "Top Countries", chip_color="#FCE8E6", bar_color="#FF6F61"),
231
- render_table(top_developers, "Top Developers", chip_color="#E6F4EA", bar_color="#4CAF50"),
232
- render_table(top_models, "Top Models", chip_color="#E8F0FE", bar_color="#2196F3"),
233
  ])
234
 
235
  return layout
 
1
  import pandas as pd
2
+ from dash import html, dcc
3
+ import base64
4
 
5
  def create_leaderboard(filtered_df, country_df, developer_df, model_df, start_time=None, end_time=None, top_n=10):
6
  country_icon_map = {
 
34
  for df in [country_df, developer_df, model_df]:
35
  df["time"] = pd.to_datetime(df["time"])
36
 
37
+ # change any value that does not equal "org" to "user"
38
+ filtered_df["org_or_user"] = filtered_df["org_or_user"].where(filtered_df["org_or_user"] == "org", "user")
39
+
40
  # Merge country info for developers/models
41
  developer_df = developer_df.merge(
42
+ filtered_df[["country", "author", "org_or_user", "model", "downloads", "estimated_parameters"]].drop_duplicates(subset=["author"]),
43
  left_on="metric", right_on="author", how="left"
44
  ).drop(columns=["metric"])
45
 
46
  model_df = model_df.merge(
47
+ filtered_df[["country", "author", "downloads", "org_or_user", "model", "merged_modality", "estimated_parameters"]].drop_duplicates(subset=["model"]),
48
  left_on="metric", right_on="model", how="left"
49
  ).drop(columns=["metric"])
50
 
51
  # Rename metric columns
52
  # country_df = country_df.rename(columns={"metric": "country"})
53
  country_df = country_df.merge(
54
+ filtered_df[["country", "downloads", "estimated_parameters"]].drop_duplicates(subset=["country"]),
55
  left_on="metric", right_on="country", how="left"
56
  ).drop(columns=["metric"])
57
 
 
78
  total_value = top["Total Value"].sum()
79
  top["% of total"] = top["Total Value"] / total_value * 100 if total_value else 0
80
 
81
+ # Create a downloadable version of the leaderboard
82
+ download_top = top.copy()
83
+ download_top["Total Value"] = download_top["Total Value"].astype(int)
84
+ download_top["% of total"] = download_top["% of total"].round(2)
85
+
86
  # All relevant metadata columns
87
+ meta_cols = ["country", "author", "downloads", "org_or_user", "merged_modality", "estimated_parameters"]
88
  # Collect all metadata per top n for each category (country, author, model)
89
  meta_map = {}
90
+ download_map = {}
91
  for name in top["Name"]:
92
  name_data = df[df[group_col] == name]
93
  meta_map[name] = {}
94
+ download_map[name] = {}
95
  for col in meta_cols:
96
  if col in name_data.columns:
97
  unique_vals = name_data[col].unique()
98
  meta_map[name][col] = list(unique_vals)
99
+ download_map[name][col] = list(unique_vals)
100
 
101
  # Function to build metadata chips
102
  def build_metadata(nm):
 
123
  # Modality
124
  for m in meta.get("merged_modality", []):
125
  chips.append(("", m))
126
+
127
+ # Estimated Parameters
128
+ for p in meta.get("estimated_parameters", []):
129
+ if pd.notna(p): # Check if p is not NaN
130
+ if p >= 1e9:
131
+ p_str = f"{p/1e9:.1f}B"
132
+ elif p >= 1e6:
133
+ p_str = f"{p/1e6:.1f}M"
134
+ elif p >= 1e3:
135
+ p_str = f"{p/1e3:.1f}K"
136
+ else:
137
+ p_str = str(p)
138
+ chips.append(("⚙️", p_str))
139
  return chips
140
 
141
+ # Function to create downloadable dataframe
142
+ def build_download_metadata(nm):
143
+ meta = download_map.get(nm, {})
144
+ download_info = {}
145
+ for col in meta_cols:
146
+ # don't add empty columns
147
+ if col not in meta or not meta[col]:
148
+ continue
149
+ vals = meta.get(col, [])
150
+ if vals:
151
+ # Join list into a single string for CSV
152
+ download_info[col] = ", ".join(str(v) for v in vals)
153
+ else:
154
+ download_info[col] = ""
155
+ return download_info
156
+
157
  # Apply metadata builder to top dataframe
158
  top["Metadata"] = top["Name"].map(build_metadata)
159
+ download_info_list = [build_download_metadata(nm) for nm in download_top["Name"]]
160
+ download_info_df = pd.DataFrame(download_info_list)
161
+ download_top = pd.concat([download_top, download_info_df], axis=1)
162
 
163
+ return top[["Name", "Metadata", "% of total"]], download_top
164
 
165
  # Build leaderboards
166
+ top_countries, download_top_countries = get_top_n_leaderboard(country_df, "country", top_n)
167
+ top_developers, download_top_developers = get_top_n_leaderboard(developer_df, "author", top_n)
168
+ top_models, download_top_models = get_top_n_leaderboard(model_df, "model", top_n)
169
 
170
  # Chip renderer
171
  def chip(text, bg_color="#F0F0F0"):
 
246
  )
247
  ]
248
  )
249
+
250
+ # Helper to convert DataFrame to CSV and encode for download
251
+ def df_to_download_link(df, filename):
252
+ csv_string = df.to_csv(index=False)
253
+ b64 = base64.b64encode(csv_string.encode()).decode()
254
+ return html.Div(
255
+ html.A(
256
+ "Download CSV",
257
+ id=f"download-{filename}",
258
+ download=f"{filename}.csv",
259
+ href=f"data:text/csv;base64,{b64}",
260
+ target="_blank",
261
+ style={
262
+ "display": "inline-block",
263
+ "marginBottom": "10px",
264
+ "marginRight": "15px",
265
+ "marginTop": "30px",
266
+ "padding": "6px 16px",
267
+ "backgroundColor": "#2196F3",
268
+ "color": "white",
269
+ "borderRadius": "6px",
270
+ "textDecoration": "none",
271
+ "fontWeight": "bold",
272
+ "fontSize": "14px"
273
+ }
274
+ ),
275
+ style={"textAlign": "right"}
276
+ )
277
 
278
  # Table renderer
279
+ def render_table(df, download_df, title, chip_color="#F0F0F0", bar_color="#4CAF50", filename="data"):
280
  return html.Div([
281
  html.H4(title, style={"textAlign": "left", "marginBottom": "10px", "fontSize": "20px"}),
282
  html.Table([
 
294
  html.Td(progress_bar(row["% of total"], bar_color), style={"textAlign": "center"})
295
  ]) for idx, row in df.iterrows()
296
  ])
297
+ ], style={"borderCollapse": "collapse", "width": "100%"}),
298
+ df_to_download_link(download_df, filename),
299
  ], style={"marginBottom": "20px"})
300
 
301
  # Layout with 3 stacked tables
302
  layout = html.Div([
303
+ render_table(top_countries, download_top_countries, "Top Countries", chip_color="#FCE8E6", bar_color="#FF6F61", filename="top_countries"),
304
+ render_table(top_developers, download_top_developers, "Top Developers", chip_color="#E6F4EA", bar_color="#4CAF50", filename="top_developers"),
305
+ render_table(top_models, download_top_models, "Top Models", chip_color="#E8F0FE", bar_color="#2196F3", filename="top_models"),
306
  ])
307
 
308
  return layout
graphs/model_market_share.py CHANGED
@@ -285,54 +285,6 @@ def create_world_map(
285
  row=1,
286
  col=1,
287
  )
288
-
289
- # Country center coordinates for labels
290
- # country_centers = {
291
- # "USA": {"lat": 39.8, "lon": -98.5},
292
- # "CHN": {"lat": 35.8, "lon": 104.2},
293
- # "DEU": {"lat": 51.2, "lon": 10.4},
294
- # "GBR": {"lat": 55.4, "lon": -3.4},
295
- # "FRA": {"lat": 46.6, "lon": 2.2},
296
- # "JPN": {"lat": 36.2, "lon": 138.3},
297
- # "IND": {"lat": 20.6, "lon": 78.9},
298
- # "CAN": {"lat": 56.1, "lon": -106.3},
299
- # "RUS": {"lat": 61.5, "lon": 105.3},
300
- # "BRA": {"lat": -14.2, "lon": -51.9},
301
- # "AUS": {"lat": -25.3, "lon": 133.8},
302
- # "KOR": {"lat": 35.9, "lon": 127.8},
303
- # }
304
-
305
- # # Add initial labels using scattergeo instead of annotations
306
- # label_lons = []
307
- # label_lats = []
308
- # label_texts = []
309
-
310
- # for _, country in top_countries.iterrows():
311
- # country_code = country["country_code"]
312
- # if country_code in country_centers:
313
- # center = country_centers[country_code]
314
- # label_lons.append(center["lon"])
315
- # label_lats.append(center["lat"])
316
- # label_texts.append(f"{country['percentage']:.1f}%")
317
-
318
- # # Add text labels as a scattergeo trace
319
- # fig.add_trace(
320
- # go.Scattergeo(
321
- # lon=label_lons,
322
- # lat=label_lats,
323
- # text=label_texts,
324
- # mode="text",
325
- # textfont=dict(
326
- # color="#ffffff", size=13, family="Inter, system-ui, sans-serif"
327
- # ),
328
- # textposition="middle center",
329
- # showlegend=False,
330
- # hoverinfo="skip",
331
- # geo="geo",
332
- # ),
333
- # row=1,
334
- # col=1,
335
- # )
336
 
337
  # Update layout
338
  fig.update_layout(
 
285
  row=1,
286
  col=1,
287
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  # Update layout
290
  fig.update_layout(
graphs/tree.py CHANGED
@@ -1,8 +1,29 @@
1
  import plotly.express as px
2
  import pandas as pd
3
 
 
 
 
 
 
 
 
 
4
  def generate_model_treemap(df, parent_col='merged_derived_from', child_col='model', value_col='downloads'):
5
- df[parent_col] = str(df[parent_col][0])
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  fig = px.treemap(
8
  df,
@@ -12,4 +33,110 @@ def generate_model_treemap(df, parent_col='merged_derived_from', child_col='mode
12
  color=value_col,
13
  color_continuous_scale='Viridis'
14
  )
 
 
 
 
 
 
15
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import plotly.express as px
2
  import pandas as pd
3
 
4
+ PALETTE_0 = [
5
+ "#335C67",
6
+ "#FFF3B0",
7
+ "#E09F3E",
8
+ "#9E2A2B",
9
+ "#540B0E"
10
+ ]
11
+
12
  def generate_model_treemap(df, parent_col='merged_derived_from', child_col='model', value_col='downloads'):
13
+ # filtered_df[parent_col] = filtered_df[parent_col].apply(lambda x: str(x[0]) if isinstance(x, list) and x else None)
14
+
15
+ df = pd.read_pickle('data_frames/filtered_tree_df.pkl')
16
+ # Filter out nan, No parent, and Unsure
17
+ df = df[~df[parent_col].isin([None, "['Unsure']", 'nan'])]
18
+
19
+ # Find all models that act as a parent
20
+ parent_models = set(df[parent_col].dropna())
21
+
22
+ # Assign empty parent only if row has no parent and is not itself a parent
23
+ df[parent_col] = df[parent_col].where(
24
+ df[parent_col].notna() | df[child_col].isin(parent_models),
25
+ other=""
26
+ )
27
 
28
  fig = px.treemap(
29
  df,
 
33
  color=value_col,
34
  color_continuous_scale='Viridis'
35
  )
36
+
37
+ fig.update_layout(
38
+ height=1200, # make the plot tall
39
+ margin=dict(t=50, l=25, r=25, b=25) # add some breathing room
40
+ )
41
+
42
  return fig
43
+
44
+ # def generate_model_treemap(df, parent_col='merged_derived_from', child_col='model', value_col='downloads'):
45
+ # # iterate over the rows and stringify the lists in 'merged_derived_from'
46
+
47
+ # df.to_pickle('filtered_tree_df.pkl')
48
+
49
+ # fig = px.icicle(
50
+ # df,
51
+ # path=[parent_col, child_col],
52
+ # values=value_col,
53
+ # hover_data=['author', 'estimated_parameters', 'created'],
54
+ # color=value_col,
55
+ # color_continuous_scale='Viridis'
56
+ # )
57
+
58
+ # fig.update_layout(
59
+ # height=1400,
60
+ # margin=dict(t=50, l=25, r=25, b=25)
61
+ # )
62
+ # return fig
63
+
64
+
65
+ # import plotly.graph_objects as go
66
+ # import networkx as nx
67
+ # import pandas as pd
68
+
69
+ # def generate_model_treemap(df, parent_col='merged_derived_from', child_col='model',
70
+ # value_col='downloads', top_n=1000):
71
+
72
+ # # Fill missing parents
73
+ # df[parent_col] = str(df[parent_col][0])
74
+
75
+ # # Keep only top_n by downloads
76
+ # df = df.sort_values(value_col, ascending=False).head(top_n)
77
+
78
+ # # Build directed graph
79
+ # G = nx.DiGraph()
80
+ # for _, row in df.iterrows():
81
+ # parent = row[parent_col]
82
+ # child = row[child_col]
83
+ # G.add_edge(parent, child, weight=row.get(value_col, 1))
84
+
85
+ # # Layout positions (smaller k → tighter graph)
86
+ # pos = nx.spring_layout(G, k=0.3, seed=42)
87
+
88
+ # # Edges
89
+ # edge_x, edge_y = [], []
90
+ # for parent, child in G.edges():
91
+ # x0, y0 = pos[parent]
92
+ # x1, y1 = pos[child]
93
+ # edge_x += [x0, x1, None]
94
+ # edge_y += [y0, y1, None]
95
+
96
+ # edge_trace = go.Scatter(
97
+ # x=edge_x, y=edge_y,
98
+ # line=dict(width=0.8, color="#888"),
99
+ # hoverinfo="none",
100
+ # mode="lines"
101
+ # )
102
+
103
+ # # Nodes
104
+ # node_x, node_y, sizes, texts = [], [], [], []
105
+ # for node in G.nodes():
106
+ # x, y = pos[node]
107
+ # node_x.append(x)
108
+ # node_y.append(y)
109
+ # downloads = df.loc[df[child_col] == node, value_col].sum()
110
+ # sizes.append(max(10, downloads**0.3))
111
+ # texts.append(f"{node}<br>Downloads: {downloads}")
112
+
113
+ # node_trace = go.Scatter(
114
+ # x=node_x, y=node_y,
115
+ # mode="markers+text",
116
+ # text=[n for n in G.nodes()],
117
+ # textposition="top center",
118
+ # hovertext=texts,
119
+ # hoverinfo="text",
120
+ # marker=dict(
121
+ # showscale=True,
122
+ # colorscale="Viridis",
123
+ # color=sizes,
124
+ # size=sizes,
125
+ # colorbar=dict(
126
+ # thickness=15,
127
+ # title=f"{value_col} (scaled)",
128
+ # xanchor="left",
129
+ # ),
130
+ # line_width=2
131
+ # )
132
+ # )
133
+
134
+ # return go.Figure(data=[edge_trace, node_trace],
135
+ # layout=go.Layout(
136
+ # title=f"Model Tree (Top {top_n} by {value_col})",
137
+ # showlegend=False,
138
+ # hovermode="closest",
139
+ # margin=dict(b=20, l=5, r=5, t=40),
140
+ # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
141
+ # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
142
+ # ))
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  pandas
2
  dash
3
  plotly
4
- gunicorn
 
 
1
  pandas
2
  dash
3
  plotly
4
+ gunicorn
5
+ dash-mantine-components