from dash import Dash, html, dcc, Input, Output, State import pandas as pd import dash_mantine_components as dmc import duckdb import time from graphs.leaderboard import ( create_leaderboard, get_top_n_leaderboard, render_table_content, ) # Initialize the app app = Dash() server = app.server # DuckDB connection (global) con = duckdb.connect(database=":memory:", read_only=False) # Load parquet file from Hugging Face using DuckDB HF_DATASET_ID = "emsesc/open_model_evolution_data" hf_parquet_url = "https://huggingface.co/datasets/emsesc/open_model_evolution_data/resolve/main/filtered_df.parquet" print(f"Attempting to connect to dataset from Hugging Face Hub: {HF_DATASET_ID}") try: overall_start_time = time.time() # Install and load httpfs extension for remote file access con.execute("INSTALL httpfs;") con.execute("LOAD httpfs;") # Create a view that references the remote parquet file con.execute(f""" CREATE OR REPLACE VIEW filtered_df AS SELECT * FROM read_parquet('{hf_parquet_url}') """) # Get column list and basic info columns = con.execute("DESCRIBE filtered_df").fetchdf() print("Columns:", columns["column_name"].tolist()) # Get time range for slider time_range = con.execute( "SELECT MIN(time) as min_time, MAX(time) as max_time FROM filtered_df" ).fetchdf() start_dt = pd.to_datetime(time_range["min_time"].iloc[0]) end_dt = pd.to_datetime(time_range["max_time"].iloc[0]) msg = ( f"Successfully connected to dataset in {time.time() - overall_start_time:.2f}s." ) print(msg) except Exception as e: err_msg = f"Failed to load dataset. Error: {e}" print(err_msg) raise # Create a dcc slider for time range selection by year (readable marks) start_ts = int(start_dt.timestamp()) end_ts = int(end_dt.timestamp()) marks = [] # Add start label (e.g. "Jan 2020") marks.append({"value": start_ts, "label": start_dt.strftime("%b %Y")}) # Add yearly marks between start and end (e.g. "2021", "2022") for yr in range(start_dt.year, end_dt.year + 1): yr_ts = int(pd.Timestamp(year=yr, month=1, day=1).timestamp()) start_yr = int(pd.Timestamp(year=start_dt.year, month=1, day=1).timestamp()) if yr_ts != start_yr and yr_ts != end_ts: marks.append({"value": yr_ts, "label": str(yr)}) # Add end label (e.g. "Dec 2024") marks.append({"value": end_ts, "label": end_dt.strftime("%b %Y")}) # Create a dcc slider for time range selection by year time_slider = dmc.RangeSlider( id="time-slider", min=start_ts, max=end_ts, value=[ start_ts, end_ts, ], step=24 * 60 * 60, color="#AC482A", size="md", radius="xl", marks=marks, style={"width": "70%", "margin": "0 auto"}, labelAlwaysOn=False, ) # App layout app.layout = dmc.MantineProvider( theme={ "colorScheme": "light", "primaryColor": "blue", "fontFamily": "Inter, sans-serif", }, children=[ html.Div( [ # Header html.Div( [ html.Div( [ html.Div( [ html.Div( children="Visualizing the Open Model Ecosystem", style={ "fontSize": 22, "fontWeight": "700", "lineHeight": "1.1", }, ), html.Div( children="An interactive dashboard to explore trends in open models on Hugging Face", style={ "fontSize": 13, "marginTop": 6, "opacity": 0.9, }, ), ], style={ "display": "flex", "flexDirection": "column", "justifyContent": "center", }, ), html.Div( [ html.A( children=[ html.Img( src="assets/images/dpi-logo.svg", style={ "height": "28px", "verticalAlign": "middle", "paddingRight": "8px", }, ), "Data Provenance Initiative", ], href="https://www.dataprovenance.org/", target="_blank", style={ "display": "inline-block", "padding": "6px 14px", "fontSize": 13, "color": "#082030", "backgroundColor": "#ffffff", "borderRadius": "18px", "fontWeight": "700", "textDecoration": "none", "marginRight": "12px", }, ), html.A( children=[ html.Img( src="assets/images/Hf-logo-with-title.svg", style={ "height": "30px", "verticalAlign": "middle", }, ) ], href="https://huggingface.co/", target="_blank", style={ "display": "inline-flex", "padding": "6px 14px", "alignItems": "center", "backgroundColor": "#ffffff", "borderRadius": "18px", "textDecoration": "none", }, ), ], style={"display": "flex", "alignItems": "center"}, ), ], style={ "marginLeft": "50px", "marginRight": "50px", "display": "flex", "justifyContent": "space-between", "alignItems": "center", "padding": "18px 24px", "gap": "24px", }, ), ], style={ "backgroundColor": "#082030", "color": "white", "width": "100%", }, ), # Intro / description below header (kept but styled to match layout) # Title html.Div( children="Model Leaderboard", style={ "fontSize": 40, "fontWeight": "700", "textAlign": "center", "marginTop": 20, "marginBottom": 20, }, ), # Button html.Div( children=[ html.Button( "Read the paper", id="my-button", style={ "padding": "10px 20px", "fontSize": 16, "margin": "0 auto", "display": "block", "backgroundColor": "#AC482A", "color": "white", "border": "none", "borderRadius": "5px", "cursor": "pointer", }, ), ], style={"textAlign": "center", "marginBottom": 20}, ), html.Div( children="Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s...", style={ "fontSize": 14, "marginTop": 18, "marginBottom": 12, "marginLeft": 100, "marginRight": 100, "textAlign": "center", }, ), # Main content (filters + tabs) html.Div( children=[ html.Div( [ html.Div( "Select Window", style={ "fontWeight": "700", "marginBottom": 8, "fontSize": 14, }, ), dmc.SegmentedControl( id="segmented", value="all-downloads", color="#AC482A", transitionDuration=200, data=[ { "value": "all-downloads", "label": "All Downloads", }, { "value": "filtered-downloads", "label": "Filtered Downloads", }, ], mb=10, ), html.Span( id="global-toggle-status", style={ "marginLeft": "8px", "display": "inline-block", "marginTop": 6, }, ), ], style={"flex": 1, "minWidth": "220px"}, ), html.Div( [ html.Div( "Select Time Range", style={ "fontWeight": "700", "marginBottom": 8, "fontSize": 14, }, ), time_slider, ], style={"flex": 2, "minWidth": "320px"}, ), ], style={ "display": "flex", "gap": "24px", "padding": "32px", "alignItems": "flex-start", "marginLeft": "100px", "marginRight": "100px", "backgroundColor": "#FFFBF9", "borderRadius": "18px", }, ), html.Div( [ dcc.Tabs( id="leaderboard-tabs", value="Countries", children=[ dcc.Tab( label="Countries", value="Countries", style={ "backgroundColor": "transparent", "border": "none", "padding": "10px 18px", "color": "#6B7280", "fontWeight": "500", }, selected_style={ "backgroundColor": "transparent", "border": "none", "padding": "10px 18px", "fontWeight": "700", "borderBottom": "3px solid #082030", }, children=[create_leaderboard(con, "countries")], ), dcc.Tab( label="Developers", value="Developers", style={ "backgroundColor": "transparent", "border": "none", "padding": "10px 18px", "color": "#6B7280", "fontWeight": "500", }, selected_style={ "backgroundColor": "transparent", "border": "none", "padding": "10px 18px", "fontWeight": "700", "borderBottom": "3px solid #082030", }, children=[create_leaderboard(con, "developers")], ), dcc.Tab( label="Models", value="Models", style={ "backgroundColor": "transparent", "border": "none", "padding": "10px 18px", "color": "#6B7280", "fontWeight": "500", }, selected_style={ "backgroundColor": "transparent", "border": "none", "padding": "10px 18px", "fontWeight": "700", "borderBottom": "3px solid #082030", }, children=[create_leaderboard(con, "models")], ), ], ), ], style={ "borderRadius": "18px", "padding": "32px", "marginTop": "12px", "marginBottom": "64px", "marginLeft": "50px", "marginRight": "50px", }, ), ], style={ "fontFamily": "Inter", "backgroundColor": "#ffffff", "minHeight": "100vh", }, ) ], ) # Callbacks for interactivity # -- helper utilities to consolidate duplicated callback logic -- def _get_filtered_top_n_from_duckdb(slider_value, group_col, top_n): """ Query DuckDB directly to get top N entries with metadata This minimizes data transfer by doing aggregation in DuckDB """ # Build time filter clause time_clause = "" if slider_value and len(slider_value) == 2: start = pd.to_datetime(slider_value[0], unit="s") end = pd.to_datetime(slider_value[1], unit="s") time_clause = f"WHERE time >= '{start}' AND time <= '{end}'" # Build the aggregation query to get top N with all needed metadata # This query groups by the target column and aggregates downloads # while collecting all metadata we need for chips query = f""" WITH base_data AS ( SELECT {group_col}, CASE WHEN org_country_single = 'HF' THEN 'United States of America' WHEN org_country_single = 'International' THEN 'International/Online' WHEN org_country_single = 'Online' THEN 'International/Online' ELSE org_country_single END AS org_country_single, author, merged_country_groups_single, merged_modality, downloads, estimated_parameters, model FROM filtered_df {time_clause} ), -- Compute the total downloads for all rows in the time range total_downloads_cte AS ( SELECT SUM(downloads) AS total_downloads_all FROM base_data ), -- Compute per-group totals and their percentage of all downloads top_items AS ( SELECT b.{group_col} AS name, SUM(b.downloads) AS total_downloads, ROUND(SUM(b.downloads) * 100.0 / t.total_downloads_all, 2) AS percent_of_total, -- Pick first non-null metadata values for reference ANY_VALUE(b.org_country_single) AS org_country_single, ANY_VALUE(b.author) AS author, ANY_VALUE(b.merged_country_groups_single) AS merged_country_groups_single, ANY_VALUE(b.merged_modality) AS merged_modality, ANY_VALUE(b.model) AS model FROM base_data b CROSS JOIN total_downloads_cte t GROUP BY b.{group_col}, t.total_downloads_all ) SELECT * FROM top_items ORDER BY total_downloads DESC LIMIT {top_n}; """ print("Executing DuckDB query for filtered top N:") print(query) # Print the query for debugging return con.execute(query).fetchdf() def _leaderboard_callback_logic( n_clicks, slider_value, current_label, group_col, filename, default_label="▼ Show Top 50", chip_color="#F0F9FF", ): # Normalize label on first load if current_label is None: current_label = default_label # Determine top_n and next label if n_clicks == 0: top_n = 10 new_label = current_label elif "Show Top 50" in current_label: top_n, new_label = 50, "▼ Show Top 100" elif "Show Top 100" in current_label: top_n, new_label = 100, "▲ Show Less" else: top_n, new_label = 10, "▼ Show Top 50" # Get filtered and aggregated data directly from DuckDB df_filtered = _get_filtered_top_n_from_duckdb(slider_value, group_col, top_n) print("CALLBACK LOGIC - Filtered DataFrame:") print(df_filtered.head()) # Print first 5 rows for debugging # Process the already-filtered data df, download_df = get_top_n_leaderboard(df_filtered, group_col, top_n) return render_table_content( df, download_df, chip_color=chip_color, filename=filename ), new_label # -- end helpers -- # Callbacks for interactivity (modularized) @app.callback( Output("top_countries-table", "children"), Output("top_countries-toggle", "children"), Input("top_countries-toggle", "n_clicks"), Input("time-slider", "value"), State("top_countries-toggle", "children"), ) def update_top_countries(n_clicks, slider_value, current_label): return _leaderboard_callback_logic( n_clicks, slider_value, current_label, group_col="org_country_single", filename="top_countries", default_label="▼ Show Top 50", chip_color="#F0F9FF", ) @app.callback( Output("top_developers-table", "children"), Output("top_developers-toggle", "children"), Input("top_developers-toggle", "n_clicks"), Input("time-slider", "value"), State("top_developers-toggle", "children"), ) def update_top_developers(n_clicks, slider_value, current_label): return _leaderboard_callback_logic( n_clicks, slider_value, current_label, group_col="author", filename="top_developers", default_label="▼ Show More", chip_color="#F0F9FF", ) @app.callback( Output("top_models-table", "children"), Output("top_models-toggle", "children"), Input("top_models-toggle", "n_clicks"), Input("time-slider", "value"), State("top_models-toggle", "children"), ) def update_top_models(n_clicks, slider_value, current_label): return _leaderboard_callback_logic( n_clicks, slider_value, current_label, group_col="model", filename="top_models", default_label="▼ Show More", chip_color="#F0F9FF", ) @app.callback(Output("time-slider", "label"), Input("time-slider", "value")) def update_range_labels(values): start_label = pd.to_datetime(values[0], unit="s").strftime("%b %Y") end_label = pd.to_datetime(values[1], unit="s").strftime("%b %Y") return [start_label, end_label] # Run the app if __name__ == "__main__": app.run(debug=True)