import duckdb import pandas as pd from config import HF_PARQUET_URL_1, HF_PARQUET_URL_2 def create_fresh_duckdb_with_views(parquet_url_1: str = HF_PARQUET_URL_1, parquet_url_2: str = HF_PARQUET_URL_2): """Return a fresh DuckDB connection with parquet-backed views configured.""" local_con = duckdb.connect(database=":memory:", read_only=False) try: try: local_con.execute("INSTALL httpfs;") local_con.execute("LOAD httpfs;") except Exception: pass try: local_con.execute("SET enable_http_metadata_cache = false;") local_con.execute("SET enable_object_cache = false;") except Exception: pass local_con.execute( f""" CREATE OR REPLACE VIEW all_downloads AS SELECT * FROM read_parquet('{parquet_url_1}') """ ) local_con.execute( f""" CREATE OR REPLACE VIEW one_year_rolling AS SELECT * FROM read_parquet('{parquet_url_2}') """ ) except Exception: pass return local_con def get_last_updated(): """Return the latest timestamp available in the all_downloads view.""" try: conn = create_fresh_duckdb_with_views() try: result = conn.execute("SELECT MAX(time) as max_time FROM all_downloads").fetchdf() finally: conn.close() max_time = result["max_time"].iloc[0] if pd.isnull(max_time): return "N/A" dt = pd.to_datetime(max_time) return dt.strftime("%b %d, %Y") except Exception: return "N/A" def build_leaderboard_query( group_col, top_n, start_str=None, end_str=None, date_str=None, view="all_downloads", derived_org_toggle=False, ): """Build the SQL query string for leaderboard data.""" is_alltime = date_str is not None if group_col == "org_country_single": group_expr = """CASE WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America' WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online' ELSE org_country_single END""" else: group_expr = group_col if is_alltime: base_where = f"WHERE time <= '{date_str}'" downloads_calc = f"COALESCE(MAX(CASE WHEN time <= '{date_str}' THEN downloadsAllTime END), 0) AS total_downloads" else: base_where = "" downloads_calc = f"""COALESCE(MAX(CASE WHEN time <= '{end_str}' THEN downloadsAllTime END), 0) - COALESCE(MAX(CASE WHEN time < '{start_str}' THEN downloadsAllTime END), 0) AS total_downloads""" # Determine which org_country column to use if derived_org_toggle: org_country_case = """CASE WHEN derived_org_country_single IN ('HF', 'United States of America') THEN 'United States of America' WHEN derived_org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online' ELSE derived_org_country_single END""" else: org_country_case = """CASE WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America' WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online' ELSE org_country_single END""" return f""" WITH base_data AS ( SELECT {group_expr} AS group_key, {org_country_case} AS org_country_single, author, derived_author, merged_country_groups_single, merged_modality, model, time, downloadsAllTime FROM {view} {base_where} ), model_metrics AS ( SELECT model, group_key, ANY_VALUE(org_country_single) AS org_country_single, ANY_VALUE(author) AS author, ANY_VALUE(derived_author) AS derived_author, ANY_VALUE(merged_country_groups_single) AS merged_country_groups_single, ANY_VALUE(merged_modality) AS merged_modality, {downloads_calc} FROM base_data GROUP BY model, group_key ), total_downloads_cte AS ( SELECT SUM(total_downloads) AS total_downloads_all FROM model_metrics ) SELECT mm.model, mm.group_key, mm.org_country_single, mm.author, mm.derived_author, mm.merged_country_groups_single, mm.merged_modality, mm.total_downloads, CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total FROM model_metrics mm CROSS JOIN total_downloads_cte td WHERE mm.total_downloads > 0 ORDER BY mm.total_downloads DESC LIMIT {top_n * 10}; """ def get_top_n_from_duckdb(group_col, top_n=10, time_filter=None, view="all_downloads", derived_org_toggle=False): """Query DuckDB directly to get model-level rows with per-model total_downloads.""" if time_filter and len(time_filter) == 2: start = pd.to_datetime(time_filter[0], unit="s") end = pd.to_datetime(time_filter[1], unit="s") else: start = pd.to_datetime("1970-01-01") end = pd.Timestamp.now() start_str = str(start) end_str = str(end) query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view, derived_org_toggle=derived_org_toggle) conn_local = create_fresh_duckdb_with_views() try: return conn_local.execute(query).fetchdf() except Exception as exc: print(f"Error querying DuckDB: {exc}") return pd.DataFrame() finally: conn_local.close()