emsesc's picture
buggy duckdb
6ba1ddc
raw
history blame
23.6 kB
from dash import Dash, html, dcc, Input, Output, State
import pandas as pd
import dash_mantine_components as dmc
import duckdb
import time
from graphs.leaderboard import (
create_leaderboard,
get_top_n_leaderboard,
render_table_content,
)
# Initialize the app
app = Dash()
server = app.server
# DuckDB connection (global)
con = duckdb.connect(database=":memory:", read_only=False)
# Load parquet file from Hugging Face using DuckDB
HF_DATASET_ID = "emsesc/open_model_evolution_data"
hf_parquet_url = "https://huggingface.co/datasets/emsesc/open_model_evolution_data/resolve/main/filtered_df.parquet"
print(f"Attempting to connect to dataset from Hugging Face Hub: {HF_DATASET_ID}")
try:
overall_start_time = time.time()
# Install and load httpfs extension for remote file access
con.execute("INSTALL httpfs;")
con.execute("LOAD httpfs;")
# Create a view that references the remote parquet file
con.execute(f"""
CREATE OR REPLACE VIEW filtered_df AS
SELECT * FROM read_parquet('{hf_parquet_url}')
""")
# Get column list and basic info
columns = con.execute("DESCRIBE filtered_df").fetchdf()
print("Columns:", columns["column_name"].tolist())
# Get time range for slider
time_range = con.execute(
"SELECT MIN(time) as min_time, MAX(time) as max_time FROM filtered_df"
).fetchdf()
start_dt = pd.to_datetime(time_range["min_time"].iloc[0])
end_dt = pd.to_datetime(time_range["max_time"].iloc[0])
msg = (
f"Successfully connected to dataset in {time.time() - overall_start_time:.2f}s."
)
print(msg)
except Exception as e:
err_msg = f"Failed to load dataset. Error: {e}"
print(err_msg)
raise
# Create a dcc slider for time range selection by year (readable marks)
start_ts = int(start_dt.timestamp())
end_ts = int(end_dt.timestamp())
marks = []
# Add start label (e.g. "Jan 2020")
marks.append({"value": start_ts, "label": start_dt.strftime("%b %Y")})
# Add yearly marks between start and end (e.g. "2021", "2022")
for yr in range(start_dt.year, end_dt.year + 1):
yr_ts = int(pd.Timestamp(year=yr, month=1, day=1).timestamp())
start_yr = int(pd.Timestamp(year=start_dt.year, month=1, day=1).timestamp())
if yr_ts != start_yr and yr_ts != end_ts:
marks.append({"value": yr_ts, "label": str(yr)})
# Add end label (e.g. "Dec 2024")
marks.append({"value": end_ts, "label": end_dt.strftime("%b %Y")})
# Create a dcc slider for time range selection by year
time_slider = dmc.RangeSlider(
id="time-slider",
min=start_ts,
max=end_ts,
value=[
start_ts,
end_ts,
],
step=24 * 60 * 60,
color="#AC482A",
size="md",
radius="xl",
marks=marks,
style={"width": "70%", "margin": "0 auto"},
labelAlwaysOn=False,
)
# App layout
app.layout = dmc.MantineProvider(
theme={
"colorScheme": "light",
"primaryColor": "blue",
"fontFamily": "Inter, sans-serif",
},
children=[
html.Div(
[
# Header
html.Div(
[
html.Div(
[
html.Div(
[
html.Div(
children="Visualizing the Open Model Ecosystem",
style={
"fontSize": 22,
"fontWeight": "700",
"lineHeight": "1.1",
},
),
html.Div(
children="An interactive dashboard to explore trends in open models on Hugging Face",
style={
"fontSize": 13,
"marginTop": 6,
"opacity": 0.9,
},
),
],
style={
"display": "flex",
"flexDirection": "column",
"justifyContent": "center",
},
),
html.Div(
[
html.A(
children=[
html.Img(
src="assets/images/dpi-logo.svg",
style={
"height": "28px",
"verticalAlign": "middle",
"paddingRight": "8px",
},
),
"Data Provenance Initiative",
],
href="https://www.dataprovenance.org/",
target="_blank",
style={
"display": "inline-block",
"padding": "6px 14px",
"fontSize": 13,
"color": "#082030",
"backgroundColor": "#ffffff",
"borderRadius": "18px",
"fontWeight": "700",
"textDecoration": "none",
"marginRight": "12px",
},
),
html.A(
children=[
html.Img(
src="assets/images/Hf-logo-with-title.svg",
style={
"height": "30px",
"verticalAlign": "middle",
},
)
],
href="https://huggingface.co/",
target="_blank",
style={
"display": "inline-flex",
"padding": "6px 14px",
"alignItems": "center",
"backgroundColor": "#ffffff",
"borderRadius": "18px",
"textDecoration": "none",
},
),
],
style={"display": "flex", "alignItems": "center"},
),
],
style={
"marginLeft": "50px",
"marginRight": "50px",
"display": "flex",
"justifyContent": "space-between",
"alignItems": "center",
"padding": "18px 24px",
"gap": "24px",
},
),
],
style={
"backgroundColor": "#082030",
"color": "white",
"width": "100%",
},
),
# Intro / description below header (kept but styled to match layout)
# Title
html.Div(
children="Model Leaderboard",
style={
"fontSize": 40,
"fontWeight": "700",
"textAlign": "center",
"marginTop": 20,
"marginBottom": 20,
},
),
# Button
html.Div(
children=[
html.Button(
"Read the paper",
id="my-button",
style={
"padding": "10px 20px",
"fontSize": 16,
"margin": "0 auto",
"display": "block",
"backgroundColor": "#AC482A",
"color": "white",
"border": "none",
"borderRadius": "5px",
"cursor": "pointer",
},
),
],
style={"textAlign": "center", "marginBottom": 20},
),
html.Div(
children="Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s...",
style={
"fontSize": 14,
"marginTop": 18,
"marginBottom": 12,
"marginLeft": 100,
"marginRight": 100,
"textAlign": "center",
},
),
# Main content (filters + tabs)
html.Div(
children=[
html.Div(
[
html.Div(
"Select Window",
style={
"fontWeight": "700",
"marginBottom": 8,
"fontSize": 14,
},
),
dmc.SegmentedControl(
id="segmented",
value="all-downloads",
color="#AC482A",
transitionDuration=200,
data=[
{
"value": "all-downloads",
"label": "All Downloads",
},
{
"value": "filtered-downloads",
"label": "Filtered Downloads",
},
],
mb=10,
),
html.Span(
id="global-toggle-status",
style={
"marginLeft": "8px",
"display": "inline-block",
"marginTop": 6,
},
),
],
style={"flex": 1, "minWidth": "220px"},
),
html.Div(
[
html.Div(
"Select Time Range",
style={
"fontWeight": "700",
"marginBottom": 8,
"fontSize": 14,
},
),
time_slider,
],
style={"flex": 2, "minWidth": "320px"},
),
],
style={
"display": "flex",
"gap": "24px",
"padding": "32px",
"alignItems": "flex-start",
"marginLeft": "100px",
"marginRight": "100px",
"backgroundColor": "#FFFBF9",
"borderRadius": "18px",
},
),
html.Div(
[
dcc.Tabs(
id="leaderboard-tabs",
value="Countries",
children=[
dcc.Tab(
label="Countries",
value="Countries",
style={
"backgroundColor": "transparent",
"border": "none",
"padding": "10px 18px",
"color": "#6B7280",
"fontWeight": "500",
},
selected_style={
"backgroundColor": "transparent",
"border": "none",
"padding": "10px 18px",
"fontWeight": "700",
"borderBottom": "3px solid #082030",
},
children=[create_leaderboard(con, "countries")],
),
dcc.Tab(
label="Developers",
value="Developers",
style={
"backgroundColor": "transparent",
"border": "none",
"padding": "10px 18px",
"color": "#6B7280",
"fontWeight": "500",
},
selected_style={
"backgroundColor": "transparent",
"border": "none",
"padding": "10px 18px",
"fontWeight": "700",
"borderBottom": "3px solid #082030",
},
children=[create_leaderboard(con, "developers")],
),
dcc.Tab(
label="Models",
value="Models",
style={
"backgroundColor": "transparent",
"border": "none",
"padding": "10px 18px",
"color": "#6B7280",
"fontWeight": "500",
},
selected_style={
"backgroundColor": "transparent",
"border": "none",
"padding": "10px 18px",
"fontWeight": "700",
"borderBottom": "3px solid #082030",
},
children=[create_leaderboard(con, "models")],
),
],
),
],
style={
"borderRadius": "18px",
"padding": "32px",
"marginTop": "12px",
"marginBottom": "64px",
"marginLeft": "50px",
"marginRight": "50px",
},
),
],
style={
"fontFamily": "Inter",
"backgroundColor": "#ffffff",
"minHeight": "100vh",
},
)
],
)
# Callbacks for interactivity
# -- helper utilities to consolidate duplicated callback logic --
def _get_filtered_top_n_from_duckdb(slider_value, group_col, top_n):
"""
Query DuckDB directly to get top N entries with metadata
This minimizes data transfer by doing aggregation in DuckDB
"""
# Build time filter clause
time_clause = ""
if slider_value and len(slider_value) == 2:
start = pd.to_datetime(slider_value[0], unit="s")
end = pd.to_datetime(slider_value[1], unit="s")
time_clause = f"WHERE time >= '{start}' AND time <= '{end}'"
# Build the aggregation query to get top N with all needed metadata
# This query groups by the target column and aggregates downloads
# while collecting all metadata we need for chips
query = f"""
WITH base_data AS (
SELECT
{group_col},
CASE
WHEN org_country_single = 'HF' THEN 'United States of America'
WHEN org_country_single = 'International' THEN 'International/Online'
WHEN org_country_single = 'Online' THEN 'International/Online'
ELSE org_country_single
END AS org_country_single,
author,
merged_country_groups_single,
merged_modality,
downloads,
estimated_parameters,
model
FROM filtered_df
{time_clause}
),
-- Compute the total downloads for all rows in the time range
total_downloads_cte AS (
SELECT SUM(downloads) AS total_downloads_all
FROM base_data
),
-- Compute per-group totals and their percentage of all downloads
top_items AS (
SELECT
b.{group_col} AS name,
SUM(b.downloads) AS total_downloads,
ROUND(SUM(b.downloads) * 100.0 / t.total_downloads_all, 2) AS percent_of_total,
-- Pick first non-null metadata values for reference
ANY_VALUE(b.org_country_single) AS org_country_single,
ANY_VALUE(b.author) AS author,
ANY_VALUE(b.merged_country_groups_single) AS merged_country_groups_single,
ANY_VALUE(b.merged_modality) AS merged_modality,
ANY_VALUE(b.model) AS model
FROM base_data b
CROSS JOIN total_downloads_cte t
GROUP BY b.{group_col}, t.total_downloads_all
)
SELECT *
FROM top_items
ORDER BY total_downloads DESC
LIMIT {top_n};
"""
print("Executing DuckDB query for filtered top N:")
print(query) # Print the query for debugging
return con.execute(query).fetchdf()
def _leaderboard_callback_logic(
n_clicks,
slider_value,
current_label,
group_col,
filename,
default_label="▼ Show Top 50",
chip_color="#F0F9FF",
):
# Normalize label on first load
if current_label is None:
current_label = default_label
# Determine top_n and next label
if n_clicks == 0:
top_n = 10
new_label = current_label
elif "Show Top 50" in current_label:
top_n, new_label = 50, "▼ Show Top 100"
elif "Show Top 100" in current_label:
top_n, new_label = 100, "▲ Show Less"
else:
top_n, new_label = 10, "▼ Show Top 50"
# Get filtered and aggregated data directly from DuckDB
df_filtered = _get_filtered_top_n_from_duckdb(slider_value, group_col, top_n)
print("CALLBACK LOGIC - Filtered DataFrame:")
print(df_filtered.head()) # Print first 5 rows for debugging
# Process the already-filtered data
df, download_df = get_top_n_leaderboard(df_filtered, group_col, top_n)
return render_table_content(
df, download_df, chip_color=chip_color, filename=filename
), new_label
# -- end helpers --
# Callbacks for interactivity (modularized)
@app.callback(
Output("top_countries-table", "children"),
Output("top_countries-toggle", "children"),
Input("top_countries-toggle", "n_clicks"),
Input("time-slider", "value"),
State("top_countries-toggle", "children"),
)
def update_top_countries(n_clicks, slider_value, current_label):
return _leaderboard_callback_logic(
n_clicks,
slider_value,
current_label,
group_col="org_country_single",
filename="top_countries",
default_label="▼ Show Top 50",
chip_color="#F0F9FF",
)
@app.callback(
Output("top_developers-table", "children"),
Output("top_developers-toggle", "children"),
Input("top_developers-toggle", "n_clicks"),
Input("time-slider", "value"),
State("top_developers-toggle", "children"),
)
def update_top_developers(n_clicks, slider_value, current_label):
return _leaderboard_callback_logic(
n_clicks,
slider_value,
current_label,
group_col="author",
filename="top_developers",
default_label="▼ Show More",
chip_color="#F0F9FF",
)
@app.callback(
Output("top_models-table", "children"),
Output("top_models-toggle", "children"),
Input("top_models-toggle", "n_clicks"),
Input("time-slider", "value"),
State("top_models-toggle", "children"),
)
def update_top_models(n_clicks, slider_value, current_label):
return _leaderboard_callback_logic(
n_clicks,
slider_value,
current_label,
group_col="model",
filename="top_models",
default_label="▼ Show More",
chip_color="#F0F9FF",
)
@app.callback(Output("time-slider", "label"), Input("time-slider", "value"))
def update_range_labels(values):
start_label = pd.to_datetime(values[0], unit="s").strftime("%b %Y")
end_label = pd.to_datetime(values[1], unit="s").strftime("%b %Y")
return [start_label, end_label]
# Run the app
if __name__ == "__main__":
app.run(debug=True)