open-model-evolution / data_utils.py
emsesc's picture
fixed author attribution issue
aaa721d
import duckdb
import pandas as pd
from config import HF_PARQUET_URL_1, HF_PARQUET_URL_2
def create_fresh_duckdb_with_views(parquet_url_1: str = HF_PARQUET_URL_1, parquet_url_2: str = HF_PARQUET_URL_2):
"""Return a fresh DuckDB connection with parquet-backed views configured."""
local_con = duckdb.connect(database=":memory:", read_only=False)
try:
try:
local_con.execute("INSTALL httpfs;")
local_con.execute("LOAD httpfs;")
except Exception:
pass
try:
local_con.execute("SET enable_http_metadata_cache = false;")
local_con.execute("SET enable_object_cache = false;")
except Exception:
pass
local_con.execute(
f"""
CREATE OR REPLACE VIEW all_downloads AS
SELECT * FROM read_parquet('{parquet_url_1}')
"""
)
local_con.execute(
f"""
CREATE OR REPLACE VIEW one_year_rolling AS
SELECT * FROM read_parquet('{parquet_url_2}')
"""
)
except Exception:
pass
return local_con
def get_last_updated():
"""Return the latest timestamp available in the all_downloads view."""
try:
conn = create_fresh_duckdb_with_views()
try:
result = conn.execute("SELECT MAX(time) as max_time FROM all_downloads").fetchdf()
finally:
conn.close()
max_time = result["max_time"].iloc[0]
if pd.isnull(max_time):
return "N/A"
dt = pd.to_datetime(max_time)
return dt.strftime("%b %d, %Y")
except Exception:
return "N/A"
def build_leaderboard_query(
group_col,
top_n,
start_str=None,
end_str=None,
date_str=None,
view="all_downloads",
derived_org_toggle=False,
):
"""Build the SQL query string for leaderboard data."""
is_alltime = date_str is not None
if group_col == "org_country_single":
group_expr = """CASE
WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
ELSE org_country_single
END"""
else:
group_expr = group_col
if is_alltime:
base_where = f"WHERE time <= '{date_str}'"
downloads_calc = f"COALESCE(MAX(CASE WHEN time <= '{date_str}' THEN downloadsAllTime END), 0) AS total_downloads"
else:
base_where = ""
downloads_calc = f"""COALESCE(MAX(CASE WHEN time <= '{end_str}' THEN downloadsAllTime END), 0)
- COALESCE(MAX(CASE WHEN time < '{start_str}' THEN downloadsAllTime END), 0)
AS total_downloads"""
# Determine which org_country column to use
if derived_org_toggle:
org_country_case = """CASE
WHEN derived_org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
WHEN derived_org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
ELSE derived_org_country_single
END"""
else:
org_country_case = """CASE
WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
ELSE org_country_single
END"""
return f"""
WITH base_data AS (
SELECT
{group_expr} AS group_key,
{org_country_case} AS org_country_single,
author,
derived_author,
merged_country_groups_single,
merged_modality,
model,
time,
downloadsAllTime
FROM {view}
{base_where}
),
model_metrics AS (
SELECT
model,
group_key,
ANY_VALUE(org_country_single) AS org_country_single,
ANY_VALUE(author) AS author,
ANY_VALUE(derived_author) AS derived_author,
ANY_VALUE(merged_country_groups_single) AS merged_country_groups_single,
ANY_VALUE(merged_modality) AS merged_modality,
{downloads_calc}
FROM base_data
GROUP BY model, group_key
),
total_downloads_cte AS (
SELECT SUM(total_downloads) AS total_downloads_all FROM model_metrics
)
SELECT
mm.model,
mm.group_key,
mm.org_country_single,
mm.author,
mm.derived_author,
mm.merged_country_groups_single,
mm.merged_modality,
mm.total_downloads,
CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
FROM model_metrics mm
CROSS JOIN total_downloads_cte td
WHERE mm.total_downloads > 0
ORDER BY mm.total_downloads DESC
LIMIT {top_n * 10};
"""
def get_top_n_from_duckdb(group_col, top_n=10, time_filter=None, view="all_downloads", derived_org_toggle=False):
"""Query DuckDB directly to get model-level rows with per-model total_downloads."""
if time_filter and len(time_filter) == 2:
start = pd.to_datetime(time_filter[0], unit="s")
end = pd.to_datetime(time_filter[1], unit="s")
else:
start = pd.to_datetime("1970-01-01")
end = pd.Timestamp.now()
start_str = str(start)
end_str = str(end)
query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view, derived_org_toggle=derived_org_toggle)
conn_local = create_fresh_duckdb_with_views()
try:
return conn_local.execute(query).fetchdf()
except Exception as exc:
print(f"Error querying DuckDB: {exc}")
return pd.DataFrame()
finally:
conn_local.close()