import gradio as gr import pandas as pd import plotly.express as px import time import os import tempfile import requests import duckdb import json from datasets import load_dataset from huggingface_hub import logout as hf_logout from gradio_rangeslider import RangeSlider # --- Constants --- TOP_K_CHOICES = list(range(5, 51, 5)) HF_DATASET_ID = "evijit/paperverse_daily_data" # Direct parquet file URL (public) PARQUET_URL = "https://huggingface.co/datasets/evijit/paperverse_daily_data/resolve/main/papers_with_semantic_taxonomy.parquet" TAXONOMY_JSON_PATH = "integrated_ml_taxonomy.json" # Simple content filters derived from the new dataset TAG_FILTER_CHOICES = [ "None", "Has Code", "Has Media", "Has Organization", ] # Load taxonomy from JSON file def load_taxonomy(): """Load the ML taxonomy from JSON file.""" try: with open(TAXONOMY_JSON_PATH, 'r') as f: taxonomy = json.load(f) # Extract choices for dropdowns categories = sorted(taxonomy.keys()) # Build subcategories and topics all_subcategories = set() all_topics = set() for category, subcats in taxonomy.items(): for subcat, topics in subcats.items(): all_subcategories.add(subcat) all_topics.update(topics) return { 'categories': ["All"] + categories, 'subcategories': ["All"] + sorted(all_subcategories), 'topics': ["All"] + sorted(all_topics), 'taxonomy': taxonomy } except Exception as e: print(f"Error loading taxonomy from JSON: {e}") return { 'categories': ["All"], 'subcategories': ["All"], 'topics': ["All"], 'taxonomy': {} } TAXONOMY_DATA = load_taxonomy() def _first_non_null(*values): for v in values: if v is None: continue # treat empty strings as null-ish if isinstance(v, str) and v.strip() == "": continue return v return None def _get_nested(row, *paths): """Try multiple dotted paths in a row that may contain dicts; return first non-null.""" for path in paths: cur = row ok = True for key in path.split('.'): if isinstance(cur, dict) and key in cur: cur = cur[key] else: ok = False break if ok and cur is not None: return cur return None def load_datasets_data(): """Load the PaperVerse Daily dataset from the Hugging Face Hub and normalize columns used by the app.""" start_time = time.time() print(f"Attempting to load dataset from Hugging Face Hub: {HF_DATASET_ID}") try: # First try: direct parquet download (avoids any auth header issues) try: print(f"Trying direct parquet download: {PARQUET_URL}") with requests.get(PARQUET_URL, stream=True, timeout=120) as resp: resp.raise_for_status() with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmpf: for chunk in resp.iter_content(chunk_size=1024 * 1024): if chunk: tmpf.write(chunk) tmp_path = tmpf.name try: # Use DuckDB to read parquet to avoid pyarrow decoding issues df = duckdb.query(f"SELECT * FROM read_parquet('{tmp_path}')").df() finally: try: os.remove(tmp_path) except Exception: pass print("Loaded DataFrame from direct parquet download via DuckDB.") except Exception as direct_e: print(f"Direct parquet load failed: {direct_e}. Falling back to datasets loader...") # Force anonymous access in case an invalid cached token is present # Clear any token environment variables that could inject a bad Authorization header for env_key in ("HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HF_HUB_TOKEN"): if os.environ.pop(env_key, None) is not None: print(f"Cleared env var: {env_key}") # Prefer explicit train split when available try: dataset_obj = load_dataset(HF_DATASET_ID, split="train", token=None) except TypeError: dataset_obj = load_dataset(HF_DATASET_ID, split="train", use_auth_token=False) except Exception: # Fallback: load all splits and pick the first available try: dataset_obj = load_dataset(HF_DATASET_ID, token=None) except TypeError: dataset_obj = load_dataset(HF_DATASET_ID, use_auth_token=False) # Handle both Dataset and DatasetDict try: # If it's a Dataset (single split), this will work df = dataset_obj.to_pandas() except AttributeError: # Otherwise assume DatasetDict and take the first split first_split = list(dataset_obj.keys())[0] df = dataset_obj[first_split].to_pandas() # --- Normalize expected columns for the visualization --- # organization: prefer top-level organization_name, then paper_organization.name/fullname, else Unknown if 'organization_name' in df.columns: org_series = df['organization_name'] else: # try nested dicts commonly produced by HF datasets org_series = df.apply( lambda r: _first_non_null( _get_nested(r, 'paper_organization.name'), _get_nested(r, 'paper_organization.fullname'), _get_nested(r, 'organization.name'), _get_nested(r, 'organization.fullname') ), axis=1 ) df['organization'] = org_series.fillna('Unknown') # Extract organization avatar/logo if 'organization_name' in df.columns: # Try to get avatar from paper_organization or organization struct def _get_avatar(row): for path in ['paper_organization.avatar', 'organization.avatar']: av = _get_nested(row, path) if av and isinstance(av, str) and av.strip(): return av return None org_avatar_series = df.apply(_get_avatar, axis=1) else: org_avatar_series = pd.Series([None] * len(df)) df['organization_avatar'] = org_avatar_series # id for each paper row cand_cols = [ 'paper_id', 'paper_discussionId', 'key' ] id_val = None for c in cand_cols: if c in df.columns: id_val = df[c] break if id_val is None: # fallback to title + index if 'paper_title' in df.columns: df['id'] = df['paper_title'].astype(str) + '_' + df.reset_index().index.astype(str) elif 'title' in df.columns: df['id'] = df['title'].astype(str) + '_' + df.reset_index().index.astype(str) else: df['id'] = df.reset_index().index.astype(str) else: df['id'] = id_val.astype(str) # numeric metrics used for aggregation def _to_num(col_name): if col_name in df.columns: return pd.to_numeric(df[col_name], errors='coerce').fillna(0.0) return pd.Series([0.0] * len(df)) df['paper_upvotes'] = _to_num('paper_upvotes') df['numComments'] = _to_num('numComments') df['paper_githubStars'] = _to_num('paper_githubStars') # computed boolean filters def _has_code(row): # Check for GitHub repo try: gh = row['paper_githubRepo'] if 'paper_githubRepo' in row and pd.notna(row['paper_githubRepo']) else None if isinstance(gh, str) and len(gh.strip()) > 0: return True except Exception: pass # Check for project page try: pp = row.get('paper_projectPage') if isinstance(row, dict) else row.get('paper_projectPage', None) if isinstance(pp, str) and len(str(pp).strip()) > 0 and str(pp).strip().lower() != 'n/a': return True except Exception: pass return False def _has_media(row): for c in ['paper_mediaUrls', 'mediaUrls']: try: v = row[c] if isinstance(v, list) and len(v) > 0: return True # some providers store arrays as strings like "[... ]" if isinstance(v, str) and v.strip().startswith('[') and len(v.strip()) > 2: return True except Exception: continue return False df['has_code'] = df.apply(_has_code, axis=1) df['has_media'] = df.apply(_has_media, axis=1) df['has_organization'] = df['organization'].astype(str).str.strip().ne('Unknown') # Process publishedAt field for date filtering if 'publishedAt' in df.columns: df['publishedAt_dt'] = pd.to_datetime(df['publishedAt'], errors='coerce') else: df['publishedAt_dt'] = pd.NaT # Ensure topic hierarchy columns exist and are strings for col_name, default_val in [ ('primary_category', 'Unknown'), ('primary_subcategory', 'Unknown'), ('primary_topic', 'Unknown'), ]: if col_name not in df.columns: df[col_name] = default_val else: df[col_name] = df[col_name].fillna(default_val).astype(str).replace({'': default_val}) # Create a human-friendly paper label for treemap leaves: " — <topic>" def _pick_title(row): t1 = row.get('paper_title') if isinstance(row, dict) else None try: t1 = row['paper_title'] if 'paper_title' in row and pd.notna(row['paper_title']) and str(row['paper_title']).strip() != '' else None except Exception: pass if t1 is not None: return str(t1) try: t2 = row['title'] if 'title' in row and pd.notna(row['title']) and str(row['title']).strip() != '' else None except Exception: t2 = None return str(t2) if t2 is not None else 'Untitled' def _pick_topic(row): # Prefer primary_topic, else first of taxonomy_topics try: pt = row['primary_topic'] if 'primary_topic' in row and pd.notna(row['primary_topic']) and str(row['primary_topic']).strip() != '' else None except Exception: pt = None if pt is not None: return str(pt) try: tt = row['taxonomy_topics'] if 'taxonomy_topics' in row else None if isinstance(tt, list) and len(tt) > 0: return str(tt[0]) # Sometimes arrays are serialized as strings like "[ ... ]" if isinstance(tt, str) and tt.strip().startswith('[') and len(tt.strip()) > 2: # naive parse for first quoted token inner = tt.strip().lstrip('[').rstrip(']') first = inner.split(',')[0].strip().strip('"\'') return first if first else 'No topic' except Exception: pass return 'No topic' titles = df.apply(_pick_title, axis=1) df['paper_label'] = titles.astype(str) # Build a Topic Chain for hover details df['topic_chain'] = ( df['primary_category'].astype(str) + ' > ' + df['primary_subcategory'].astype(str) + ' > ' + df['primary_topic'].astype(str) ) # Ensure link fields exist for hover details for link_col in ['paper_githubRepo', 'paper_projectPage']: if link_col not in df.columns: df[link_col] = 'N/A' else: df[link_col] = df[link_col].fillna('N/A').replace({'': 'N/A'}) msg = f"Successfully loaded dataset in {time.time() - start_time:.2f}s." print(msg) return df, True, msg except Exception as e: # If we encountered invalid credentials, try logging out programmatically and retry once anonymously if "Invalid credentials" in str(e) or "401 Client Error" in str(e): try: print("Encountered auth error; attempting to clear cached token and retry anonymously...") hf_logout() try: dataset_dict = load_dataset(HF_DATASET_ID, token=None) except TypeError: dataset_dict = load_dataset(HF_DATASET_ID, use_auth_token=False) df = dataset_dict[list(dataset_dict.keys())[0]].to_pandas() msg = f"Successfully loaded dataset after clearing token in {time.time() - start_time:.2f}s." print(msg) return df, True, msg except Exception as e2: err_msg = f"Failed to load dataset after retry. Error: {e2} (initial: {e})" print(err_msg) return pd.DataFrame(), False, err_msg err_msg = f"Failed to load dataset. Error: {e}" print(err_msg) return pd.DataFrame(), False, err_msg def make_treemap_data(df, count_by, top_k=25, tag_filter=None, skip_cats=None, group_by='organization', date_range=None): """ Filter data and prepare it for a multi-level treemap. - Preserves individual datasets for the top K organizations. - Groups all other organizations into a single "Other" category. - date_range: tuple of (min_timestamp, max_timestamp) in seconds since epoch """ if df is None or df.empty: return pd.DataFrame() filtered_df = df.copy() # Apply date range filter if date_range is not None and 'publishedAt_dt' in filtered_df.columns: min_ts, max_ts = date_range min_date = pd.to_datetime(min_ts, unit='s') max_date = pd.to_datetime(max_ts, unit='s') # Remove timezone info for comparison if publishedAt_dt is tz-naive if filtered_df['publishedAt_dt'].dt.tz is None: min_date = min_date.tz_localize(None) max_date = max_date.tz_localize(None) filtered_df = filtered_df[ (filtered_df['publishedAt_dt'] >= min_date) & (filtered_df['publishedAt_dt'] <= max_date) ] col_map = { "Has Code": "has_code", "Has Media": "has_media", "Has Organization": "has_organization", } if tag_filter and tag_filter != "None" and tag_filter in col_map: if col_map[tag_filter] in filtered_df.columns: filtered_df = filtered_df[filtered_df[col_map[tag_filter]]] if filtered_df.empty: return pd.DataFrame() if count_by not in filtered_df.columns: filtered_df[count_by] = 0.0 filtered_df[count_by] = pd.to_numeric(filtered_df[count_by], errors='coerce').fillna(0.0) if group_by == 'organization': all_org_totals = filtered_df.groupby("organization")[count_by].sum() top_org_names = all_org_totals.nlargest(top_k, keep='first').index.tolist() top_orgs_df = filtered_df[filtered_df['organization'].isin(top_org_names)].copy() other_total = all_org_totals[~all_org_totals.index.isin(top_org_names)].sum() final_df_for_plot = top_orgs_df if other_total > 0: other_row = pd.DataFrame([{ 'organization': 'Other', 'paper_label': 'Other', 'primary_category': 'Other', 'primary_subcategory': 'Other', 'primary_topic': 'Other', 'topic_chain': 'Other > Other > Other', 'paper_githubRepo': 'N/A', 'paper_projectPage': 'N/A', 'organization_avatar': None, count_by: other_total }]) final_df_for_plot = pd.concat([final_df_for_plot, other_row], ignore_index=True) if skip_cats and len(skip_cats) > 0: final_df_for_plot = final_df_for_plot[~final_df_for_plot['organization'].isin(skip_cats)] final_df_for_plot["root"] = "papers" return final_df_for_plot else: # Topic grouping: apply top-k to topic combinations and handle skip list topic_totals = filtered_df.groupby(['primary_category', 'primary_subcategory', 'primary_topic'])[count_by].sum() top_topics = topic_totals.nlargest(top_k, keep='first').index.tolist() # Filter to top topics top_topics_df = filtered_df[ filtered_df.apply( lambda r: (r['primary_category'], r['primary_subcategory'], r['primary_topic']) in top_topics, axis=1 ) ].copy() # Apply skip filter (skip by primary_topic name) if skip_cats and len(skip_cats) > 0: top_topics_df = top_topics_df[~top_topics_df['primary_topic'].isin(skip_cats)] top_topics_df["root"] = "papers" return top_topics_df def create_treemap(treemap_data, count_by, title=None, path=None, metric_label=None): """Generate the Plotly treemap figure from the prepared data.""" if treemap_data.empty or treemap_data[count_by].sum() <= 0: fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1]) fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25)) return fig if path is None: path = ["root", "organization", "paper_label"] # Add custom data columns as regular columns for Plotly to access # This ensures all nodes (including intermediate hierarchy nodes) have these fields # Ensure organization_avatar column exists (for search details, not hover) if 'organization_avatar' not in treemap_data.columns: treemap_data['organization_avatar'] = None fig = px.treemap( treemap_data, path=path, values=count_by, hover_data={ 'primary_category': True, 'primary_subcategory': True, 'primary_topic': True, 'paper_githubRepo': True, 'paper_projectPage': True, }, title=title, color_discrete_sequence=px.colors.qualitative.Plotly ) fig.update_layout(margin=dict(t=50, l=25, r=25, b=25)) display_metric = metric_label if metric_label else count_by # Clean hover without organization avatar (images shown in search details instead) fig.update_traces( textinfo="label+value", hovertemplate=( "<b>%{label}</b><br>" + "%{value:,} " + display_metric + "<br><br><b>Topic Hierarchy:</b><br>" + "%{customdata[0]} > %{customdata[1]} > %{customdata[2]}<br>" + "<br><b>Links:</b><br>" + "GitHub: %{customdata[3]}<br>" + "Project: %{customdata[4]}" + "<extra></extra>" ), ) return fig # --- Gradio UI Blocks --- with gr.Blocks( title="📚 PaperVerse Daily Explorer", fill_width=True, css=""" /* Hide the timestamp numbers on the range slider */ #date-range-slider-wrapper .head, #date-range-slider-wrapper div[data-testid="range-slider"] > span { display: none !important; } """ ) as demo: datasets_data_state = gr.State(pd.DataFrame()) loading_complete_state = gr.State(False) date_range_state = gr.State(None) # Store min/max timestamps with gr.Row(): gr.Markdown("# 📚 PaperVerse Daily Explorer") with gr.Tabs(): with gr.Tab("📊 Treemap Visualization"): with gr.Row(): with gr.Column(scale=1): count_by_dropdown = gr.Dropdown( label="Metric", choices=[ ("Upvotes", "paper_upvotes"), ("Comments", "numComments"), ], value="paper_upvotes", ) group_by_dropdown = gr.Dropdown( label="Group by", choices=[("Organization", "organization"), ("Topic", "topic")], value="organization", ) gr.Markdown("**Filters**") filter_code = gr.Checkbox(label="Has Code", value=False) filter_media = gr.Checkbox(label="Has Media", value=False) filter_org = gr.Checkbox(label="Has Organization", value=False) gr.Markdown("**Date Range**") date_range_slider = RangeSlider( minimum=0, maximum=100, value=(0, 100), label="Paper Release Date Range", interactive=True, elem_id="date-range-slider-wrapper" ) date_range_display = gr.Markdown("Loading date range...") top_k_dropdown = gr.Dropdown(label="Number of Top Organizations", choices=TOP_K_CHOICES, value=25) category_filter_dropdown = gr.Dropdown(label="Primary Category", choices=["All"], value="All") subcategory_filter_dropdown = gr.Dropdown(label="Primary Subcategory", choices=["All"], value="All") topic_filter_dropdown = gr.Dropdown(label="Primary Topic", choices=["All"], value="All") skip_cats_textbox = gr.Textbox(label="Organizations to Skip", value="unaffiliated, Other") generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False) with gr.Column(scale=3): plot_output = gr.Plot() status_message_md = gr.Markdown("Initializing...") data_info_md = gr.Markdown("") with gr.Tab("🔍 Paper Search"): with gr.Column(): gr.Markdown("### � Search Papers and Organizations") with gr.Row(): search_item = gr.Textbox( label="Search Organization or Paper", placeholder="Type organization name or paper title to see details...", scale=4 ) search_button = gr.Button("Show Details", scale=1, variant="secondary") selected_info_html = gr.HTML(value="<p style='color: gray;'>Enter an organization name or paper title above to see details</p>") def _update_button_interactivity(is_loaded_flag): return gr.update(interactive=is_loaded_flag) def _format_date_range(date_range_tuple, date_range_value): """Convert slider values to readable date range text""" if date_range_tuple is None: return "Date range unavailable" min_ts, max_ts = date_range_tuple selected_min, selected_max = date_range_value # Convert slider values to timestamps # The slider values are already timestamps min_date = pd.to_datetime(selected_min, unit='s') max_date = pd.to_datetime(selected_max, unit='s') return f"**Selected Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}" def _toggle_labels_by_grouping(group_by_value): # Update labels based on grouping mode if group_by_value == 'topic': top_k_label = "Number of Top Topics" skip_label = "Topics to Skip" skip_value = "" # Clear skip box for topics else: top_k_label = "Number of Top Organizations" skip_label = "Organizations to Skip" skip_value = "unaffiliated, Other" # Default orgs to skip return ( gr.update(label=top_k_label), gr.update(label=skip_label, value=skip_value) ) ## CHANGE: New combined function to load data and generate the initial plot on startup. def load_and_generate_initial_plot(progress=gr.Progress()): progress(0, desc=f"Loading dataset '{HF_DATASET_ID}'...") # --- Part 1: Data Loading --- try: current_df, load_success_flag, status_msg_from_load = load_datasets_data() if load_success_flag: progress(0.5, desc="Processing data...") date_display = "Pre-processed (date unavailable)" if 'data_download_timestamp' in current_df.columns and pd.notna(current_df['data_download_timestamp'].iloc[0]): ts = pd.to_datetime(current_df['data_download_timestamp'].iloc[0], utc=True) date_display = ts.strftime('%B %d, %Y, %H:%M:%S %Z') # Calculate date range from publishedAt_dt min_ts = 0 max_ts = 100 date_range_text = "Date range unavailable" date_range_tuple = None if 'publishedAt_dt' in current_df.columns: valid_dates = current_df['publishedAt_dt'].dropna() if len(valid_dates) > 0: min_date = valid_dates.min() max_date = valid_dates.max() min_ts = int(min_date.timestamp()) max_ts = int(max_date.timestamp()) date_range_tuple = (min_ts, max_ts) date_range_text = f"**Full Range:** {min_date.strftime('%B %d, %Y')} to {max_date.strftime('%B %d, %Y')}" data_info_text = (f"### Data Information\n- Source: `{HF_DATASET_ID}`\n" f"- Status: {status_msg_from_load}\n" f"- Total records loaded: {len(current_df):,}\n" f"- Data as of: {date_display}\n") else: data_info_text = f"### Data Load Failed\n- {status_msg_from_load}" min_ts = 0 max_ts = 100 date_range_text = "Date range unavailable" date_range_tuple = None except Exception as e: status_msg_from_load = f"An unexpected error occurred: {str(e)}" data_info_text = f"### Critical Error\n- {status_msg_from_load}" load_success_flag = False current_df = pd.DataFrame() # Ensure df is empty on failure min_ts = 0 max_ts = 100 date_range_text = "Date range unavailable" date_range_tuple = None print(f"Critical error in load_and_generate_initial_plot: {e}") # --- Part 2: Generate Initial Plot --- progress(0.6, desc="Generating initial plot...") # Defaults matching UI definitions default_metric = "paper_upvotes" default_tag = "None" default_k = 25 default_group_by = "organization" default_skip_cats = "unaffiliated, Other" # Use taxonomy from JSON instead of calculating from dataset cat_choices = TAXONOMY_DATA['categories'] subcat_choices = TAXONOMY_DATA['subcategories'] topic_choices = TAXONOMY_DATA['topics'] # Reuse the existing controller function for plotting (with date range set to None for initial load) initial_plot, initial_status = ui_generate_plot_controller( default_metric, False, False, False, default_k, default_group_by, "All", "All", "All", default_skip_cats, None, current_df, progress ) # Also update taxonomy dropdown choices return ( current_df, load_success_flag, data_info_text, initial_status, initial_plot, gr.update(choices=cat_choices, value="All"), gr.update(choices=subcat_choices, value="All"), gr.update(choices=topic_choices, value="All"), gr.update(minimum=min_ts, maximum=max_ts, value=(min_ts, max_ts)), date_range_text, date_range_tuple, ) def ui_generate_plot_controller(metric_choice, has_code, has_media, has_org, k_orgs, group_by_choice, category_choice, subcategory_choice, topic_choice, skip_cats_input, date_range, df_current_datasets, progress=gr.Progress()): if df_current_datasets is None or df_current_datasets.empty: return create_treemap(pd.DataFrame(), metric_choice), "Dataset data is not loaded. Cannot generate plot." progress(0.1, desc="Aggregating data...") cats_to_skip = [cat.strip() for cat in skip_cats_input.split(',') if cat.strip()] # Apply content filters (checkboxes) df_filtered = df_current_datasets.copy() if has_code: df_filtered = df_filtered[df_filtered['has_code']] if has_media: df_filtered = df_filtered[df_filtered['has_media']] if has_org: df_filtered = df_filtered[df_filtered['has_organization']] # Apply taxonomy filters if category_choice and category_choice != 'All': df_filtered = df_filtered[df_filtered['primary_category'] == category_choice] if subcategory_choice and subcategory_choice != 'All': df_filtered = df_filtered[df_filtered['primary_subcategory'] == subcategory_choice] if topic_choice and topic_choice != 'All': df_filtered = df_filtered[df_filtered['primary_topic'] == topic_choice] treemap_df = make_treemap_data(df_filtered, metric_choice, k_orgs, None, cats_to_skip, group_by_choice, date_range) progress(0.7, desc="Generating plot...") title_labels = { "paper_upvotes": "Upvotes", "numComments": "Comments", } if group_by_choice == "topic": chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Topic" path = ["root", "primary_category", "primary_subcategory", "primary_topic", "paper_label"] else: chart_title = f"PaperVerse Daily - {title_labels.get(metric_choice, metric_choice)} by Organization" path = ["root", "organization", "paper_label"] plotly_fig = create_treemap( treemap_df, metric_choice, chart_title, path=path, metric_label=title_labels.get(metric_choice, metric_choice), ) if treemap_df.empty: plot_stats_md = "No data matches the selected filters. Please try different options." else: total_value_in_plot = treemap_df[metric_choice].sum() total_items_in_plot = treemap_df[treemap_df['paper_label'] != 'Other']['paper_label'].nunique() if group_by_choice == "topic": group_count = treemap_df[["primary_category", "primary_subcategory", "primary_topic"]].drop_duplicates().shape[0] group_line = f"**Topics Shown**: {group_count:,} unique triplets" else: group_line = f"**Organizations Shown**: {treemap_df['organization'].nunique():,}" plot_stats_md = ( f"## Plot Statistics\n- {group_line}\n" f"- **Individual Papers Shown**: {total_items_in_plot:,}\n" f"- **Total {title_labels.get(metric_choice, metric_choice)} in plot**: {int(total_value_in_plot):,}" ) return plotly_fig, plot_stats_md # --- Event Wiring --- ## CHANGE: Updated demo.load to call the new function and to add plot_output to the outputs list. demo.load( fn=load_and_generate_initial_plot, inputs=[], outputs=[ datasets_data_state, loading_complete_state, data_info_md, status_message_md, plot_output, category_filter_dropdown, subcategory_filter_dropdown, topic_filter_dropdown, date_range_slider, date_range_display, date_range_state, ] ) loading_complete_state.change( fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button ) # Update labels based on grouping mode group_by_dropdown.change( fn=_toggle_labels_by_grouping, inputs=group_by_dropdown, outputs=[top_k_dropdown, skip_cats_textbox], ) # Update date range display when slider changes date_range_slider.change( fn=_format_date_range, inputs=[date_range_state, date_range_slider], outputs=date_range_display, show_progress="hidden" ) def handle_search_details(search_text, df_current): """Search for an organization or paper and show detailed information.""" if not search_text or not search_text.strip(): return "<p style='color: gray;'>Please enter a search term</p>" if df_current is None or df_current.empty: return "<p style='color: gray;'>No data available</p>" search_text = search_text.strip() try: # Try to find matching rows by organization or paper title (case-insensitive partial match) matching_rows = df_current[ df_current['organization'].str.contains(search_text, case=False, na=False) | df_current['paper_label'].str.contains(search_text, case=False, na=False) | (df_current['paper_title'].str.contains(search_text, case=False, na=False) if 'paper_title' in df_current.columns else False) ] if matching_rows.empty: return f"<p style='color: orange;'>No results found for: <b>{search_text}</b></p><p style='color: gray;'>Try searching for an organization name (e.g., 'Qwen', 'Meta') or paper title keyword</p>" # Build the info panel HTML showing all matching results num_results = len(matching_rows) html_parts = [ f"<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; background: #f9f9f9; max-height: 600px; overflow-y: auto;'>", f"<h3 style='margin: 0 0 15px 0; color: #333;'>🔍 Found {num_results} result{'s' if num_results > 1 else ''} for: <span style='color: #0366d6;'>{search_text}</span></h3>" ] # Limit to first 20 results to avoid too much content display_rows = matching_rows.head(20) for idx, (_, row) in enumerate(display_rows.iterrows()): # Add separator between results if idx > 0: html_parts.append("<hr style='margin: 15px 0; border: none; border-top: 1px solid #ddd;'/>") html_parts.append("<div style='margin-bottom: 10px; overflow: auto;'>") # Get organization avatar from precomputed column org_avatar = row.get('organization_avatar') # Organization logo if available if org_avatar and isinstance(org_avatar, str) and org_avatar.strip() and org_avatar.strip().lower() not in ['none', 'null', 'n/a', '']: html_parts.append(f"<img src='{org_avatar}' style='max-width: 60px; max-height: 60px; border-radius: 50%; margin-bottom: 8px; float: left; margin-right: 12px; border: 2px solid #ddd;' onerror=\"this.style.display='none'\"/>") # Get paper thumbnail (direct field from schema) paper_thumbnail = row.get('thumbnail') # Paper thumbnail if available if paper_thumbnail and isinstance(paper_thumbnail, str) and paper_thumbnail.strip() and paper_thumbnail.strip().lower() not in ['none', 'null', 'n/a', '']: html_parts.append(f"<img src='{paper_thumbnail}' style='max-width: 120px; max-height: 120px; border-radius: 8px; margin-bottom: 8px; float: right; margin-left: 12px; border: 1px solid #ddd;' onerror=\"this.style.display='none'\"/>") # Organization name org_name = row.get('organization', 'Unknown') html_parts.append(f"<p style='margin: 0 0 5px 0; font-weight: bold; color: #333;'>🏢 {org_name}</p>") # Paper title paper_title = row.get('paper_title', row.get('title', 'Untitled')) html_parts.append(f"<p style='margin: 0 0 5px 0; color: #555; font-size: 0.95em;'>📄 {paper_title}</p>") # Topic hierarchy category = row.get('primary_category', 'Unknown') subcategory = row.get('primary_subcategory', 'Unknown') topic = row.get('primary_topic', 'Unknown') html_parts.append(f"<p style='margin: 0 0 5px 0; font-size: 0.9em; color: #666;'><b>Topics:</b> {category} → {subcategory} → {topic}</p>") # Metrics upvotes = row.get('paper_upvotes', 0) comments = row.get('numComments', 0) html_parts.append(f"<p style='margin: 0 0 5px 0; font-size: 0.9em;'><b>Metrics:</b> ⬆️ {upvotes:,} upvotes | 💬 {comments:,} comments</p>") # Links github = row.get('paper_githubRepo') project = row.get('paper_projectPage') links = [] if github and isinstance(github, str) and github.strip() and github.strip().lower() not in ['n/a', 'none']: links.append(f"<a href='{github}' target='_blank' style='color: #0366d6; margin-right: 15px;'>🔗 GitHub</a>") if project and isinstance(project, str) and project.strip() and project.strip().lower() not in ['n/a', 'none']: links.append(f"<a href='{project}' target='_blank' style='color: #0366d6;'>🔗 Project</a>") if links: html_parts.append(f"<p style='margin: 0; font-size: 0.9em;'>{' '.join(links)}</p>") html_parts.append("<div style='clear: both;'></div>") html_parts.append("</div>") if num_results > 20: html_parts.append(f"<p style='margin-top: 15px; color: #666; font-style: italic;'>Showing first 20 of {num_results} results. Refine your search for fewer results.</p>") html_parts.append("</div>") return "".join(html_parts) except Exception as e: return f"<p style='color: red;'>Error displaying details: {str(e)}</p>" generate_plot_button.click( fn=ui_generate_plot_controller, inputs=[ count_by_dropdown, filter_code, filter_media, filter_org, top_k_dropdown, group_by_dropdown, category_filter_dropdown, subcategory_filter_dropdown, topic_filter_dropdown, skip_cats_textbox, date_range_slider, datasets_data_state, ], outputs=[plot_output, status_message_md] ) # Handle search button for showing details search_button.click( fn=handle_search_details, inputs=[search_item, datasets_data_state], outputs=[selected_info_html] ) # Also trigger on Enter key in search box search_item.submit( fn=handle_search_details, inputs=[search_item, datasets_data_state], outputs=[selected_info_html] ) if __name__ == "__main__": print("Application starting...") demo.queue().launch()