File size: 5,858 Bytes
927a4de
 
 
aaa721d
927a4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaa721d
927a4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaa721d
 
 
 
 
 
 
 
 
 
 
 
 
927a4de
 
 
 
 
aaa721d
927a4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaa721d
927a4de
 
 
 
 
 
 
 
 
 
aaa721d
927a4de
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import duckdb
import pandas as pd

from config import HF_PARQUET_URL_1, HF_PARQUET_URL_2


def create_fresh_duckdb_with_views(parquet_url_1: str = HF_PARQUET_URL_1, parquet_url_2: str = HF_PARQUET_URL_2):
    """Return a fresh DuckDB connection with parquet-backed views configured."""
    local_con = duckdb.connect(database=":memory:", read_only=False)
    try:
        try:
            local_con.execute("INSTALL httpfs;")
            local_con.execute("LOAD httpfs;")
        except Exception:
            pass

        try:
            local_con.execute("SET enable_http_metadata_cache = false;")
            local_con.execute("SET enable_object_cache = false;")
        except Exception:
            pass

        local_con.execute(
            f"""
            CREATE OR REPLACE VIEW all_downloads AS
            SELECT * FROM read_parquet('{parquet_url_1}')
        """
        )
        local_con.execute(
            f"""
            CREATE OR REPLACE VIEW one_year_rolling AS
            SELECT * FROM read_parquet('{parquet_url_2}')
        """
        )
    except Exception:
        pass
    return local_con


def get_last_updated():
    """Return the latest timestamp available in the all_downloads view."""
    try:
        conn = create_fresh_duckdb_with_views()
        try:
            result = conn.execute("SELECT MAX(time) as max_time FROM all_downloads").fetchdf()
        finally:
            conn.close()

        max_time = result["max_time"].iloc[0]
        if pd.isnull(max_time):
            return "N/A"
        dt = pd.to_datetime(max_time)
        return dt.strftime("%b %d, %Y")
    except Exception:
        return "N/A"


def build_leaderboard_query(
    group_col,
    top_n,
    start_str=None,
    end_str=None,
    date_str=None,
    view="all_downloads",
    derived_org_toggle=False,
):
    """Build the SQL query string for leaderboard data."""
    is_alltime = date_str is not None

    if group_col == "org_country_single":
        group_expr = """CASE 
            WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
            WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
            ELSE org_country_single
        END"""
    else:
        group_expr = group_col

    if is_alltime:
        base_where = f"WHERE time <= '{date_str}'"
        downloads_calc = f"COALESCE(MAX(CASE WHEN time <= '{date_str}' THEN downloadsAllTime END), 0) AS total_downloads"
    else:
        base_where = ""
        downloads_calc = f"""COALESCE(MAX(CASE WHEN time <= '{end_str}' THEN downloadsAllTime END), 0)
            - COALESCE(MAX(CASE WHEN time < '{start_str}' THEN downloadsAllTime END), 0)
            AS total_downloads"""

    # Determine which org_country column to use
    if derived_org_toggle:
        org_country_case = """CASE 
            WHEN derived_org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
            WHEN derived_org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
            ELSE derived_org_country_single
        END"""
    else:
        org_country_case = """CASE 
            WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
            WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
            ELSE org_country_single
        END"""

    return f"""
    WITH base_data AS (
        SELECT
            {group_expr} AS group_key,
            {org_country_case} AS org_country_single,
            author,
            derived_author,
            merged_country_groups_single,
            merged_modality,
            model,
            time,
            downloadsAllTime
        FROM {view}
        {base_where}
    ),

    model_metrics AS (
        SELECT
            model,
            group_key,
            ANY_VALUE(org_country_single) AS org_country_single,
            ANY_VALUE(author) AS author,
            ANY_VALUE(derived_author) AS derived_author,
            ANY_VALUE(merged_country_groups_single) AS merged_country_groups_single,
            ANY_VALUE(merged_modality) AS merged_modality,
            {downloads_calc}
        FROM base_data
        GROUP BY model, group_key
    ),

    total_downloads_cte AS (
        SELECT SUM(total_downloads) AS total_downloads_all FROM model_metrics
    )

    SELECT
        mm.model,
        mm.group_key,
        mm.org_country_single,
        mm.author,
        mm.derived_author,
        mm.merged_country_groups_single,
        mm.merged_modality,
        mm.total_downloads,
        CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
    FROM model_metrics mm
    CROSS JOIN total_downloads_cte td
    WHERE mm.total_downloads > 0
    ORDER BY mm.total_downloads DESC
    LIMIT {top_n * 10};
    """


def get_top_n_from_duckdb(group_col, top_n=10, time_filter=None, view="all_downloads", derived_org_toggle=False):
    """Query DuckDB directly to get model-level rows with per-model total_downloads."""
    if time_filter and len(time_filter) == 2:
        start = pd.to_datetime(time_filter[0], unit="s")
        end = pd.to_datetime(time_filter[1], unit="s")
    else:
        start = pd.to_datetime("1970-01-01")
        end = pd.Timestamp.now()

    start_str = str(start)
    end_str = str(end)
    query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view, derived_org_toggle=derived_org_toggle)

    conn_local = create_fresh_duckdb_with_views()
    try:
        return conn_local.execute(query).fetchdf()
    except Exception as exc:
        print(f"Error querying DuckDB: {exc}")
        return pd.DataFrame()
    finally:
        conn_local.close()