File size: 40,766 Bytes
bbf45d0
 
 
961c6fe
f30a36e
 
 
 
 
b06975a
f30a36e
 
961c6fe
 
b06975a
f30a36e
 
 
 
 
 
 
 
 
 
 
8c60635
 
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c60635
f30a36e
8c60635
afd7356
 
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c60635
afd7356
 
 
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b7e37
 
 
961c6fe
f30a36e
addb03f
e65f153
 
 
f30a36e
addb03f
8c60635
 
 
9c451ee
8c60635
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c60635
 
 
 
 
 
 
 
 
 
 
0c6bf95
8c60635
f30a36e
 
 
e65f153
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
addb03f
f30a36e
 
addb03f
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c451ee
f30a36e
e65f153
 
f0e2fd8
961c6fe
9c451ee
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961c6fe
f30a36e
 
 
8c60635
f30a36e
 
 
 
 
 
 
 
 
 
 
8c60635
bbf45d0
 
813c7cf
f30a36e
 
 
 
 
 
 
 
 
 
 
8c60635
f0e2fd8
f30a36e
d858aa5
eec69ec
f30a36e
eec69ec
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b7e37
8c60635
 
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961c6fe
813c7cf
 
b06975a
813c7cf
961c6fe
8c60635
961c6fe
813c7cf
fa2c2d2
b06975a
 
fa2c2d2
8c60635
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813c7cf
 
f30a36e
813c7cf
4517d15
961c6fe
f30a36e
 
 
 
961c6fe
813c7cf
 
c0b7e37
813c7cf
f30a36e
 
 
 
813c7cf
8c60635
813c7cf
 
f30a36e
 
813c7cf
 
f30a36e
 
 
 
 
 
 
813c7cf
f30a36e
813c7cf
f30a36e
813c7cf
 
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
4d0811f
f30a36e
 
 
 
8c60635
813c7cf
d858aa5
addb03f
 
f0e2fd8
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c451ee
b06975a
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afd7356
961c6fe
b06975a
961c6fe
d858aa5
f30a36e
 
 
 
 
 
 
 
 
 
 
8c60635
961c6fe
 
813c7cf
 
 
0c6bf95
813c7cf
0c6bf95
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
47e0cf9
afd7356
8c60635
 
 
 
 
 
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf45d0
961c6fe
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
f0e2fd8
bbf45d0
f30a36e
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf45d0
 
8c60635
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
import gradio as gr
import pandas as pd
import plotly.express as px
import time
import os
import tempfile
import requests
import duckdb
import json
from datasets import load_dataset
from huggingface_hub import logout as hf_logout
from gradio_rangeslider import RangeSlider

# --- Constants ---
TOP_K_CHOICES = list(range(5, 51, 5))
HF_DATASET_ID = "evijit/paperverse_daily_data"
# Direct parquet file URL (public)
PARQUET_URL = "https://huggingface.co/datasets/evijit/paperverse_daily_data/resolve/main/papers_with_semantic_taxonomy.parquet"
TAXONOMY_JSON_PATH = "integrated_ml_taxonomy.json"

# Simple content filters derived from the new dataset
TAG_FILTER_CHOICES = [
    "None",
    "Has Code",
    "Has Media",
    "Has Organization",
]

# Load taxonomy from JSON file
def load_taxonomy():
    """Load the ML taxonomy from JSON file."""
    try:
        with open(TAXONOMY_JSON_PATH, 'r') as f:
            taxonomy = json.load(f)
        
        # Extract choices for dropdowns
        categories = sorted(taxonomy.keys())
        
        # Build subcategories and topics
        all_subcategories = set()
        all_topics = set()
        
        for category, subcats in taxonomy.items():
            for subcat, topics in subcats.items():
                all_subcategories.add(subcat)
                all_topics.update(topics)
        
        return {
            'categories': ["All"] + categories,
            'subcategories': ["All"] + sorted(all_subcategories),
            'topics': ["All"] + sorted(all_topics),
            'taxonomy': taxonomy
        }
    except Exception as e:
        print(f"Error loading taxonomy from JSON: {e}")
        return {
            'categories': ["All"],
            'subcategories': ["All"],
            'topics': ["All"],
            'taxonomy': {}
        }

TAXONOMY_DATA = load_taxonomy()

def _first_non_null(*values):
    for v in values:
        if v is None:
            continue
        # treat empty strings as null-ish
        if isinstance(v, str) and v.strip() == "":
            continue
        return v
    return None


def _get_nested(row, *paths):
    """Try multiple dotted paths in a row that may contain dicts; return first non-null."""
    for path in paths:
        cur = row
        ok = True
        for key in path.split('.'):
            if isinstance(cur, dict) and key in cur:
                cur = cur[key]
            else:
                ok = False
                break
        if ok and cur is not None:
            return cur
    return None


def load_datasets_data():
    """Load the PaperVerse Daily dataset from the Hugging Face Hub and normalize columns used by the app."""
    start_time = time.time()
    print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}")
    try:
        # First try: direct parquet download (avoids any auth header issues)
        try:
            print(f"Trying direct parquet download: {PARQUET_URL}")
            with requests.get(PARQUET_URL, stream=True, timeout=120) as resp:
                resp.raise_for_status()
                with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmpf:
                    for chunk in resp.iter_content(chunk_size=1024 * 1024):
                        if chunk:
                            tmpf.write(chunk)
                    tmp_path = tmpf.name
            try:
                # Use DuckDB to read parquet to avoid pyarrow decoding issues
                df = duckdb.query(f"SELECT * FROM read_parquet('{tmp_path}')").df()
            finally:
                try:
                    os.remove(tmp_path)
                except Exception:
                    pass
            print("Loaded DataFrame from direct parquet download via DuckDB.")
        except Exception as direct_e:
            print(f"Direct parquet load failed: {direct_e}. Falling back to datasets loader...")
            # Force anonymous access in case an invalid cached token is present
            # Clear any token environment variables that could inject a bad Authorization header
            for env_key in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HF_HUB_TOKEN"):
                if os.environ.pop(env_key, None) is not None:
                    print(f"Cleared env var: {env_key}")

            # Prefer explicit train split when available
            try:
                dataset_obj = load_dataset(HF_DATASET_ID, split="train", token=None)
            except TypeError:
                dataset_obj = load_dataset(HF_DATASET_ID, split="train", use_auth_token=False)
            except Exception:
                # Fallback: load all splits and pick the first available
                try:
                    dataset_obj = load_dataset(HF_DATASET_ID, token=None)
                except TypeError:
                    dataset_obj = load_dataset(HF_DATASET_ID, use_auth_token=False)

            # Handle both Dataset and DatasetDict
            try:
                # If it's a Dataset (single split), this will work
                df = dataset_obj.to_pandas()
            except AttributeError:
                # Otherwise assume DatasetDict and take the first split
                first_split = list(dataset_obj.keys())[0]
                df = dataset_obj[first_split].to_pandas()

        # --- Normalize expected columns for the visualization ---
        # organization: prefer top-level organization_name, then paper_organization.name/fullname, else Unknown
        if 'organization_name' in df.columns:
            org_series = df['organization_name']
        else:
            # try nested dicts commonly produced by HF datasets
            org_series = df.apply(
                lambda r: _first_non_null(
                    _get_nested(r, 'paper_organization.name'),
                    _get_nested(r, 'paper_organization.fullname'),
                    _get_nested(r, 'organization.name'),
                    _get_nested(r, 'organization.fullname')
                ), axis=1
            )
        df['organization'] = org_series.fillna('Unknown')

        # Extract organization avatar/logo
        if 'organization_name' in df.columns:
            # Try to get avatar from paper_organization or organization struct
            def _get_avatar(row):
                for path in ['paper_organization.avatar', 'organization.avatar']:
                    av = _get_nested(row, path)
                    if av and isinstance(av, str) and av.strip():
                        return av
                return None
            org_avatar_series = df.apply(_get_avatar, axis=1)
        else:
            org_avatar_series = pd.Series([None] * len(df))
        df['organization_avatar'] = org_avatar_series

        # id for each paper row
        cand_cols = [
            'paper_id', 'paper_discussionId', 'key'
        ]
        id_val = None
        for c in cand_cols:
            if c in df.columns:
                id_val = df[c]
                break
        if id_val is None:
            # fallback to title + index
            if 'paper_title' in df.columns:
                df['id'] = df['paper_title'].astype(str) + '_' + df.reset_index().index.astype(str)
            elif 'title' in df.columns:
                df['id'] = df['title'].astype(str) + '_' + df.reset_index().index.astype(str)
            else:
                df['id'] = df.reset_index().index.astype(str)
        else:
            df['id'] = id_val.astype(str)

        # numeric metrics used for aggregation
        def _to_num(col_name):
            if col_name in df.columns:
                return pd.to_numeric(df[col_name], errors='coerce').fillna(0.0)
            return pd.Series([0.0] * len(df))

        df['paper_upvotes'] = _to_num('paper_upvotes')
        df['numComments'] = _to_num('numComments')
        df['paper_githubStars'] = _to_num('paper_githubStars')

        # computed boolean filters
        def _has_code(row):
            # Check for GitHub repo
            try:
                gh = row['paper_githubRepo'] if 'paper_githubRepo' in row and pd.notna(row['paper_githubRepo']) else None
                if isinstance(gh, str) and len(gh.strip()) > 0:
                    return True
            except Exception:
                pass
            # Check for project page
            try:
                pp = row.get('paper_projectPage') if isinstance(row, dict) else row.get('paper_projectPage', None)
                if isinstance(pp, str) and len(str(pp).strip()) > 0 and str(pp).strip().lower() != 'n/a':
                    return True
            except Exception:
                pass
            return False

        def _has_media(row):
            for c in ['paper_mediaUrls', 'mediaUrls']:
                try:
                    v = row[c]
                    if isinstance(v, list) and len(v) > 0:
                        return True
                    # some providers store arrays as strings like "[... ]"
                    if isinstance(v, str) and v.strip().startswith('[') and len(v.strip()) > 2:
                        return True
                except Exception:
                    continue
            return False

        df['has_code'] = df.apply(_has_code, axis=1)
        df['has_media'] = df.apply(_has_media, axis=1)
        df['has_organization'] = df['organization'].astype(str).str.strip().ne('Unknown')

        # Process publishedAt field for date filtering
        if 'publishedAt' in df.columns:
            df['publishedAt_dt'] = pd.to_datetime(df['publishedAt'], errors='coerce')
        else:
            df['publishedAt_dt'] = pd.NaT

        # Ensure topic hierarchy columns exist and are strings
        for col_name, default_val in [
            ('primary_category', 'Unknown'),
            ('primary_subcategory', 'Unknown'),
            ('primary_topic', 'Unknown'),
        ]:
            if col_name not in df.columns:
                df[col_name] = default_val
            else:
                df[col_name] = df[col_name].fillna(default_val).astype(str).replace({'': default_val})

        # Create a human-friendly paper label for treemap leaves: "<title> β€” <topic>"
        def _pick_title(row):
            t1 = row.get('paper_title') if isinstance(row, dict) else None
            try:
                t1 = row['paper_title'] if 'paper_title' in row and pd.notna(row['paper_title']) and str(row['paper_title']).strip() != '' else None
            except Exception:
                pass
            if t1 is not None:
                return str(t1)
            try:
                t2 = row['title'] if 'title' in row and pd.notna(row['title']) and str(row['title']).strip() != '' else None
            except Exception:
                t2 = None
            return str(t2) if t2 is not None else 'Untitled'

        def _pick_topic(row):
            # Prefer primary_topic, else first of taxonomy_topics
            try:
                pt = row['primary_topic'] if 'primary_topic' in row and pd.notna(row['primary_topic']) and str(row['primary_topic']).strip() != '' else None
            except Exception:
                pt = None
            if pt is not None:
                return str(pt)
            try:
                tt = row['taxonomy_topics'] if 'taxonomy_topics' in row else None
                if isinstance(tt, list) and len(tt) > 0:
                    return str(tt[0])
                # Sometimes arrays are serialized as strings like "[ ... ]"
                if isinstance(tt, str) and tt.strip().startswith('[') and len(tt.strip()) > 2:
                    # naive parse for first quoted token
                    inner = tt.strip().lstrip('[').rstrip(']')
                    first = inner.split(',')[0].strip().strip('"\'')
                    return first if first else 'No topic'
            except Exception:
                pass
            return 'No topic'

        titles = df.apply(_pick_title, axis=1)
        df['paper_label'] = titles.astype(str)
        # Build a Topic Chain for hover details
        df['topic_chain'] = (
            df['primary_category'].astype(str) + ' > ' +
            df['primary_subcategory'].astype(str) + ' > ' +
            df['primary_topic'].astype(str)
        )

        # Ensure link fields exist for hover details
        for link_col in ['paper_githubRepo', 'paper_projectPage']:
            if link_col not in df.columns:
                df[link_col] = 'N/A'
            else:
                df[link_col] = df[link_col].fillna('N/A').replace({'': 'N/A'})

        msg = f"Successfully loaded dataset in {time.time() - start_time:.2f}s."
        print(msg)
        return df, True, msg
    except Exception as e:
        # If we encountered invalid credentials, try logging out programmatically and retry once anonymously
        if "Invalid credentials" in str(e) or "401 Client Error" in str(e):
            try:
                print("Encountered auth error; attempting to clear cached token and retry anonymously...")
                hf_logout()
                try:
                    dataset_dict = load_dataset(HF_DATASET_ID, token=None)
                except TypeError:
                    dataset_dict = load_dataset(HF_DATASET_ID, use_auth_token=False)
                df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas()
                msg = f"Successfully loaded dataset after clearing token in {time.time() - start_time:.2f}s."
                print(msg)
                return df, True, msg
            except Exception as e2:
                err_msg = f"Failed to load dataset after retry. Error: {e2} (initial: {e})"
                print(err_msg)
                return pd.DataFrame(), False, err_msg
        err_msg = f"Failed to load dataset. Error: {e}"
        print(err_msg)
        return pd.DataFrame(), False, err_msg

def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None, group_by='organization', date_range=None):
    """
    Filter data and prepare it for a multi-level treemap.
    - Preserves individual datasets for the top K organizations.
    - Groups all other organizations into a single "Other" category.
    - date_range: tuple of (min_timestamp, max_timestamp) in seconds since epoch
    """
    if df is None or df.empty:
        return pd.DataFrame()
        
    filtered_df = df.copy()
    
    # Apply date range filter
    if date_range is not None and 'publishedAt_dt' in filtered_df.columns:
        min_ts, max_ts = date_range
        min_date = pd.to_datetime(min_ts, unit='s')
        max_date = pd.to_datetime(max_ts, unit='s')
        # Remove timezone info for comparison if publishedAt_dt is tz-naive
        if filtered_df['publishedAt_dt'].dt.tz is None:
            min_date = min_date.tz_localize(None)
            max_date = max_date.tz_localize(None)
        filtered_df = filtered_df[
            (filtered_df['publishedAt_dt'] >= min_date) & 
            (filtered_df['publishedAt_dt'] <= max_date)
        ]
    
    col_map = {
        "Has Code": "has_code",
        "Has Media": "has_media",
        "Has Organization": "has_organization",
    }
    
    if tag_filter and tag_filter != "None" and tag_filter in col_map:
        if col_map[tag_filter] in filtered_df.columns:
            filtered_df = filtered_df[filtered_df[col_map[tag_filter]]]
        
    if filtered_df.empty:
        return pd.DataFrame()
        
    if count_by not in filtered_df.columns:
        filtered_df[count_by] = 0.0
    filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0)
    
    if group_by == 'organization':
        all_org_totals = filtered_df.groupby("organization")[count_by].sum()
        top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist()

        top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy()
        other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum()
        
        final_df_for_plot = top_orgs_df
        
        if other_total > 0:
            other_row = pd.DataFrame([{
                'organization': 'Other',
                'paper_label': 'Other',
                'primary_category': 'Other',
                'primary_subcategory': 'Other',
                'primary_topic': 'Other',
                'topic_chain': 'Other > Other > Other',
                'paper_githubRepo': 'N/A',
                'paper_projectPage': 'N/A',
                'organization_avatar': None,
                count_by: other_total
            }])
            final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True)

        if skip_cats and len(skip_cats) > 0:
            final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)]

        final_df_for_plot["root"] = "papers"
        return final_df_for_plot
    else:
        # Topic grouping: apply top-k to topic combinations and handle skip list
        topic_totals = filtered_df.groupby(['primary_category', 'primary_subcategory', 'primary_topic'])[count_by].sum()
        top_topics = topic_totals.nlargest(top_k, keep='first').index.tolist()
        
        # Filter to top topics
        top_topics_df = filtered_df[
            filtered_df.apply(
                lambda r: (r['primary_category'], r['primary_subcategory'], r['primary_topic']) in top_topics, 
                axis=1
            )
        ].copy()
        
        # Apply skip filter (skip by primary_topic name)
        if skip_cats and len(skip_cats) > 0:
            top_topics_df = top_topics_df[~top_topics_df['primary_topic'].isin(skip_cats)]
        
        top_topics_df["root"] = "papers"
        return top_topics_df

def create_treemap(treemap_data, count_by, title=None, path=None, metric_label=None):
    """Generate the Plotly treemap figure from the prepared data."""
    if treemap_data.empty or treemap_data[count_by].sum() <= 0:
        fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
        fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
        return fig
    if path is None:
        path = ["root", "organization", "paper_label"]
    # Add custom data columns as regular columns for Plotly to access
    # This ensures all nodes (including intermediate hierarchy nodes) have these fields
    # Ensure organization_avatar column exists (for search details, not hover)
    if 'organization_avatar' not in treemap_data.columns:
        treemap_data['organization_avatar'] = None
    
    fig = px.treemap(
        treemap_data, 
        path=path, 
        values=count_by,
        hover_data={
            'primary_category': True,
            'primary_subcategory': True,
            'primary_topic': True,
            'paper_githubRepo': True,
            'paper_projectPage': True,
        },
        title=title, 
        color_discrete_sequence=px.colors.qualitative.Plotly
    )
    fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
    display_metric = metric_label if metric_label else count_by
    
    # Clean hover without organization avatar (images shown in search details instead)
    fig.update_traces(
        textinfo="label+value",
        hovertemplate=(
            "<b>%{label}</b><br>"
            + "%{value:,} " + display_metric +
            "<br><br><b>Topic Hierarchy:</b><br>"
            + "%{customdata[0]} > %{customdata[1]} > %{customdata[2]}<br>"
            + "<br><b>Links:</b><br>"
            + "GitHub: %{customdata[3]}<br>"
            + "Project: %{customdata[4]}"
            + "<extra></extra>"
        ),
    )
    return fig

# --- Gradio UI Blocks ---
with gr.Blocks(
    title="πŸ“š PaperVerse Daily Explorer", 
    fill_width=True,
    css="""
        /* Hide the timestamp numbers on the range slider */
        #date-range-slider-wrapper .head,
        #date-range-slider-wrapper div[data-testid="range-slider"] > span {
            display: none !important;
        }
    """
) as demo:
    datasets_data_state = gr.State(pd.DataFrame())
    loading_complete_state = gr.State(False)
    date_range_state = gr.State(None)  # Store min/max timestamps
    
    with gr.Row():
        gr.Markdown("# πŸ“š PaperVerse Daily Explorer")

    with gr.Tabs():
        with gr.Tab("πŸ“Š Treemap Visualization"):
            with gr.Row():
                with gr.Column(scale=1):
                    count_by_dropdown = gr.Dropdown(
                        label="Metric",
                        choices=[
                            ("Upvotes", "paper_upvotes"),
                            ("Comments", "numComments"),
                        ],
                        value="paper_upvotes",
                    )
                    group_by_dropdown = gr.Dropdown(
                        label="Group by",
                        choices=[("Organization", "organization"), ("Topic", "topic")],
                        value="organization",
                    )
                    gr.Markdown("**Filters**")
                    filter_code = gr.Checkbox(label="Has Code", value=False)
                    filter_media = gr.Checkbox(label="Has Media", value=False)
                    filter_org = gr.Checkbox(label="Has Organization", value=False)
                    
                    gr.Markdown("**Date Range**")
                    date_range_slider = RangeSlider(
                        minimum=0,
                        maximum=100,
                        value=(0, 100),
                        label="Paper Release Date Range",
                        interactive=True,
                        elem_id="date-range-slider-wrapper"
                    )
                    date_range_display = gr.Markdown("Loading date range...")
                    
                    top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25)
                    category_filter_dropdown = gr.Dropdown(label="Primary Category", choices=["All"], value="All")
                    subcategory_filter_dropdown = gr.Dropdown(label="Primary Subcategory", choices=["All"], value="All")
                    topic_filter_dropdown = gr.Dropdown(label="Primary Topic", choices=["All"], value="All")
                    skip_cats_textbox = gr.Textbox(label="Organizations to Skip", value="unaffiliated, Other")
                    generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)

                with gr.Column(scale=3):
                    plot_output = gr.Plot()
                    status_message_md = gr.Markdown("Initializing...")
                    data_info_md = gr.Markdown("")
        
        with gr.Tab("πŸ” Paper Search"):
            with gr.Column():
                gr.Markdown("### οΏ½ Search Papers and Organizations")
                with gr.Row():
                    search_item = gr.Textbox(
                        label="Search Organization or Paper", 
                        placeholder="Type organization name or paper title to see details...",
                        scale=4
                    )
                    search_button = gr.Button("Show Details", scale=1, variant="secondary")
                selected_info_html = gr.HTML(value="<p style='color: gray;'>Enter an organization name or paper title above to see details</p>")
    
    def _update_button_interactivity(is_loaded_flag):
        return gr.update(interactive=is_loaded_flag)
    
    def _format_date_range(date_range_tuple, date_range_value):
        """Convert slider values to readable date range text"""
        if date_range_tuple is None:
            return "Date range unavailable"
        min_ts, max_ts = date_range_tuple
        selected_min, selected_max = date_range_value
        
        # Convert slider values to timestamps
        # The slider values are already timestamps
        min_date = pd.to_datetime(selected_min, unit='s')
        max_date = pd.to_datetime(selected_max, unit='s')
        
        return f"**Selected Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}"

    def _toggle_labels_by_grouping(group_by_value):
        # Update labels based on grouping mode
        if group_by_value == 'topic':
            top_k_label = "Number of Top Topics"
            skip_label = "Topics to Skip"
            skip_value = ""  # Clear skip box for topics
        else:
            top_k_label = "Number of Top Organizations"
            skip_label = "Organizations to Skip"
            skip_value = "unaffiliated, Other"  # Default orgs to skip
        return (
            gr.update(label=top_k_label),
            gr.update(label=skip_label, value=skip_value)
        )

    ## CHANGE: New combined function to load data and generate the initial plot on startup.
    def load_and_generate_initial_plot(progress=gr.Progress()):
        progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...")
        # --- Part 1: Data Loading ---
        try:
            current_df, load_success_flag, status_msg_from_load = load_datasets_data()
            if load_success_flag:
                progress(0.5, desc="Processing data...")
                date_display = "Pre-processed (date unavailable)"
                if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]):
                    ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True)
                    date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z')
                
                # Calculate date range from publishedAt_dt
                min_ts = 0
                max_ts = 100
                date_range_text = "Date range unavailable"
                date_range_tuple = None
                
                if 'publishedAt_dt' in current_df.columns:
                    valid_dates = current_df['publishedAt_dt'].dropna()
                    if len(valid_dates) > 0:
                        min_date = valid_dates.min()
                        max_date = valid_dates.max()
                        min_ts = int(min_date.timestamp())
                        max_ts = int(max_date.timestamp())
                        date_range_tuple = (min_ts, max_ts)
                        date_range_text = f"**Full Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}"
                
                data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n"
                                  f"- Status: {status_msg_from_load}\n"
                                  f"- Total records loaded: {len(current_df):,}\n"
                                  f"- Data as of: {date_display}\n")
            else:
                data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
                min_ts = 0
                max_ts = 100
                date_range_text = "Date range unavailable"
                date_range_tuple = None
        except Exception as e:
            status_msg_from_load = f"An unexpected error occurred: {str(e)}"
            data_info_text = f"### Critical Error\n- {status_msg_from_load}"
            load_success_flag = False
            current_df = pd.DataFrame() # Ensure df is empty on failure
            min_ts = 0
            max_ts = 100
            date_range_text = "Date range unavailable"
            date_range_tuple = None
            print(f"Critical error in load_and_generate_initial_plot: {e}")
            
        # --- Part 2: Generate Initial Plot ---
        progress(0.6, desc="Generating initial plot...")
        # Defaults matching UI definitions
        default_metric = "paper_upvotes"
        default_tag = "None"
        default_k = 25
        default_group_by = "organization"
        default_skip_cats = "unaffiliated, Other"

        # Use taxonomy from JSON instead of calculating from dataset
        cat_choices = TAXONOMY_DATA['categories']
        subcat_choices = TAXONOMY_DATA['subcategories']
        topic_choices = TAXONOMY_DATA['topics']

        # Reuse the existing controller function for plotting (with date range set to None for initial load)
        initial_plot, initial_status = ui_generate_plot_controller(
            default_metric, False, False, False, default_k, default_group_by, "All", "All", "All", default_skip_cats, None, current_df, progress
        )
        
        # Also update taxonomy dropdown choices
        return (
            current_df,
            load_success_flag,
            data_info_text,
            initial_status,
            initial_plot,
            gr.update(choices=cat_choices, value="All"),
            gr.update(choices=subcat_choices, value="All"),
            gr.update(choices=topic_choices, value="All"),
            gr.update(minimum=min_ts, maximum=max_ts, value=(min_ts, max_ts)),
            date_range_text,
            date_range_tuple,
        )

    def ui_generate_plot_controller(metric_choice, has_code, has_media, has_org, 
                                   k_orgs, group_by_choice,
                                   category_choice, subcategory_choice, topic_choice,
                                   skip_cats_input, date_range, df_current_datasets, progress=gr.Progress()):
        if df_current_datasets is None or df_current_datasets.empty:
            return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot."
        
        progress(0.1, desc="Aggregating data...")
        cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()]
        
        # Apply content filters (checkboxes)
        df_filtered = df_current_datasets.copy()
        if has_code:
            df_filtered = df_filtered[df_filtered['has_code']]
        if has_media:
            df_filtered = df_filtered[df_filtered['has_media']]
        if has_org:
            df_filtered = df_filtered[df_filtered['has_organization']]
        
        # Apply taxonomy filters
        if category_choice and category_choice != 'All':
            df_filtered = df_filtered[df_filtered['primary_category'] == category_choice]
        if subcategory_choice and subcategory_choice != 'All':
            df_filtered = df_filtered[df_filtered['primary_subcategory'] == subcategory_choice]
        if topic_choice and topic_choice != 'All':
            df_filtered = df_filtered[df_filtered['primary_topic'] == topic_choice]
        
        treemap_df = make_treemap_data(df_filtered, metric_choice, k_orgs, None, cats_to_skip, group_by_choice, date_range)
        
        progress(0.7, desc="Generating plot...")
        title_labels = {
            "paper_upvotes": "Upvotes",
            "numComments": "Comments",
        }
        if group_by_choice == "topic":
            chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Topic"
            path = ["root", "primary_category", "primary_subcategory", "primary_topic", "paper_label"]
        else:
            chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Organization"
            path = ["root", "organization", "paper_label"]
        plotly_fig = create_treemap(
            treemap_df,
            metric_choice,
            chart_title,
            path=path,
            metric_label=title_labels.get(metric_choice, metric_choice),
        )
        
        if treemap_df.empty:
            plot_stats_md = "No data matches the selected filters. Please try different options."
        else:
            total_value_in_plot = treemap_df[metric_choice].sum()
            total_items_in_plot = treemap_df[treemap_df['paper_label'] != 'Other']['paper_label'].nunique()
            if group_by_choice == "topic":
                group_count = treemap_df[["primary_category", "primary_subcategory", "primary_topic"]].drop_duplicates().shape[0]
                group_line = f"**Topics Shown**: {group_count:,} unique triplets"
            else:
                group_line = f"**Organizations Shown**: {treemap_df['organization'].nunique():,}"
            plot_stats_md = (
                f"## Plot Statistics\n- {group_line}\n"
                f"- **Individual Papers Shown**: {total_items_in_plot:,}\n"
                f"- **Total {title_labels.get(metric_choice, metric_choice)} in plot**: {int(total_value_in_plot):,}"
            )
            
        return plotly_fig, plot_stats_md

    # --- Event Wiring ---
    
    ## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list.
    demo.load(
        fn=load_and_generate_initial_plot, 
        inputs=[], 
        outputs=[
            datasets_data_state,
            loading_complete_state,
            data_info_md,
            status_message_md,
            plot_output,
            category_filter_dropdown,
            subcategory_filter_dropdown,
            topic_filter_dropdown,
            date_range_slider,
            date_range_display,
            date_range_state,
        ]
    )

    loading_complete_state.change(
        fn=_update_button_interactivity, 
        inputs=loading_complete_state, 
        outputs=generate_plot_button
    )

    # Update labels based on grouping mode
    group_by_dropdown.change(
        fn=_toggle_labels_by_grouping,
        inputs=group_by_dropdown,
        outputs=[top_k_dropdown, skip_cats_textbox],
    )
    
    # Update date range display when slider changes
    date_range_slider.change(
        fn=_format_date_range,
        inputs=[date_range_state, date_range_slider],
        outputs=date_range_display,
        show_progress="hidden"
    )

    def handle_search_details(search_text, df_current):
        """Search for an organization or paper and show detailed information."""
        if not search_text or not search_text.strip():
            return "<p style='color: gray;'>Please enter a search term</p>"
        
        if df_current is None or df_current.empty:
            return "<p style='color: gray;'>No data available</p>"
        
        search_text = search_text.strip()
        
        try:
            # Try to find matching rows by organization or paper title (case-insensitive partial match)
            matching_rows = df_current[
                df_current['organization'].str.contains(search_text, case=False, na=False) |
                df_current['paper_label'].str.contains(search_text, case=False, na=False) |
                (df_current['paper_title'].str.contains(search_text, case=False, na=False) if 'paper_title' in df_current.columns else False)
            ]
            
            if matching_rows.empty:
                return f"<p style='color: orange;'>No results found for: <b>{search_text}</b></p><p style='color: gray;'>Try searching for an organization name (e.g., 'Qwen', 'Meta') or paper title keyword</p>"
            
            # Build the info panel HTML showing all matching results
            num_results = len(matching_rows)
            html_parts = [
                f"<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; background: #f9f9f9; max-height: 600px; overflow-y: auto;'>",
                f"<h3 style='margin: 0 0 15px 0; color: #333;'>πŸ” Found {num_results} result{'s' if num_results > 1 else ''} for: <span style='color: #0366d6;'>{search_text}</span></h3>"
            ]
            
            # Limit to first 20 results to avoid too much content
            display_rows = matching_rows.head(20)
            
            for idx, (_, row) in enumerate(display_rows.iterrows()):
                # Add separator between results
                if idx > 0:
                    html_parts.append("<hr style='margin: 15px 0; border: none; border-top: 1px solid #ddd;'/>")
                
                html_parts.append("<div style='margin-bottom: 10px; overflow: auto;'>")
                
                # Get organization avatar from precomputed column
                org_avatar = row.get('organization_avatar')
                
                # Organization logo if available
                if org_avatar and isinstance(org_avatar, str) and org_avatar.strip() and org_avatar.strip().lower() not in ['none', 'null', 'n/a', '']:
                    html_parts.append(f"<img src='{org_avatar}' style='max-width: 60px; max-height: 60px; border-radius: 50%; margin-bottom: 8px; float: left; margin-right: 12px; border: 2px solid #ddd;' onerror=\"this.style.display='none'\"/>")
                
                # Get paper thumbnail (direct field from schema)
                paper_thumbnail = row.get('thumbnail')
                
                # Paper thumbnail if available
                if paper_thumbnail and isinstance(paper_thumbnail, str) and paper_thumbnail.strip() and paper_thumbnail.strip().lower() not in ['none', 'null', 'n/a', '']:
                    html_parts.append(f"<img src='{paper_thumbnail}' style='max-width: 120px; max-height: 120px; border-radius: 8px; margin-bottom: 8px; float: right; margin-left: 12px; border: 1px solid #ddd;' onerror=\"this.style.display='none'\"/>")
                
                # Organization name
                org_name = row.get('organization', 'Unknown')
                html_parts.append(f"<p style='margin: 0 0 5px 0; font-weight: bold; color: #333;'>🏒 {org_name}</p>")
                
                # Paper title
                paper_title = row.get('paper_title', row.get('title', 'Untitled'))
                html_parts.append(f"<p style='margin: 0 0 5px 0; color: #555; font-size: 0.95em;'>πŸ“„ {paper_title}</p>")
                
                # Topic hierarchy
                category = row.get('primary_category', 'Unknown')
                subcategory = row.get('primary_subcategory', 'Unknown')
                topic = row.get('primary_topic', 'Unknown')
                html_parts.append(f"<p style='margin: 0 0 5px 0; font-size: 0.9em; color: #666;'><b>Topics:</b> {category} β†’ {subcategory} β†’ {topic}</p>")
                
                # Metrics
                upvotes = row.get('paper_upvotes', 0)
                comments = row.get('numComments', 0)
                html_parts.append(f"<p style='margin: 0 0 5px 0; font-size: 0.9em;'><b>Metrics:</b> ⬆️ {upvotes:,} upvotes | πŸ’¬ {comments:,} comments</p>")
                
                # Links
                github = row.get('paper_githubRepo')
                project = row.get('paper_projectPage')
                
                links = []
                if github and isinstance(github, str) and github.strip() and github.strip().lower() not in ['n/a', 'none']:
                    links.append(f"<a href='{github}' target='_blank' style='color: #0366d6; margin-right: 15px;'>πŸ”— GitHub</a>")
                
                if project and isinstance(project, str) and project.strip() and project.strip().lower() not in ['n/a', 'none']:
                    links.append(f"<a href='{project}' target='_blank' style='color: #0366d6;'>πŸ”— Project</a>")
                
                if links:
                    html_parts.append(f"<p style='margin: 0; font-size: 0.9em;'>{' '.join(links)}</p>")
                
                html_parts.append("<div style='clear: both;'></div>")
                html_parts.append("</div>")
            
            if num_results > 20:
                html_parts.append(f"<p style='margin-top: 15px; color: #666; font-style: italic;'>Showing first 20 of {num_results} results. Refine your search for fewer results.</p>")
            
            html_parts.append("</div>")
            
            return "".join(html_parts)
            
        except Exception as e:
            return f"<p style='color: red;'>Error displaying details: {str(e)}</p>"

    generate_plot_button.click(
        fn=ui_generate_plot_controller,
        inputs=[
            count_by_dropdown,
            filter_code,
            filter_media,
            filter_org,
            top_k_dropdown,
            group_by_dropdown,
            category_filter_dropdown,
            subcategory_filter_dropdown,
            topic_filter_dropdown,
            skip_cats_textbox,
            date_range_slider,
            datasets_data_state,
        ],
        outputs=[plot_output, status_message_md]
    )
    
    # Handle search button for showing details
    search_button.click(
        fn=handle_search_details,
        inputs=[search_item, datasets_data_state],
        outputs=[selected_info_html]
    )
    
    # Also trigger on Enter key in search box
    search_item.submit(
        fn=handle_search_details,
        inputs=[search_item, datasets_data_state],
        outputs=[selected_info_html]
    )

if __name__ == "__main__":
    print("Application starting...")
    demo.queue().launch()