import gradio as gr import pandas as pd import plotly.express as px import time import os import tempfile import requests import duckdb import json from datasets import load_dataset from huggingface_hub import logout as hf_logout from gradio_rangeslider import RangeSlider # --- Constants --- TOP_K_CHOICES = list(range(5, 51, 5)) HF_DATASET_ID = "evijit/paperverse_daily_data" # Direct parquet file URL (public) PARQUET_URL = "https://huggingface.co/datasets/evijit/paperverse_daily_data/resolve/main/papers_with_semantic_taxonomy.parquet" TAXONOMY_JSON_PATH = "integrated_ml_taxonomy.json" # Simple content filters derived from the new dataset TAG_FILTER_CHOICES = [ "None", "Has Code", "Has Media", "Has Organization", ] # Load taxonomy from JSON file def load_taxonomy(): """Load the ML taxonomy from JSON file.""" try: with open(TAXONOMY_JSON_PATH, 'r') as f: taxonomy = json.load(f) # Extract choices for dropdowns categories = sorted(taxonomy.keys()) # Build subcategories and topics all_subcategories = set() all_topics = set() for category, subcats in taxonomy.items(): for subcat, topics in subcats.items(): all_subcategories.add(subcat) all_topics.update(topics) return { 'categories': ["All"] + categories, 'subcategories': ["All"] + sorted(all_subcategories), 'topics': ["All"] + sorted(all_topics), 'taxonomy': taxonomy } except Exception as e: print(f"Error loading taxonomy from JSON: {e}") return { 'categories': ["All"], 'subcategories': ["All"], 'topics': ["All"], 'taxonomy': {} } TAXONOMY_DATA = load_taxonomy() def _first_non_null(*values): for v in values: if v is None: continue # treat empty strings as null-ish if isinstance(v, str) and v.strip() == "": continue return v return None def _get_nested(row, *paths): """Try multiple dotted paths in a row that may contain dicts; return first non-null.""" for path in paths: cur = row ok = True for key in path.split('.'): if isinstance(cur, dict) and key in cur: cur = cur[key] else: ok = False break if ok and cur is not None: return cur return None def load_datasets_data(): """Load the PaperVerse Daily dataset from the Hugging Face Hub and normalize columns used by the app.""" start_time = time.time() print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}") try: # First try: direct parquet download (avoids any auth header issues) try: print(f"Trying direct parquet download: {PARQUET_URL}") with requests.get(PARQUET_URL, stream=True, timeout=120) as resp: resp.raise_for_status() with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmpf: for chunk in resp.iter_content(chunk_size=1024 * 1024): if chunk: tmpf.write(chunk) tmp_path = tmpf.name try: # Use DuckDB to read parquet to avoid pyarrow decoding issues df = duckdb.query(f"SELECT * FROM read_parquet('{tmp_path}')").df() finally: try: os.remove(tmp_path) except Exception: pass print("Loaded DataFrame from direct parquet download via DuckDB.") except Exception as direct_e: print(f"Direct parquet load failed: {direct_e}. Falling back to datasets loader...") # Force anonymous access in case an invalid cached token is present # Clear any token environment variables that could inject a bad Authorization header for env_key in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HF_HUB_TOKEN"): if os.environ.pop(env_key, None) is not None: print(f"Cleared env var: {env_key}") # Prefer explicit train split when available try: dataset_obj = load_dataset(HF_DATASET_ID, split="train", token=None) except TypeError: dataset_obj = load_dataset(HF_DATASET_ID, split="train", use_auth_token=False) except Exception: # Fallback: load all splits and pick the first available try: dataset_obj = load_dataset(HF_DATASET_ID, token=None) except TypeError: dataset_obj = load_dataset(HF_DATASET_ID, use_auth_token=False) # Handle both Dataset and DatasetDict try: # If it's a Dataset (single split), this will work df = dataset_obj.to_pandas() except AttributeError: # Otherwise assume DatasetDict and take the first split first_split = list(dataset_obj.keys())[0] df = dataset_obj[first_split].to_pandas() # --- Normalize expected columns for the visualization --- # organization: prefer top-level organization_name, then paper_organization.name/fullname, else Unknown if 'organization_name' in df.columns: org_series = df['organization_name'] else: # try nested dicts commonly produced by HF datasets org_series = df.apply( lambda r: _first_non_null( _get_nested(r, 'paper_organization.name'), _get_nested(r, 'paper_organization.fullname'), _get_nested(r, 'organization.name'), _get_nested(r, 'organization.fullname') ), axis=1 ) df['organization'] = org_series.fillna('Unknown') # Extract organization avatar/logo if 'organization_name' in df.columns: # Try to get avatar from paper_organization or organization struct def _get_avatar(row): for path in ['paper_organization.avatar', 'organization.avatar']: av = _get_nested(row, path) if av and isinstance(av, str) and av.strip(): return av return None org_avatar_series = df.apply(_get_avatar, axis=1) else: org_avatar_series = pd.Series([None] * len(df)) df['organization_avatar'] = org_avatar_series # id for each paper row cand_cols = [ 'paper_id', 'paper_discussionId', 'key' ] id_val = None for c in cand_cols: if c in df.columns: id_val = df[c] break if id_val is None: # fallback to title + index if 'paper_title' in df.columns: df['id'] = df['paper_title'].astype(str) + '_' + df.reset_index().index.astype(str) elif 'title' in df.columns: df['id'] = df['title'].astype(str) + '_' + df.reset_index().index.astype(str) else: df['id'] = df.reset_index().index.astype(str) else: df['id'] = id_val.astype(str) # numeric metrics used for aggregation def _to_num(col_name): if col_name in df.columns: return pd.to_numeric(df[col_name], errors='coerce').fillna(0.0) return pd.Series([0.0] * len(df)) df['paper_upvotes'] = _to_num('paper_upvotes') df['numComments'] = _to_num('numComments') df['paper_githubStars'] = _to_num('paper_githubStars') # computed boolean filters def _has_code(row): # Check for GitHub repo try: gh = row['paper_githubRepo'] if 'paper_githubRepo' in row and pd.notna(row['paper_githubRepo']) else None if isinstance(gh, str) and len(gh.strip()) > 0: return True except Exception: pass # Check for project page try: pp = row.get('paper_projectPage') if isinstance(row, dict) else row.get('paper_projectPage', None) if isinstance(pp, str) and len(str(pp).strip()) > 0 and str(pp).strip().lower() != 'n/a': return True except Exception: pass return False def _has_media(row): for c in ['paper_mediaUrls', 'mediaUrls']: try: v = row[c] if isinstance(v, list) and len(v) > 0: return True # some providers store arrays as strings like "[... ]" if isinstance(v, str) and v.strip().startswith('[') and len(v.strip()) > 2: return True except Exception: continue return False df['has_code'] = df.apply(_has_code, axis=1) df['has_media'] = df.apply(_has_media, axis=1) df['has_organization'] = df['organization'].astype(str).str.strip().ne('Unknown') # Process publishedAt field for date filtering if 'publishedAt' in df.columns: df['publishedAt_dt'] = pd.to_datetime(df['publishedAt'], errors='coerce') else: df['publishedAt_dt'] = pd.NaT # Ensure topic hierarchy columns exist and are strings for col_name, default_val in [ ('primary_category', 'Unknown'), ('primary_subcategory', 'Unknown'), ('primary_topic', 'Unknown'), ]: if col_name not in df.columns: df[col_name] = default_val else: df[col_name] = df[col_name].fillna(default_val).astype(str).replace({'': default_val}) # Create a human-friendly paper label for treemap leaves: "
Enter an organization name or paper title above to see details
") def _update_button_interactivity(is_loaded_flag): return gr.update(interactive=is_loaded_flag) def _format_date_range(date_range_tuple, date_range_value): """Convert slider values to readable date range text""" if date_range_tuple is None: return "Date range unavailable" min_ts, max_ts = date_range_tuple selected_min, selected_max = date_range_value # Convert slider values to timestamps # The slider values are already timestamps min_date = pd.to_datetime(selected_min, unit='s') max_date = pd.to_datetime(selected_max, unit='s') return f"**Selected Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}" def _toggle_labels_by_grouping(group_by_value): # Update labels based on grouping mode if group_by_value == 'topic': top_k_label = "Number of Top Topics" skip_label = "Topics to Skip" skip_value = "" # Clear skip box for topics else: top_k_label = "Number of Top Organizations" skip_label = "Organizations to Skip" skip_value = "unaffiliated, Other" # Default orgs to skip return ( gr.update(label=top_k_label), gr.update(label=skip_label, value=skip_value) ) ## CHANGE: New combined function to load data and generate the initial plot on startup. def load_and_generate_initial_plot(progress=gr.Progress()): progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...") # --- Part 1: Data Loading --- try: current_df, load_success_flag, status_msg_from_load = load_datasets_data() if load_success_flag: progress(0.5, desc="Processing data...") date_display = "Pre-processed (date unavailable)" if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]): ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True) date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z') # Calculate date range from publishedAt_dt min_ts = 0 max_ts = 100 date_range_text = "Date range unavailable" date_range_tuple = None if 'publishedAt_dt' in current_df.columns: valid_dates = current_df['publishedAt_dt'].dropna() if len(valid_dates) > 0: min_date = valid_dates.min() max_date = valid_dates.max() min_ts = int(min_date.timestamp()) max_ts = int(max_date.timestamp()) date_range_tuple = (min_ts, max_ts) date_range_text = f"**Full Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}" data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n" f"- Status: {status_msg_from_load}\n" f"- Total records loaded: {len(current_df):,}\n" f"- Data as of: {date_display}\n") else: data_info_text = f"### Data Load Failed\n- {status_msg_from_load}" min_ts = 0 max_ts = 100 date_range_text = "Date range unavailable" date_range_tuple = None except Exception as e: status_msg_from_load = f"An unexpected error occurred: {str(e)}" data_info_text = f"### Critical Error\n- {status_msg_from_load}" load_success_flag = False current_df = pd.DataFrame() # Ensure df is empty on failure min_ts = 0 max_ts = 100 date_range_text = "Date range unavailable" date_range_tuple = None print(f"Critical error in load_and_generate_initial_plot: {e}") # --- Part 2: Generate Initial Plot --- progress(0.6, desc="Generating initial plot...") # Defaults matching UI definitions default_metric = "paper_upvotes" default_tag = "None" default_k = 25 default_group_by = "organization" default_skip_cats = "unaffiliated, Other" # Use taxonomy from JSON instead of calculating from dataset cat_choices = TAXONOMY_DATA['categories'] subcat_choices = TAXONOMY_DATA['subcategories'] topic_choices = TAXONOMY_DATA['topics'] # Reuse the existing controller function for plotting (with date range set to None for initial load) initial_plot, initial_status = ui_generate_plot_controller( default_metric, False, False, False, default_k, default_group_by, "All", "All", "All", default_skip_cats, None, current_df, progress ) # Also update taxonomy dropdown choices return ( current_df, load_success_flag, data_info_text, initial_status, initial_plot, gr.update(choices=cat_choices, value="All"), gr.update(choices=subcat_choices, value="All"), gr.update(choices=topic_choices, value="All"), gr.update(minimum=min_ts, maximum=max_ts, value=(min_ts, max_ts)), date_range_text, date_range_tuple, ) def ui_generate_plot_controller(metric_choice, has_code, has_media, has_org, k_orgs, group_by_choice, category_choice, subcategory_choice, topic_choice, skip_cats_input, date_range, df_current_datasets, progress=gr.Progress()): if df_current_datasets is None or df_current_datasets.empty: return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot." progress(0.1, desc="Aggregating data...") cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()] # Apply content filters (checkboxes) df_filtered = df_current_datasets.copy() if has_code: df_filtered = df_filtered[df_filtered['has_code']] if has_media: df_filtered = df_filtered[df_filtered['has_media']] if has_org: df_filtered = df_filtered[df_filtered['has_organization']] # Apply taxonomy filters if category_choice and category_choice != 'All': df_filtered = df_filtered[df_filtered['primary_category'] == category_choice] if subcategory_choice and subcategory_choice != 'All': df_filtered = df_filtered[df_filtered['primary_subcategory'] == subcategory_choice] if topic_choice and topic_choice != 'All': df_filtered = df_filtered[df_filtered['primary_topic'] == topic_choice] treemap_df = make_treemap_data(df_filtered, metric_choice, k_orgs, None, cats_to_skip, group_by_choice, date_range) progress(0.7, desc="Generating plot...") title_labels = { "paper_upvotes": "Upvotes", "numComments": "Comments", } if group_by_choice == "topic": chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Topic" path = ["root", "primary_category", "primary_subcategory", "primary_topic", "paper_label"] else: chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Organization" path = ["root", "organization", "paper_label"] plotly_fig = create_treemap( treemap_df, metric_choice, chart_title, path=path, metric_label=title_labels.get(metric_choice, metric_choice), ) if treemap_df.empty: plot_stats_md = "No data matches the selected filters. Please try different options." else: total_value_in_plot = treemap_df[metric_choice].sum() total_items_in_plot = treemap_df[treemap_df['paper_label'] != 'Other']['paper_label'].nunique() if group_by_choice == "topic": group_count = treemap_df[["primary_category", "primary_subcategory", "primary_topic"]].drop_duplicates().shape[0] group_line = f"**Topics Shown**: {group_count:,} unique triplets" else: group_line = f"**Organizations Shown**: {treemap_df['organization'].nunique():,}" plot_stats_md = ( f"## Plot Statistics\n- {group_line}\n" f"- **Individual Papers Shown**: {total_items_in_plot:,}\n" f"- **Total {title_labels.get(metric_choice, metric_choice)} in plot**: {int(total_value_in_plot):,}" ) return plotly_fig, plot_stats_md # --- Event Wiring --- ## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list. demo.load( fn=load_and_generate_initial_plot, inputs=[], outputs=[ datasets_data_state, loading_complete_state, data_info_md, status_message_md, plot_output, category_filter_dropdown, subcategory_filter_dropdown, topic_filter_dropdown, date_range_slider, date_range_display, date_range_state, ] ) loading_complete_state.change( fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button ) # Update labels based on grouping mode group_by_dropdown.change( fn=_toggle_labels_by_grouping, inputs=group_by_dropdown, outputs=[top_k_dropdown, skip_cats_textbox], ) # Update date range display when slider changes date_range_slider.change( fn=_format_date_range, inputs=[date_range_state, date_range_slider], outputs=date_range_display, show_progress="hidden" ) def handle_search_details(search_text, df_current): """Search for an organization or paper and show detailed information.""" if not search_text or not search_text.strip(): return "Please enter a search term
" if df_current is None or df_current.empty: return "No data available
" search_text = search_text.strip() try: # Try to find matching rows by organization or paper title (case-insensitive partial match) matching_rows = df_current[ df_current['organization'].str.contains(search_text, case=False, na=False) | df_current['paper_label'].str.contains(search_text, case=False, na=False) | (df_current['paper_title'].str.contains(search_text, case=False, na=False) if 'paper_title' in df_current.columns else False) ] if matching_rows.empty: return f"No results found for: {search_text}
Try searching for an organization name (e.g., 'Qwen', 'Meta') or paper title keyword
" # Build the info panel HTML showing all matching results num_results = len(matching_rows) html_parts = [ f"🏢 {org_name}
") # Paper title paper_title = row.get('paper_title', row.get('title', 'Untitled')) html_parts.append(f"📄 {paper_title}
") # Topic hierarchy category = row.get('primary_category', 'Unknown') subcategory = row.get('primary_subcategory', 'Unknown') topic = row.get('primary_topic', 'Unknown') html_parts.append(f"Topics: {category} → {subcategory} → {topic}
") # Metrics upvotes = row.get('paper_upvotes', 0) comments = row.get('numComments', 0) html_parts.append(f"Metrics: ⬆️ {upvotes:,} upvotes | 💬 {comments:,} comments
") # Links github = row.get('paper_githubRepo') project = row.get('paper_projectPage') links = [] if github and isinstance(github, str) and github.strip() and github.strip().lower() not in ['n/a', 'none']: links.append(f"🔗 GitHub") if project and isinstance(project, str) and project.strip() and project.strip().lower() not in ['n/a', 'none']: links.append(f"🔗 Project") if links: html_parts.append(f"{' '.join(links)}
") html_parts.append("") html_parts.append("Showing first 20 of {num_results} results. Refine your search for fewer results.
") html_parts.append("Error displaying details: {str(e)}
" generate_plot_button.click( fn=ui_generate_plot_controller, inputs=[ count_by_dropdown, filter_code, filter_media, filter_org, top_k_dropdown, group_by_dropdown, category_filter_dropdown, subcategory_filter_dropdown, topic_filter_dropdown, skip_cats_textbox, date_range_slider, datasets_data_state, ], outputs=[plot_output, status_message_md] ) # Handle search button for showing details search_button.click( fn=handle_search_details, inputs=[search_item, datasets_data_state], outputs=[selected_info_html] ) # Also trigger on Enter key in search box search_item.submit( fn=handle_search_details, inputs=[search_item, datasets_data_state], outputs=[selected_info_html] ) if __name__ == "__main__": print("Application starting...") demo.queue().launch()