open-model-evolution / graphs /leaderboard.py
emsesc's picture
fixed author attribution issue
aaa721d
Raw
History Blame Contribute Delete
23.1 kB
import base64
import pandas as pd
from dash import html
from dash_iconify import DashIconify
import dash_mantine_components as dmc
import countryflag
from config import COMPANY_ICON_MAP, COUNTRY_EMOJI_FALLBACK, META_COLS_MAP
from data_utils import build_leaderboard_query, create_fresh_duckdb_with_views
from helpers import format_large_number
# =============================
# Leaderboard Data Fetching
# =============================
def get_filtered_top_n_from_duckdb(
slider_value, group_col, top_n, view="all_downloads", derived_org_toggle=False
):
"""
Query DuckDB to get model-level rows with per-model total_downloads (delta or full)
Returns a DataFrame with columns including:
- group_key (the grouping column)
- org_country_single, author, derived_author, merged_country_groups_single, merged_modality, model
- total_downloads (per-model downloads in requested window)
- percent_of_total (percent of total across all returned model deltas)
"""
# Create a fresh connection and load parquet-backed views for each call
local_con = create_fresh_duckdb_with_views()
try:
# Compute date window (if slider_value provided, use it; otherwise cover full range)
if slider_value and len(slider_value) == 2:
start = pd.to_datetime(slider_value[0], unit="s")
end = pd.to_datetime(slider_value[1], unit="s")
else:
start = pd.to_datetime("1970-01-01")
end = pd.Timestamp.now()
start_str = str(start)
end_str = str(end)
# Build query using shared function
query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view, derived_org_toggle=derived_org_toggle)
# execute using the fresh local connection
result_df = local_con.execute(query).fetchdf()
return result_df
finally:
local_con.close()
def get_filtered_top_n_alltime_from_duckdb(
slider_value, group_col, top_n, view="all_downloads", derived_org_toggle=False
):
"""
Query DuckDB to get model-level rows with all-time (cumulative) total_downloads at a specific date.
Returns a DataFrame with columns including:
- group_key (the grouping column)
- org_country_single, author, derived_author, merged_country_groups_single, merged_modality, model
- total_downloads (cumulative downloads up to the selected date)
- percent_of_total (percent of total across all returned models)
"""
# Create a fresh connection and load parquet-backed views for each call
local_con = create_fresh_duckdb_with_views()
try:
# Get the single date from slider_value (all-time mode passes a single value)
if slider_value is not None:
date = pd.to_datetime(slider_value, unit="s")
else:
date = pd.Timestamp.now()
date_str = str(date)
# Build query using shared function for all-time
query = build_leaderboard_query(group_col, top_n, date_str=date_str, view=view, derived_org_toggle=derived_org_toggle)
# execute using the fresh local connection
result_df = local_con.execute(query).fetchdf()
return result_df
finally:
local_con.close()
def leaderboard_callback_logic(
n_clicks,
slider_value,
current_label,
group_col,
filename,
default_label="▼ Show Top 50",
chip_color="#F0F9FF",
view="all_downloads",
derived_author_toggle=True,
is_alltime=False,
):
"""
Core logic for handling leaderboard updates based on user interactions.
Returns tuple of (table_content, new_label) for the callback.
"""
# Normalize label on first load
if current_label is None:
current_label = default_label
# Determine top_n and next label
if n_clicks == 0:
top_n = 10
new_label = current_label
elif "Show Top 50" in current_label:
top_n, new_label = 50, "▼ Show Top 100"
elif "Show Top 100" in current_label:
top_n, new_label = 100, "▲ Show Less"
else:
top_n, new_label = 10, "▼ Show Top 50"
# Get filtered and aggregated data directly from DuckDB
# Use all-time query if is_alltime flag is True
if is_alltime:
df_filtered = get_filtered_top_n_alltime_from_duckdb(
slider_value, group_col, top_n, view=view, derived_org_toggle=derived_author_toggle
)
else:
df_filtered = get_filtered_top_n_from_duckdb(
slider_value, group_col, top_n, view=view, derived_org_toggle=derived_author_toggle
)
# If the SQL query returned no rows, ask user to broaden date range
if df_filtered is None or df_filtered.empty:
msg = html.Div(
"No data found in this time range. Try broadening the download date range.",
style={"padding": "18px", "fontSize": "16px", "color": "#082030"},
)
return msg, new_label
# Process the already-filtered data - pass derived_author_toggle
df, download_df = get_top_n_leaderboard(
df_filtered, group_col, top_n, derived_author_toggle=derived_author_toggle
)
# If processing produced no rows, ask user to broaden date range
if df is None or (hasattr(df, "empty") and df.empty):
msg = html.Div(
"No data found in this time range. Try broadening the download date range.",
style={"padding": "18px", "fontSize": "16px", "color": "#082030"},
)
return msg, new_label
return render_table_content(
df, download_df, chip_color=chip_color, filename=filename
), new_label
# =============================
# UI Rendering Components
# =============================
# Chip renderer
def chip(text, bg_color="#F0F0F0"):
return html.Span(
text,
style={
"backgroundColor": bg_color,
"padding": "4px 10px",
"borderRadius": "12px",
"margin": "2px",
"display": "inline-flex",
"alignItems": "center",
"fontSize": "14px",
},
)
# Progress bar for % of total
def progress_bar(percent, bar_color="#AC482A"):
return html.Div(
style={
"position": "relative",
"backgroundColor": "#E0E0E0",
"borderRadius": "8px",
"height": "20px",
"width": "100%",
"overflow": "hidden",
},
children=[
html.Div(
style={
"backgroundColor": bar_color,
"width": f"{percent}%",
"height": "100%",
"borderRadius": "8px",
"transition": "width 0.5s",
}
),
html.Div(
f"{percent:.1f}%",
style={
"position": "absolute",
"top": 0,
"left": "50%",
"transform": "translateX(-50%)",
"color": "black",
"fontWeight": "bold",
"fontSize": "12px",
"lineHeight": "20px",
"textAlign": "center",
},
),
],
)
# Helper to convert DataFrame to CSV and encode for download
def df_to_download_link(df, filename):
csv_string = df.to_csv(index=False)
b64 = base64.b64encode(csv_string.encode()).decode()
return html.Div(
html.A(
children=dmc.ActionIcon(
DashIconify(icon="mdi:download", width=24),
size="lg",
color="#082030",
),
id=f"download-{filename}",
download=f"{filename}.csv",
href=f"data:text/csv;base64,{b64}",
target="_blank",
title="Download CSV",
style={
"padding": "6px 12px",
"display": "inline-flex",
"alignItems": "center",
"justifyContent": "center",
},
),
style={"textAlign": "right"},
)
# Helper to get popover content for each metadata type
def get_metadata_popover_content(icon, name, meta_type):
popover_texts = {
"country": f"Country: {name}",
"author": f"Author/Organization: {name}",
"downloads": f"Total downloads: {name}",
"modality": f"Modality: {name}",
}
return popover_texts.get(meta_type, name)
# Chip renderer with hovercard
def chip_with_hovercard(text, bg_color="#F0F0F0", meta_type=None, icon=None):
hovercard_content = get_metadata_popover_content(icon, text, meta_type)
return dmc.HoverCard(
width="auto",
shadow="md",
position="top",
children=[
dmc.HoverCardTarget(
html.Span(
text,
style={
"backgroundColor": bg_color,
"padding": "4px 10px",
"borderRadius": "12px",
"margin": "2px",
"display": "inline-flex",
"alignItems": "center",
"fontSize": "14px",
"cursor": "pointer",
"transition": "background-color 0.15s",
},
# Add a class for hover effect
className="chip-hover-darken"
)
),
dmc.HoverCardDropdown(dmc.Text(hovercard_content, size="sm")),
],
)
# Render multiple chips in one row, each with popover
def render_chips(metadata_list, chip_color):
chips = []
for icon, name, meta_type in metadata_list:
if isinstance(icon, str) and icon.endswith((".png", ".jpg", ".jpeg", ".svg")):
chips.append(
dmc.HoverCard(
width=220,
shadow="md",
position="top",
children=[
dmc.HoverCardTarget(
html.Span(
[
html.Img(
src=icon,
style={"height": "18px", "marginRight": "6px"},
),
name,
],
style={
"backgroundColor": chip_color,
"padding": "4px 10px",
"borderRadius": "12px",
"margin": "2px",
"display": "inline-flex",
"alignItems": "left",
"fontSize": "14px",
"cursor": "pointer",
},
)
),
dmc.HoverCardDropdown(
dmc.Text(
get_metadata_popover_content(icon, name, meta_type),
size="sm",
)
),
],
)
)
else:
chips.append(
chip_with_hovercard(f"{icon} {name}", chip_color, meta_type, icon)
)
return html.Div(
chips, style={"display": "flex", "flexWrap": "wrap", "justifyContent": "left"}
)
def render_table_content(
df, download_df, chip_color, bar_color="#AC482A", filename="data"
):
return html.Div(
[
# Add download button above the table
df_to_download_link(download_df, filename),
# Wrap the table in a horizontal scroll container so the table can be wide
html.Div(
# scroll wrapper
html.Table(
[
html.Thead(
html.Tr(
[
html.Th(
"Rank",
className="rank-col",
style={
"backgroundColor": "#F0F0F0",
"textAlign": "left",
},
),
html.Th(
"Name",
className="name-col",
style={
"backgroundColor": "#F0F0F0",
"textAlign": "left",
},
),
html.Th(
"Metadata",
className="metadata-col",
style={
"backgroundColor": "#F0F0F0",
"textAlign": "left",
"marginRight": "10px",
},
),
html.Th(
"% of Total",
className="percent-col",
style={
"backgroundColor": "#F0F0F0",
"textAlign": "left",
},
),
]
)
),
html.Tbody(
[
html.Tr(
[
html.Td(idx + 1, style={"textAlign": "center"}),
html.Td(row["Name"], className="name-cell", style={"textAlign": "left"}),
html.Td(render_chips(row["Metadata"], chip_color), className="metadata-cell", style={"textAlign": "left", "whiteSpace": "normal", "wordBreak": "break-word"}),
html.Td(
progress_bar(row["% of total"], bar_color),
className="percent-cell",
style={"textAlign": "center", "minWidth": "180px", "padding": "8px"},
),
]
)
for idx, row in df.iterrows()
]
),
],
# allow the table to be wider than its container (minWidth prevents squish)
style={"borderCollapse": "collapse", "width": "100%", "minWidth": "980px", "tableLayout": "auto"},
className="leaderboard-table",
),
className="leaderboard-scroll-wrapper",
style={"overflowX": "auto", "-webkit-overflow-scrolling": "touch", "width": "100%"},
),
]
)
# Function to get top N leaderboard (now accepts pandas DataFrame from DuckDB query)
def get_top_n_leaderboard(filtered_df, group_col, top_n=10, derived_author_toggle=True):
"""
Get top N entries for a leaderboard
Args:
filtered_df: Pandas DataFrame of model-level rows. Must contain:
- group_col (the grouping key)
- total_downloads (per-model downloads for the requested window)
- plus metadata columns: org_country_single, author, derived_author, merged_country_groups_single, merged_modality, model
group_col: Column to group by
top_n: Number of top entries to return
derived_author_toggle: If True, attribute to model uploader (derived_author); if False, attribute to original model creator (author)
Returns:
tuple: (display_df, download_df)
display_df: DataFrame with columns ["Name","Metadata","% of total"] for rendering
download_df: DataFrame suitable for CSV download with numeric totals and metadata columns
"""
if filtered_df is None or filtered_df.empty:
return pd.DataFrame(), pd.DataFrame()
# Ensure numeric total_downloads
if "total_downloads" not in filtered_df.columns:
# fallback if older code still returned 'downloads' (unlikely)
if "downloads" in filtered_df.columns:
filtered_df["total_downloads"] = filtered_df["downloads"]
else:
filtered_df["total_downloads"] = 0
# Compute overall total across all models in this filtered set
total_all = filtered_df["total_downloads"].sum()
if total_all == 0:
return pd.DataFrame(), pd.DataFrame()
# Sum per group (group_col) to get group totals
grouped = (
filtered_df.groupby(group_col)["total_downloads"]
.sum()
.reset_index()
.rename(columns={group_col: "Name", "total_downloads": "Total Value"})
)
# Pick top N groups by summed downloads
top = grouped.nlargest(top_n, columns="Total Value").reset_index(drop=True)
# Compute percent of total for display (rounded)
top["% of total"] = top["Total Value"].apply(lambda v: round(v * 100.0 / total_all, 2))
# Build download version (numeric)
download_top = top.copy()
download_top["Total Value"] = download_top["Total Value"].astype(int)
download_top["% of total"] = download_top["% of total"].round(2)
# All relevant metadata columns for the grouping
meta_cols = META_COLS_MAP.get(group_col, [])
# Collect metadata per group by inspecting the underlying model-level rows
meta_map = {}
download_map = {}
for name in top["Name"]:
name_data = filtered_df[filtered_df[group_col] == name]
meta_map[name] = {}
download_map[name] = {}
for col in meta_cols:
if col in name_data.columns:
unique_vals = name_data[col].dropna().unique()
meta_map[name][col] = list(unique_vals)
download_map[name][col] = list(unique_vals)
# Function to build metadata chips
def build_metadata(nm):
meta = meta_map.get(nm, {})
chips = []
# Countries
for c in meta.get("org_country_single", []):
if c == "United States of America":
c = "USA"
if c == "user":
c = "User"
try:
flag_emoji = countryflag.getflag(c)
if not flag_emoji or flag_emoji == c:
flag_emoji = COUNTRY_EMOJI_FALLBACK.get(c, "🌍")
except Exception:
flag_emoji = COUNTRY_EMOJI_FALLBACK.get(c, "🌍")
chips.append((flag_emoji, c, "country"))
# Author - use derived_author_toggle to determine which column
author_key = "derived_author" if derived_author_toggle else "author"
for a in meta.get(author_key, []):
icon = COMPANY_ICON_MAP.get(a, "")
if icon == "":
if meta.get("merged_country_groups_single", ["User"])[0] != "User":
icon = "🏢"
else:
icon = "👤"
chips.append((icon, a, "author"))
# Modality
for m in meta.get("merged_modality", []):
if pd.notna(m):
chips.append(("", m, "modality"))
# Total downloads (aggregate numeric value for this group)
# Use the summed value from top (we can retrieve it)
# but we also include any per-model totals if desired - keep simple: use group total
group_total = int(top.loc[top["Name"] == nm, "Total Value"].iloc[0]) if nm in top["Name"].values else None
if group_total is not None:
formatted_downloads = format_large_number(group_total)
chips.append(("⬇️", formatted_downloads, "downloads"))
return chips
# Attach Metadata column for display DataFrame
display_df = top.rename(columns={"Total Value": "total_downloads"})
display_df["Metadata"] = display_df["Name"].astype(object).apply(build_metadata)
# Format display_df columns for render_table_content
display_df_formatted = display_df.rename(columns={"% of total": "% of total"})
# Keep only necessary columns in expected order
display_for_render = display_df_formatted[["Name", "Metadata", "% of total"]]
# Build download dataframe with metadata for CSV
download_info_list = []
for nm in download_top["Name"]:
info = {}
meta = download_map.get(nm, {})
for col in meta_cols:
if col in meta and meta[col]:
info[col] = ", ".join(str(v) for v in meta[col] if pd.notna(v))
else:
info[col] = ""
# attach totals
info["Total Value"] = int(download_top.loc[download_top["Name"] == nm, "Total Value"].iloc[0])
info["% of total"] = float(download_top.loc[download_top["Name"] == nm, "% of total"].iloc[0])
download_info_list.append(info)
download_info_df = pd.DataFrame(download_info_list)
download_top = pd.concat([download_top.reset_index(drop=True), download_info_df.reset_index(drop=True)], axis=1)
return display_for_render, download_top
def get_top_n_from_duckdb(
group_col, top_n=10, time_filter=None, view="all_downloads"
):
"""
Query DuckDB directly to get model-level rows with per-model total_downloads (delta or full)
Returns rows similar to _get_filtered_top_n_from_duckdb in app.py.
NOTE: This function now opens a fresh DuckDB connection internally and ignores
any external connection passed in. Keep signature for compatibility.
"""
# Compute date window
if time_filter and len(time_filter) == 2:
start = pd.to_datetime(time_filter[0], unit="s")
end = pd.to_datetime(time_filter[1], unit="s")
else:
start = pd.to_datetime("1970-01-01")
# We cannot access end_dt here; rely on time_filter for end in typical use.
end = pd.Timestamp.now()
start_str = str(start)
end_str = str(end)
# Build SQL using the shared helper
query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view)
# Open a fresh in-memory connection that creates the views, run the query, close.
conn_local = create_fresh_duckdb_with_views()
try:
return conn_local.execute(query).fetchdf()
except Exception as e:
print(f"Error querying DuckDB: {e}")
return pd.DataFrame()
finally:
conn_local.close()