emsesc commited on
Commit
f9010df
·
1 Parent(s): da7a067

consolidate queries into function

Browse files
Files changed (2) hide show
  1. app.py +8 -133
  2. graphs/leaderboard.py +100 -82
app.py CHANGED
@@ -3,7 +3,12 @@ import pandas as pd
3
  import dash_mantine_components as dmc
4
  import duckdb
5
  import time
6
- from graphs.leaderboard import button_style, get_top_n_leaderboard, render_table_content
 
 
 
 
 
7
  from dash_iconify import DashIconify
8
 
9
  # Initialize the app
@@ -796,138 +801,8 @@ def _get_filtered_top_n_from_duckdb(
796
  start_str = str(start)
797
  end_str = str(end)
798
 
799
- # If grouping by country, transform some country values
800
- if group_col == "org_country_single":
801
- group_expr = """CASE
802
- WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
803
- WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
804
- ELSE org_country_single
805
- END"""
806
- else:
807
- group_expr = group_col
808
-
809
- # Derived-author requires author->country lookup; build separate SQL for that case
810
- if group_col == "derived_author":
811
- query = f"""
812
- WITH base_data AS (
813
- SELECT
814
- {group_expr} AS group_key,
815
- CASE
816
- WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
817
- WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
818
- ELSE org_country_single
819
- END AS org_country_single,
820
- author,
821
- derived_author,
822
- merged_country_groups_single,
823
- merged_modality,
824
- model,
825
- time,
826
- downloadsAllTime
827
- FROM {view}
828
- ),
829
-
830
- author_country_lookup AS (
831
- SELECT DISTINCT
832
- author,
833
- FIRST_VALUE(org_country_single) OVER (PARTITION BY author ORDER BY downloadsAllTime DESC) AS author_country
834
- FROM base_data
835
- WHERE author IS NOT NULL
836
- ),
837
-
838
- model_metrics AS (
839
- SELECT
840
- model,
841
- group_key,
842
- ANY_VALUE(org_country_single) AS org_country_single,
843
- ANY_VALUE(author) AS author,
844
- ANY_VALUE(derived_author) AS derived_author,
845
- ANY_VALUE(merged_country_groups_single) AS merged_country_groups_single,
846
- ANY_VALUE(merged_modality) AS merged_modality,
847
- COALESCE(MAX(CASE WHEN time <= '{end_str}' THEN downloadsAllTime END), 0)
848
- - COALESCE(MAX(CASE WHEN time < '{start_str}' THEN downloadsAllTime END), 0)
849
- AS total_downloads
850
- FROM base_data
851
- GROUP BY model, group_key
852
- ),
853
-
854
- total_downloads_cte AS (
855
- SELECT SUM(total_downloads) AS total_downloads_all FROM model_metrics
856
- )
857
-
858
- SELECT
859
- mm.model,
860
- mm.group_key,
861
- COALESCE(acl.author_country, mm.org_country_single) AS org_country_single,
862
- mm.author,
863
- mm.derived_author,
864
- mm.merged_country_groups_single,
865
- mm.merged_modality,
866
- mm.total_downloads,
867
- CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
868
- FROM model_metrics mm
869
- LEFT JOIN author_country_lookup acl ON mm.group_key = acl.author
870
- CROSS JOIN total_downloads_cte td
871
- WHERE mm.total_downloads > 0
872
- ORDER BY mm.total_downloads DESC
873
- LIMIT {top_n * 10};
874
- """
875
- else:
876
- query = f"""
877
- WITH base_data AS (
878
- SELECT
879
- {group_expr} AS group_key,
880
- CASE
881
- WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
882
- WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
883
- ELSE org_country_single
884
- END AS org_country_single,
885
- author,
886
- derived_author,
887
- merged_country_groups_single,
888
- merged_modality,
889
- model,
890
- time,
891
- downloadsAllTime
892
- FROM {view}
893
- ),
894
-
895
- model_metrics AS (
896
- SELECT
897
- model,
898
- group_key,
899
- ANY_VALUE(org_country_single) AS org_country_single,
900
- ANY_VALUE(author) AS author,
901
- ANY_VALUE(derived_author) AS derived_author,
902
- ANY_VALUE(merged_country_groups_single) AS merged_country_groups_single,
903
- ANY_VALUE(merged_modality) AS merged_modality,
904
- COALESCE(MAX(CASE WHEN time <= '{end_str}' THEN downloadsAllTime END), 0)
905
- - COALESCE(MAX(CASE WHEN time < '{start_str}' THEN downloadsAllTime END), 0)
906
- AS total_downloads
907
- FROM base_data
908
- GROUP BY model, group_key
909
- ),
910
-
911
- total_downloads_cte AS (
912
- SELECT SUM(total_downloads) AS total_downloads_all FROM model_metrics
913
- )
914
-
915
- SELECT
916
- mm.model,
917
- mm.group_key,
918
- mm.org_country_single,
919
- mm.author,
920
- mm.derived_author,
921
- mm.merged_country_groups_single,
922
- mm.merged_modality,
923
- mm.total_downloads,
924
- CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
925
- FROM model_metrics mm
926
- CROSS JOIN total_downloads_cte td
927
- WHERE mm.total_downloads > 0
928
- ORDER BY mm.total_downloads DESC
929
- LIMIT {top_n * 10};
930
- """
931
 
932
  # execute using the fresh local connection
933
  result_df = local_con.execute(query).fetchdf()
 
3
  import dash_mantine_components as dmc
4
  import duckdb
5
  import time
6
+ from graphs.leaderboard import (
7
+ button_style,
8
+ get_top_n_leaderboard,
9
+ render_table_content,
10
+ build_leaderboard_query,
11
+ )
12
  from dash_iconify import DashIconify
13
 
14
  # Initialize the app
 
801
  start_str = str(start)
802
  end_str = str(end)
803
 
804
+ # Build query using shared function
805
+ query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
 
807
  # execute using the fresh local connection
808
  result_df = local_con.execute(query).fetchdf()
graphs/leaderboard.py CHANGED
@@ -1,5 +1,5 @@
1
  import pandas as pd
2
- from dash import html, dcc
3
  from dash_iconify import DashIconify
4
  import dash_mantine_components as dmc
5
  import base64
@@ -496,28 +496,12 @@ def create_fresh_duckdb_with_views():
496
  return local_con
497
 
498
 
499
- def get_top_n_from_duckdb(
500
- con, group_col, top_n=10, time_filter=None, view="all_downloads"
501
- ):
502
  """
503
- Query DuckDB directly to get model-level rows with per-model total_downloads (delta or full)
504
- Returns rows similar to _get_filtered_top_n_from_duckdb in app.py.
505
- NOTE: This function now opens a fresh DuckDB connection internally and ignores
506
- any external connection passed in. Keep signature for compatibility.
507
  """
508
- # Compute date window
509
- if time_filter and len(time_filter) == 2:
510
- start = pd.to_datetime(time_filter[0], unit="s")
511
- end = pd.to_datetime(time_filter[1], unit="s")
512
- else:
513
- start = pd.to_datetime("1970-01-01")
514
- # We cannot access end_dt here; rely on time_filter for end in typical use.
515
- end = pd.Timestamp.now()
516
-
517
- start_str = str(start)
518
- end_str = str(end)
519
-
520
- # If grouping by country, transform some country values
521
  if group_col == "org_country_single":
522
  group_expr = """CASE
523
  WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
@@ -527,9 +511,9 @@ def get_top_n_from_duckdb(
527
  else:
528
  group_expr = group_col
529
 
530
- # Derived author special-case
531
  if group_col == "derived_author":
532
- query = f"""
533
  WITH base_data AS (
534
  SELECT
535
  {group_expr} AS group_key,
@@ -550,67 +534,18 @@ def get_top_n_from_duckdb(
550
 
551
  author_country_lookup AS (
552
  SELECT DISTINCT
553
- author,
554
- FIRST_VALUE(org_country_single) OVER (PARTITION BY author ORDER BY downloadsAllTime DESC) AS author_country
555
- FROM base_data
556
- WHERE author IS NOT NULL
557
- ),
558
-
559
- model_metrics AS (
560
- SELECT
561
- model,
562
- group_key,
563
- ANY_VALUE(org_country_single) AS org_country_single,
564
- ANY_VALUE(author) AS author,
565
- ANY_VALUE(derived_author) AS derived_author,
566
- ANY_VALUE(merged_country_groups_single) AS merged_country_groups_single,
567
- ANY_VALUE(merged_modality) AS merged_modality,
568
- COALESCE(MAX(CASE WHEN time <= '{end_str}' THEN downloadsAllTime END), 0)
569
- - COALESCE(MAX(CASE WHEN time < '{start_str}' THEN downloadsAllTime END), 0)
570
- AS total_downloads
571
  FROM base_data
572
- GROUP BY model, group_key
573
  ),
574
 
575
- total_downloads_cte AS (
576
- SELECT SUM(total_downloads) AS total_downloads_all FROM model_metrics
577
- )
578
-
579
- SELECT
580
- mm.model,
581
- mm.group_key,
582
- COALESCE(acl.author_country, mm.org_country_single) AS org_country_single,
583
- mm.author,
584
- mm.derived_author,
585
- mm.merged_country_groups_single,
586
- mm.merged_modality,
587
- mm.total_downloads,
588
- CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
589
- FROM model_metrics mm
590
- LEFT JOIN author_country_lookup acl ON mm.group_key = acl.author
591
- CROSS JOIN total_downloads_cte td
592
- WHERE mm.total_downloads > 0
593
- ORDER BY mm.total_downloads DESC
594
- LIMIT {top_n * 10};
595
- """
596
- else:
597
- query = f"""
598
- WITH base_data AS (
599
- SELECT
600
- {group_expr} AS group_key,
601
- CASE
602
- WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
603
- WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
604
- ELSE org_country_single
605
- END AS org_country_single,
606
- author,
607
  derived_author,
608
- merged_country_groups_single,
609
- merged_modality,
610
- model,
611
- time,
612
- downloadsAllTime
613
- FROM {view}
614
  ),
615
 
616
  model_metrics AS (
@@ -636,19 +571,102 @@ def get_top_n_from_duckdb(
636
  SELECT
637
  mm.model,
638
  mm.group_key,
639
- mm.org_country_single,
 
640
  mm.author,
641
  mm.derived_author,
642
- mm.merged_country_groups_single,
643
  mm.merged_modality,
644
  mm.total_downloads,
645
  CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
646
  FROM model_metrics mm
 
 
647
  CROSS JOIN total_downloads_cte td
648
  WHERE mm.total_downloads > 0
649
  ORDER BY mm.total_downloads DESC
650
  LIMIT {top_n * 10};
651
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
  # Open a fresh in-memory connection that creates the views, run the query, close.
654
  conn_local = create_fresh_duckdb_with_views()
 
1
  import pandas as pd
2
+ from dash import html
3
  from dash_iconify import DashIconify
4
  import dash_mantine_components as dmc
5
  import base64
 
496
  return local_con
497
 
498
 
499
+ def build_leaderboard_query(group_col, top_n, start_str, end_str, view="all_downloads"):
 
 
500
  """
501
+ Build and return the SQL query string for the given grouping.
502
+ Encapsulates common parts used both in app.py and in this module.
 
 
503
  """
504
+ # handle country grouping normalization
 
 
 
 
 
 
 
 
 
 
 
 
505
  if group_col == "org_country_single":
506
  group_expr = """CASE
507
  WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
 
511
  else:
512
  group_expr = group_col
513
 
514
+ # Derived-author special-case (uses author-derived lookups)
515
  if group_col == "derived_author":
516
+ return f"""
517
  WITH base_data AS (
518
  SELECT
519
  {group_expr} AS group_key,
 
534
 
535
  author_country_lookup AS (
536
  SELECT DISTINCT
537
+ derived_author,
538
+ FIRST_VALUE(org_country_single) OVER (PARTITION BY derived_author ORDER BY downloadsAllTime DESC) AS derived_author_country
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  FROM base_data
540
+ WHERE derived_author IS NOT NULL
541
  ),
542
 
543
+ author_merged_country_lookup AS (
544
+ SELECT DISTINCT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  derived_author,
546
+ FIRST_VALUE(merged_country_groups_single) OVER (PARTITION BY derived_author ORDER BY downloadsAllTime DESC) AS derived_author_merged_country
547
+ FROM base_data
548
+ WHERE derived_author IS NOT NULL
 
 
 
549
  ),
550
 
551
  model_metrics AS (
 
571
  SELECT
572
  mm.model,
573
  mm.group_key,
574
+ acl.derived_author_country AS org_country_single,
575
+ amc.derived_author_merged_country AS merged_country_groups_single,
576
  mm.author,
577
  mm.derived_author,
 
578
  mm.merged_modality,
579
  mm.total_downloads,
580
  CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
581
  FROM model_metrics mm
582
+ LEFT JOIN author_country_lookup acl ON mm.group_key = acl.derived_author
583
+ LEFT JOIN author_merged_country_lookup amc ON mm.group_key = amc.derived_author
584
  CROSS JOIN total_downloads_cte td
585
  WHERE mm.total_downloads > 0
586
  ORDER BY mm.total_downloads DESC
587
  LIMIT {top_n * 10};
588
  """
589
+ # Generic grouping SQL
590
+ return f"""
591
+ WITH base_data AS (
592
+ SELECT
593
+ {group_expr} AS group_key,
594
+ CASE
595
+ WHEN org_country_single IN ('HF', 'United States of America') THEN 'United States of America'
596
+ WHEN org_country_single IN ('International', 'Online', 'Online?') THEN 'International/Online'
597
+ ELSE org_country_single
598
+ END AS org_country_single,
599
+ author,
600
+ derived_author,
601
+ merged_country_groups_single,
602
+ merged_modality,
603
+ model,
604
+ time,
605
+ downloadsAllTime
606
+ FROM {view}
607
+ ),
608
+
609
+ model_metrics AS (
610
+ SELECT
611
+ model,
612
+ group_key,
613
+ ANY_VALUE(org_country_single) AS org_country_single,
614
+ ANY_VALUE(author) AS author,
615
+ ANY_VALUE(derived_author) AS derived_author,
616
+ ANY_VALUE(merged_country_groups_single) AS merged_country_groups_single,
617
+ ANY_VALUE(merged_modality) AS merged_modality,
618
+ COALESCE(MAX(CASE WHEN time <= '{end_str}' THEN downloadsAllTime END), 0)
619
+ - COALESCE(MAX(CASE WHEN time < '{start_str}' THEN downloadsAllTime END), 0)
620
+ AS total_downloads
621
+ FROM base_data
622
+ GROUP BY model, group_key
623
+ ),
624
+
625
+ total_downloads_cte AS (
626
+ SELECT SUM(total_downloads) AS total_downloads_all FROM model_metrics
627
+ )
628
+
629
+ SELECT
630
+ mm.model,
631
+ mm.group_key,
632
+ mm.org_country_single,
633
+ mm.author,
634
+ mm.derived_author,
635
+ mm.merged_country_groups_single,
636
+ mm.merged_modality,
637
+ mm.total_downloads,
638
+ CASE WHEN td.total_downloads_all = 0 THEN 0 ELSE ROUND(mm.total_downloads * 100.0 / td.total_downloads_all, 2) END AS percent_of_total
639
+ FROM model_metrics mm
640
+ CROSS JOIN total_downloads_cte td
641
+ WHERE mm.total_downloads > 0
642
+ ORDER BY mm.total_downloads DESC
643
+ LIMIT {top_n * 10};
644
+ """
645
+
646
+
647
+ def get_top_n_from_duckdb(
648
+ con, group_col, top_n=10, time_filter=None, view="all_downloads"
649
+ ):
650
+ """
651
+ Query DuckDB directly to get model-level rows with per-model total_downloads (delta or full)
652
+ Returns rows similar to _get_filtered_top_n_from_duckdb in app.py.
653
+ NOTE: This function now opens a fresh DuckDB connection internally and ignores
654
+ any external connection passed in. Keep signature for compatibility.
655
+ """
656
+ # Compute date window
657
+ if time_filter and len(time_filter) == 2:
658
+ start = pd.to_datetime(time_filter[0], unit="s")
659
+ end = pd.to_datetime(time_filter[1], unit="s")
660
+ else:
661
+ start = pd.to_datetime("1970-01-01")
662
+ # We cannot access end_dt here; rely on time_filter for end in typical use.
663
+ end = pd.Timestamp.now()
664
+
665
+ start_str = str(start)
666
+ end_str = str(end)
667
+
668
+ # Build SQL using the shared helper
669
+ query = build_leaderboard_query(group_col, top_n, start_str, end_str, view=view)
670
 
671
  # Open a fresh in-memory connection that creates the views, run the query, close.
672
  conn_local = create_fresh_duckdb_with_views()