Spaces:
Sleeping
Sleeping
| import base64 | |
| import pandas as pd | |
| from dash import html | |
| from dash_iconify import DashIconify | |
| import dash_mantine_components as dmc | |
| import countryflag | |
| from config import COMPANY_ICON_MAP, COUNTRY_EMOJI_FALLBACK, META_COLS_MAP | |
| from data_utils import build_leaderboard_query, create_fresh_duckdb_with_views | |
| from helpers import format_large_number | |
| # ============================= | |
| # Leaderboard Data Fetching | |
| # ============================= | |
| def get_filtered_top_n_from_duckdb( | |
| slider_value, group_col, top_n, view="all_downloads", derived_org_toggle=False | |
| ): | |
| """ | |
| Query DuckDB to get model-level rows with per-model total_downloads (delta or full) | |
| Returns a DataFrame with columns including: | |
| - group_key (the grouping column) | |
| - org_country_single, author, derived_author, merged_country_groups_single, merged_modality, model | |
| - total_downloads (per-model downloads in requested window) | |
| - percent_of_total (percent of total across all returned model deltas) | |
| """ | |
| # Create a fresh connection and load parquet-backed views for each call | |
| local_con = create_fresh_duckdb_with_views() | |
| try: | |
| # Compute date window (if slider_value provided, use it; otherwise cover full range) | |
| if slider_value and len(slider_value) == 2: | |
| start = pd.to_datetime(slider_value[0], unit="s") | |
| end = pd.to_datetime(slider_value[1], unit="s") | |
| else: | |
| start = pd.to_datetime("1970-01-01") | |
| end = pd.Timestamp.now() | |
| start_str = str(start) | |
| end_str = str(end) | |
| # Build query using shared function | |
| query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view, derived_org_toggle=derived_org_toggle) | |
| # execute using the fresh local connection | |
| result_df = local_con.execute(query).fetchdf() | |
| return result_df | |
| finally: | |
| local_con.close() | |
| def get_filtered_top_n_alltime_from_duckdb( | |
| slider_value, group_col, top_n, view="all_downloads", derived_org_toggle=False | |
| ): | |
| """ | |
| Query DuckDB to get model-level rows with all-time (cumulative) total_downloads at a specific date. | |
| Returns a DataFrame with columns including: | |
| - group_key (the grouping column) | |
| - org_country_single, author, derived_author, merged_country_groups_single, merged_modality, model | |
| - total_downloads (cumulative downloads up to the selected date) | |
| - percent_of_total (percent of total across all returned models) | |
| """ | |
| # Create a fresh connection and load parquet-backed views for each call | |
| local_con = create_fresh_duckdb_with_views() | |
| try: | |
| # Get the single date from slider_value (all-time mode passes a single value) | |
| if slider_value is not None: | |
| date = pd.to_datetime(slider_value, unit="s") | |
| else: | |
| date = pd.Timestamp.now() | |
| date_str = str(date) | |
| # Build query using shared function for all-time | |
| query = build_leaderboard_query(group_col, top_n, date_str=date_str, view=view, derived_org_toggle=derived_org_toggle) | |
| # execute using the fresh local connection | |
| result_df = local_con.execute(query).fetchdf() | |
| return result_df | |
| finally: | |
| local_con.close() | |
| def leaderboard_callback_logic( | |
| n_clicks, | |
| slider_value, | |
| current_label, | |
| group_col, | |
| filename, | |
| default_label="▼ Show Top 50", | |
| chip_color="#F0F9FF", | |
| view="all_downloads", | |
| derived_author_toggle=True, | |
| is_alltime=False, | |
| ): | |
| """ | |
| Core logic for handling leaderboard updates based on user interactions. | |
| Returns tuple of (table_content, new_label) for the callback. | |
| """ | |
| # Normalize label on first load | |
| if current_label is None: | |
| current_label = default_label | |
| # Determine top_n and next label | |
| if n_clicks == 0: | |
| top_n = 10 | |
| new_label = current_label | |
| elif "Show Top 50" in current_label: | |
| top_n, new_label = 50, "▼ Show Top 100" | |
| elif "Show Top 100" in current_label: | |
| top_n, new_label = 100, "▲ Show Less" | |
| else: | |
| top_n, new_label = 10, "▼ Show Top 50" | |
| # Get filtered and aggregated data directly from DuckDB | |
| # Use all-time query if is_alltime flag is True | |
| if is_alltime: | |
| df_filtered = get_filtered_top_n_alltime_from_duckdb( | |
| slider_value, group_col, top_n, view=view, derived_org_toggle=derived_author_toggle | |
| ) | |
| else: | |
| df_filtered = get_filtered_top_n_from_duckdb( | |
| slider_value, group_col, top_n, view=view, derived_org_toggle=derived_author_toggle | |
| ) | |
| # If the SQL query returned no rows, ask user to broaden date range | |
| if df_filtered is None or df_filtered.empty: | |
| msg = html.Div( | |
| "No data found in this time range. Try broadening the download date range.", | |
| style={"padding": "18px", "fontSize": "16px", "color": "#082030"}, | |
| ) | |
| return msg, new_label | |
| # Process the already-filtered data - pass derived_author_toggle | |
| df, download_df = get_top_n_leaderboard( | |
| df_filtered, group_col, top_n, derived_author_toggle=derived_author_toggle | |
| ) | |
| # If processing produced no rows, ask user to broaden date range | |
| if df is None or (hasattr(df, "empty") and df.empty): | |
| msg = html.Div( | |
| "No data found in this time range. Try broadening the download date range.", | |
| style={"padding": "18px", "fontSize": "16px", "color": "#082030"}, | |
| ) | |
| return msg, new_label | |
| return render_table_content( | |
| df, download_df, chip_color=chip_color, filename=filename | |
| ), new_label | |
| # ============================= | |
| # UI Rendering Components | |
| # ============================= | |
| # Chip renderer | |
| def chip(text, bg_color="#F0F0F0"): | |
| return html.Span( | |
| text, | |
| style={ | |
| "backgroundColor": bg_color, | |
| "padding": "4px 10px", | |
| "borderRadius": "12px", | |
| "margin": "2px", | |
| "display": "inline-flex", | |
| "alignItems": "center", | |
| "fontSize": "14px", | |
| }, | |
| ) | |
| # Progress bar for % of total | |
| def progress_bar(percent, bar_color="#AC482A"): | |
| return html.Div( | |
| style={ | |
| "position": "relative", | |
| "backgroundColor": "#E0E0E0", | |
| "borderRadius": "8px", | |
| "height": "20px", | |
| "width": "100%", | |
| "overflow": "hidden", | |
| }, | |
| children=[ | |
| html.Div( | |
| style={ | |
| "backgroundColor": bar_color, | |
| "width": f"{percent}%", | |
| "height": "100%", | |
| "borderRadius": "8px", | |
| "transition": "width 0.5s", | |
| } | |
| ), | |
| html.Div( | |
| f"{percent:.1f}%", | |
| style={ | |
| "position": "absolute", | |
| "top": 0, | |
| "left": "50%", | |
| "transform": "translateX(-50%)", | |
| "color": "black", | |
| "fontWeight": "bold", | |
| "fontSize": "12px", | |
| "lineHeight": "20px", | |
| "textAlign": "center", | |
| }, | |
| ), | |
| ], | |
| ) | |
| # Helper to convert DataFrame to CSV and encode for download | |
| def df_to_download_link(df, filename): | |
| csv_string = df.to_csv(index=False) | |
| b64 = base64.b64encode(csv_string.encode()).decode() | |
| return html.Div( | |
| html.A( | |
| children=dmc.ActionIcon( | |
| DashIconify(icon="mdi:download", width=24), | |
| size="lg", | |
| color="#082030", | |
| ), | |
| id=f"download-{filename}", | |
| download=f"{filename}.csv", | |
| href=f"data:text/csv;base64,{b64}", | |
| target="_blank", | |
| title="Download CSV", | |
| style={ | |
| "padding": "6px 12px", | |
| "display": "inline-flex", | |
| "alignItems": "center", | |
| "justifyContent": "center", | |
| }, | |
| ), | |
| style={"textAlign": "right"}, | |
| ) | |
| # Helper to get popover content for each metadata type | |
| def get_metadata_popover_content(icon, name, meta_type): | |
| popover_texts = { | |
| "country": f"Country: {name}", | |
| "author": f"Author/Organization: {name}", | |
| "downloads": f"Total downloads: {name}", | |
| "modality": f"Modality: {name}", | |
| } | |
| return popover_texts.get(meta_type, name) | |
| # Chip renderer with hovercard | |
| def chip_with_hovercard(text, bg_color="#F0F0F0", meta_type=None, icon=None): | |
| hovercard_content = get_metadata_popover_content(icon, text, meta_type) | |
| return dmc.HoverCard( | |
| width="auto", | |
| shadow="md", | |
| position="top", | |
| children=[ | |
| dmc.HoverCardTarget( | |
| html.Span( | |
| text, | |
| style={ | |
| "backgroundColor": bg_color, | |
| "padding": "4px 10px", | |
| "borderRadius": "12px", | |
| "margin": "2px", | |
| "display": "inline-flex", | |
| "alignItems": "center", | |
| "fontSize": "14px", | |
| "cursor": "pointer", | |
| "transition": "background-color 0.15s", | |
| }, | |
| # Add a class for hover effect | |
| className="chip-hover-darken" | |
| ) | |
| ), | |
| dmc.HoverCardDropdown(dmc.Text(hovercard_content, size="sm")), | |
| ], | |
| ) | |
| # Render multiple chips in one row, each with popover | |
| def render_chips(metadata_list, chip_color): | |
| chips = [] | |
| for icon, name, meta_type in metadata_list: | |
| if isinstance(icon, str) and icon.endswith((".png", ".jpg", ".jpeg", ".svg")): | |
| chips.append( | |
| dmc.HoverCard( | |
| width=220, | |
| shadow="md", | |
| position="top", | |
| children=[ | |
| dmc.HoverCardTarget( | |
| html.Span( | |
| [ | |
| html.Img( | |
| src=icon, | |
| style={"height": "18px", "marginRight": "6px"}, | |
| ), | |
| name, | |
| ], | |
| style={ | |
| "backgroundColor": chip_color, | |
| "padding": "4px 10px", | |
| "borderRadius": "12px", | |
| "margin": "2px", | |
| "display": "inline-flex", | |
| "alignItems": "left", | |
| "fontSize": "14px", | |
| "cursor": "pointer", | |
| }, | |
| ) | |
| ), | |
| dmc.HoverCardDropdown( | |
| dmc.Text( | |
| get_metadata_popover_content(icon, name, meta_type), | |
| size="sm", | |
| ) | |
| ), | |
| ], | |
| ) | |
| ) | |
| else: | |
| chips.append( | |
| chip_with_hovercard(f"{icon} {name}", chip_color, meta_type, icon) | |
| ) | |
| return html.Div( | |
| chips, style={"display": "flex", "flexWrap": "wrap", "justifyContent": "left"} | |
| ) | |
| def render_table_content( | |
| df, download_df, chip_color, bar_color="#AC482A", filename="data" | |
| ): | |
| return html.Div( | |
| [ | |
| # Add download button above the table | |
| df_to_download_link(download_df, filename), | |
| # Wrap the table in a horizontal scroll container so the table can be wide | |
| html.Div( | |
| # scroll wrapper | |
| html.Table( | |
| [ | |
| html.Thead( | |
| html.Tr( | |
| [ | |
| html.Th( | |
| "Rank", | |
| className="rank-col", | |
| style={ | |
| "backgroundColor": "#F0F0F0", | |
| "textAlign": "left", | |
| }, | |
| ), | |
| html.Th( | |
| "Name", | |
| className="name-col", | |
| style={ | |
| "backgroundColor": "#F0F0F0", | |
| "textAlign": "left", | |
| }, | |
| ), | |
| html.Th( | |
| "Metadata", | |
| className="metadata-col", | |
| style={ | |
| "backgroundColor": "#F0F0F0", | |
| "textAlign": "left", | |
| "marginRight": "10px", | |
| }, | |
| ), | |
| html.Th( | |
| "% of Total", | |
| className="percent-col", | |
| style={ | |
| "backgroundColor": "#F0F0F0", | |
| "textAlign": "left", | |
| }, | |
| ), | |
| ] | |
| ) | |
| ), | |
| html.Tbody( | |
| [ | |
| html.Tr( | |
| [ | |
| html.Td(idx + 1, style={"textAlign": "center"}), | |
| html.Td(row["Name"], className="name-cell", style={"textAlign": "left"}), | |
| html.Td(render_chips(row["Metadata"], chip_color), className="metadata-cell", style={"textAlign": "left", "whiteSpace": "normal", "wordBreak": "break-word"}), | |
| html.Td( | |
| progress_bar(row["% of total"], bar_color), | |
| className="percent-cell", | |
| style={"textAlign": "center", "minWidth": "180px", "padding": "8px"}, | |
| ), | |
| ] | |
| ) | |
| for idx, row in df.iterrows() | |
| ] | |
| ), | |
| ], | |
| # allow the table to be wider than its container (minWidth prevents squish) | |
| style={"borderCollapse": "collapse", "width": "100%", "minWidth": "980px", "tableLayout": "auto"}, | |
| className="leaderboard-table", | |
| ), | |
| className="leaderboard-scroll-wrapper", | |
| style={"overflowX": "auto", "-webkit-overflow-scrolling": "touch", "width": "100%"}, | |
| ), | |
| ] | |
| ) | |
| # Function to get top N leaderboard (now accepts pandas DataFrame from DuckDB query) | |
| def get_top_n_leaderboard(filtered_df, group_col, top_n=10, derived_author_toggle=True): | |
| """ | |
| Get top N entries for a leaderboard | |
| Args: | |
| filtered_df: Pandas DataFrame of model-level rows. Must contain: | |
| - group_col (the grouping key) | |
| - total_downloads (per-model downloads for the requested window) | |
| - plus metadata columns: org_country_single, author, derived_author, merged_country_groups_single, merged_modality, model | |
| group_col: Column to group by | |
| top_n: Number of top entries to return | |
| derived_author_toggle: If True, attribute to model uploader (derived_author); if False, attribute to original model creator (author) | |
| Returns: | |
| tuple: (display_df, download_df) | |
| display_df: DataFrame with columns ["Name","Metadata","% of total"] for rendering | |
| download_df: DataFrame suitable for CSV download with numeric totals and metadata columns | |
| """ | |
| if filtered_df is None or filtered_df.empty: | |
| return pd.DataFrame(), pd.DataFrame() | |
| # Ensure numeric total_downloads | |
| if "total_downloads" not in filtered_df.columns: | |
| # fallback if older code still returned 'downloads' (unlikely) | |
| if "downloads" in filtered_df.columns: | |
| filtered_df["total_downloads"] = filtered_df["downloads"] | |
| else: | |
| filtered_df["total_downloads"] = 0 | |
| # Compute overall total across all models in this filtered set | |
| total_all = filtered_df["total_downloads"].sum() | |
| if total_all == 0: | |
| return pd.DataFrame(), pd.DataFrame() | |
| # Sum per group (group_col) to get group totals | |
| grouped = ( | |
| filtered_df.groupby(group_col)["total_downloads"] | |
| .sum() | |
| .reset_index() | |
| .rename(columns={group_col: "Name", "total_downloads": "Total Value"}) | |
| ) | |
| # Pick top N groups by summed downloads | |
| top = grouped.nlargest(top_n, columns="Total Value").reset_index(drop=True) | |
| # Compute percent of total for display (rounded) | |
| top["% of total"] = top["Total Value"].apply(lambda v: round(v * 100.0 / total_all, 2)) | |
| # Build download version (numeric) | |
| download_top = top.copy() | |
| download_top["Total Value"] = download_top["Total Value"].astype(int) | |
| download_top["% of total"] = download_top["% of total"].round(2) | |
| # All relevant metadata columns for the grouping | |
| meta_cols = META_COLS_MAP.get(group_col, []) | |
| # Collect metadata per group by inspecting the underlying model-level rows | |
| meta_map = {} | |
| download_map = {} | |
| for name in top["Name"]: | |
| name_data = filtered_df[filtered_df[group_col] == name] | |
| meta_map[name] = {} | |
| download_map[name] = {} | |
| for col in meta_cols: | |
| if col in name_data.columns: | |
| unique_vals = name_data[col].dropna().unique() | |
| meta_map[name][col] = list(unique_vals) | |
| download_map[name][col] = list(unique_vals) | |
| # Function to build metadata chips | |
| def build_metadata(nm): | |
| meta = meta_map.get(nm, {}) | |
| chips = [] | |
| # Countries | |
| for c in meta.get("org_country_single", []): | |
| if c == "United States of America": | |
| c = "USA" | |
| if c == "user": | |
| c = "User" | |
| try: | |
| flag_emoji = countryflag.getflag(c) | |
| if not flag_emoji or flag_emoji == c: | |
| flag_emoji = COUNTRY_EMOJI_FALLBACK.get(c, "🌍") | |
| except Exception: | |
| flag_emoji = COUNTRY_EMOJI_FALLBACK.get(c, "🌍") | |
| chips.append((flag_emoji, c, "country")) | |
| # Author - use derived_author_toggle to determine which column | |
| author_key = "derived_author" if derived_author_toggle else "author" | |
| for a in meta.get(author_key, []): | |
| icon = COMPANY_ICON_MAP.get(a, "") | |
| if icon == "": | |
| if meta.get("merged_country_groups_single", ["User"])[0] != "User": | |
| icon = "🏢" | |
| else: | |
| icon = "👤" | |
| chips.append((icon, a, "author")) | |
| # Modality | |
| for m in meta.get("merged_modality", []): | |
| if pd.notna(m): | |
| chips.append(("", m, "modality")) | |
| # Total downloads (aggregate numeric value for this group) | |
| # Use the summed value from top (we can retrieve it) | |
| # but we also include any per-model totals if desired - keep simple: use group total | |
| group_total = int(top.loc[top["Name"] == nm, "Total Value"].iloc[0]) if nm in top["Name"].values else None | |
| if group_total is not None: | |
| formatted_downloads = format_large_number(group_total) | |
| chips.append(("⬇️", formatted_downloads, "downloads")) | |
| return chips | |
| # Attach Metadata column for display DataFrame | |
| display_df = top.rename(columns={"Total Value": "total_downloads"}) | |
| display_df["Metadata"] = display_df["Name"].astype(object).apply(build_metadata) | |
| # Format display_df columns for render_table_content | |
| display_df_formatted = display_df.rename(columns={"% of total": "% of total"}) | |
| # Keep only necessary columns in expected order | |
| display_for_render = display_df_formatted[["Name", "Metadata", "% of total"]] | |
| # Build download dataframe with metadata for CSV | |
| download_info_list = [] | |
| for nm in download_top["Name"]: | |
| info = {} | |
| meta = download_map.get(nm, {}) | |
| for col in meta_cols: | |
| if col in meta and meta[col]: | |
| info[col] = ", ".join(str(v) for v in meta[col] if pd.notna(v)) | |
| else: | |
| info[col] = "" | |
| # attach totals | |
| info["Total Value"] = int(download_top.loc[download_top["Name"] == nm, "Total Value"].iloc[0]) | |
| info["% of total"] = float(download_top.loc[download_top["Name"] == nm, "% of total"].iloc[0]) | |
| download_info_list.append(info) | |
| download_info_df = pd.DataFrame(download_info_list) | |
| download_top = pd.concat([download_top.reset_index(drop=True), download_info_df.reset_index(drop=True)], axis=1) | |
| return display_for_render, download_top | |
| def get_top_n_from_duckdb( | |
| group_col, top_n=10, time_filter=None, view="all_downloads" | |
| ): | |
| """ | |
| Query DuckDB directly to get model-level rows with per-model total_downloads (delta or full) | |
| Returns rows similar to _get_filtered_top_n_from_duckdb in app.py. | |
| NOTE: This function now opens a fresh DuckDB connection internally and ignores | |
| any external connection passed in. Keep signature for compatibility. | |
| """ | |
| # Compute date window | |
| if time_filter and len(time_filter) == 2: | |
| start = pd.to_datetime(time_filter[0], unit="s") | |
| end = pd.to_datetime(time_filter[1], unit="s") | |
| else: | |
| start = pd.to_datetime("1970-01-01") | |
| # We cannot access end_dt here; rely on time_filter for end in typical use. | |
| end = pd.Timestamp.now() | |
| start_str = str(start) | |
| end_str = str(end) | |
| # Build SQL using the shared helper | |
| query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view) | |
| # Open a fresh in-memory connection that creates the views, run the query, close. | |
| conn_local = create_fresh_duckdb_with_views() | |
| try: | |
| return conn_local.execute(query).fetchdf() | |
| except Exception as e: | |
| print(f"Error querying DuckDB: {e}") | |
| return pd.DataFrame() | |
| finally: | |
| conn_local.close() | |