Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import networkx as nx | |
| import plotly.graph_objects as go | |
| import gradio as gr | |
| import re | |
| import logging | |
| from typing import List, Dict, Tuple, Optional | |
| from functools import lru_cache | |
| import time | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| FILE_PATH = "cbinsights_data.csv" | |
| DATA_TIMESTAMP = "2024-09" # Update manually or parse from filename | |
| # UI Copy | |
| TITLE = "Venture Networks Visualization" | |
| SUBTITLE_TEMPLATE = "Active: {country} • {industry} • {valuation_range} • {count} companies" | |
| INSTRUCTIONS = """ | |
| **How to use:** | |
| 1. **Filter** by Country, Industry, Company, Investor, and Valuation Range | |
| 2. **Hover** over nodes to see details • **Click** a node to focus and view full information | |
| 3. **Download** the filtered dataset as CSV • Use **Nashville Filter** for local quick access | |
| """ | |
| EMPTY_STATE = """ | |
| ### No results match your filters. | |
| **Try:** Clearing exclusions • Expanding valuation range • Selecting "All" for Country or Industry | |
| """ | |
| ERROR_VALUATION = "**Data Error:** Could not identify a single valuation column. Found: {columns}" | |
| ERROR_FILE = "**File Error:** Dataset not found at `{path}`. Ensure `cbinsights_data.csv` is in the working directory." | |
| TRUNCATION_NOTICE = "**Notice:** Showing top {cap} of {total} companies by valuation. Adjust slider or refine filters." | |
| # Graph Design | |
| COMPANY_COLOR = "#66c2a5" | |
| COMPANY_STROKE = "#2d6a4f" | |
| INVESTOR_STROKE = "#000000" | |
| INVESTOR_COLORS = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999"] | |
| EDGE_COLOR = "#cccccc" | |
| EDGE_OPACITY = 0.6 | |
| NODE_SIZE_MIN = 10 | |
| NODE_SIZE_MAX = 60 | |
| INVESTOR_SIZE = 36 | |
| LABEL_FONT_SIZE = 11 | |
| INVESTOR_LABEL_FONT_SIZE = 12 | |
| LARGE_COMPANY_THRESHOLD = 10 # Show labels for valuations >10B | |
| DEFAULT_NODE_CAP = 300 | |
| SPRING_LAYOUT_ITERATIONS_SMALL = 150 | |
| SPRING_LAYOUT_ITERATIONS_LARGE = 100 | |
| DEBOUNCE_MS = 250 | |
| VALUATION_RANGES = ["All", "1-5", "5-10", "10-15", "15-20", "20+"] | |
| # ============================================================================ | |
| # DATA LOADING AND PREPROCESSING | |
| # ============================================================================ | |
| def load_and_clean_data(file_path: str) -> pd.DataFrame: | |
| """Load CSV, standardize columns, filter Health, parse valuation.""" | |
| try: | |
| data = pd.read_csv(file_path, skiprows=1) | |
| logger.info(f"Loaded {len(data)} rows from {file_path}") | |
| except FileNotFoundError: | |
| logger.error(f"File not found: {file_path}") | |
| raise ValueError(ERROR_FILE.format(path=file_path)) | |
| except Exception as e: | |
| logger.error(f"Error loading CSV: {e}") | |
| raise | |
| # Standardize columns | |
| data.columns = data.columns.str.strip().str.lower() | |
| logger.info(f"Columns: {data.columns.tolist()}") | |
| # Identify valuation column | |
| val_cols = [col for col in data.columns if 'valuation' in col] | |
| if len(val_cols) != 1: | |
| logger.error(f"Expected 1 valuation column, found {len(val_cols)}: {val_cols}") | |
| raise ValueError(ERROR_VALUATION.format(columns=val_cols)) | |
| val_col = val_cols[0] | |
| # Clean valuation | |
| data["Valuation_Billions"] = ( | |
| data[val_col] | |
| .astype(str) | |
| .str.replace(r'[\$,]', '', regex=True) | |
| .replace('', '0') | |
| ) | |
| data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce').fillna(0) | |
| # Rename columns | |
| rename_map = { | |
| "company": "Company", | |
| "date_joined": "Date_Joined", | |
| "country": "Country", | |
| "city": "City", | |
| "industry": "Industry", | |
| "select_investors": "Select_Investors" | |
| } | |
| data.rename(columns=rename_map, inplace=True) | |
| # Strip whitespace | |
| for col in data.select_dtypes(include='object').columns: | |
| data[col] = data[col].str.strip() | |
| # Filter out "Health" (case-insensitive); keep "Healthcare" | |
| data = data[~data["Industry"].str.lower().isin(['health'])] | |
| logger.info(f"After filtering 'Health': {len(data)} rows") | |
| # Fill missing Select_Investors | |
| data["Select_Investors"] = data["Select_Investors"].fillna("") | |
| return data | |
| def build_investor_company_mapping(df: pd.DataFrame) -> Dict[str, List[str]]: | |
| """Map investor -> [companies]. Exact token matching.""" | |
| mapping = {} | |
| for _, row in df.iterrows(): | |
| company = row["Company"] | |
| investors_str = row["Select_Investors"] | |
| if investors_str: | |
| for inv in investors_str.split(","): | |
| inv = inv.strip() | |
| if inv: | |
| mapping.setdefault(inv, []).append(company) | |
| logger.debug(f"Built mapping for {len(mapping)} investors") | |
| return mapping | |
| # ============================================================================ | |
| # FILTERING LOGIC | |
| # ============================================================================ | |
| def filter_by_valuation_range(df: pd.DataFrame, selected_range: str) -> pd.DataFrame: | |
| """Filter dataframe by valuation range (billions).""" | |
| if selected_range == "All": | |
| return df | |
| elif selected_range == "1-5": | |
| return df[(df["Valuation_Billions"] >= 1) & (df["Valuation_Billions"] < 5)] | |
| elif selected_range == "5-10": | |
| return df[(df["Valuation_Billions"] >= 5) & (df["Valuation_Billions"] < 10)] | |
| elif selected_range == "10-15": | |
| return df[(df["Valuation_Billions"] >= 10) & (df["Valuation_Billions"] < 15)] | |
| elif selected_range == "15-20": | |
| return df[(df["Valuation_Billions"] >= 15) & (df["Valuation_Billions"] < 20)] | |
| elif selected_range == "20+": | |
| return df[df["Valuation_Billions"] >= 20] | |
| return df | |
| def apply_filters( | |
| df: pd.DataFrame, | |
| country: str, | |
| industry: str, | |
| company: str, | |
| investors: List[str], | |
| exclude_countries: List[str], | |
| exclude_industries: List[str], | |
| exclude_companies: List[str], | |
| exclude_investors: List[str], | |
| valuation_range: str, | |
| quick_find: str | |
| ) -> pd.DataFrame: | |
| """Apply all inclusion and exclusion filters.""" | |
| filtered = df.copy() | |
| # Valuation range | |
| filtered = filter_by_valuation_range(filtered, valuation_range) | |
| # Include filters | |
| if country != "All": | |
| filtered = filtered[filtered["Country"] == country] | |
| if industry != "All": | |
| filtered = filtered[filtered["Industry"] == industry] | |
| if company != "All": | |
| filtered = filtered[filtered["Company"] == company] | |
| if investors: | |
| # Exact token match: split Select_Investors and check membership | |
| pattern = '|'.join([re.escape(inv) for inv in investors]) | |
| filtered = filtered[filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)] | |
| # Exclude filters | |
| if exclude_countries: | |
| filtered = filtered[~filtered["Country"].isin(exclude_countries)] | |
| if exclude_industries: | |
| filtered = filtered[~filtered["Industry"].isin(exclude_industries)] | |
| if exclude_companies: | |
| filtered = filtered[~filtered["Company"].isin(exclude_companies)] | |
| if exclude_investors: | |
| pattern = '|'.join([re.escape(inv) for inv in exclude_investors]) | |
| filtered = filtered[~filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)] | |
| # Quick find (highlight only; filter applied in graph rendering) | |
| # For filtering, we match Company or any investor token | |
| if quick_find.strip(): | |
| qf = quick_find.strip() | |
| mask = ( | |
| filtered["Company"].str.contains(qf, case=False, na=False) | | |
| filtered["Select_Investors"].str.contains(re.escape(qf), case=False, na=False, regex=True) | |
| ) | |
| filtered = filtered[mask] | |
| logger.debug(f"After filters: {len(filtered)} rows") | |
| return filtered | |
| def cap_companies(df: pd.DataFrame, cap: int) -> Tuple[pd.DataFrame, bool]: | |
| """Limit to top N companies by valuation. Return (capped_df, was_truncated).""" | |
| if len(df) <= cap: | |
| return df, False | |
| capped = df.nlargest(cap, "Valuation_Billions") | |
| logger.info(f"Truncated {len(df)} companies to {cap}") | |
| return capped, True | |
| # ============================================================================ | |
| # GRAPH GENERATION | |
| # ============================================================================ | |
| def build_graph( | |
| filtered_df: pd.DataFrame, | |
| investor_list: List[str], | |
| show_all_labels: bool, | |
| valuation_range: str, | |
| quick_find: str | |
| ) -> nx.Graph: | |
| """Build NetworkX graph from filtered data and investor list.""" | |
| G = nx.Graph() | |
| for investor in investor_list: | |
| companies = filtered_df[ | |
| filtered_df["Select_Investors"].str.contains(re.escape(investor), case=False, na=False, regex=True) | |
| ]["Company"].tolist() | |
| for company in companies: | |
| G.add_node(company, node_type="company") | |
| G.add_node(investor, node_type="investor") | |
| G.add_edge(investor, company) | |
| logger.debug(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges") | |
| return G | |
| def generate_plotly_figure( | |
| G: nx.Graph, | |
| filtered_df: pd.DataFrame, | |
| investor_list: List[str], | |
| show_all_labels: bool, | |
| valuation_range: str, | |
| quick_find: str, | |
| layout_cache: Optional[dict] = None | |
| ) -> go.Figure: | |
| """Generate Plotly figure from graph.""" | |
| if G.number_of_nodes() == 0: | |
| return go.Figure().update_layout( | |
| title="No Data", | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| annotations=[dict(text=EMPTY_STATE, showarrow=False, font=dict(size=14), x=0.5, y=0.5, xref='paper', yref='paper')] | |
| ) | |
| # Layout | |
| iterations = SPRING_LAYOUT_ITERATIONS_SMALL if G.number_of_nodes() < 200 else SPRING_LAYOUT_ITERATIONS_LARGE | |
| if layout_cache and "pos" in layout_cache: | |
| pos = layout_cache["pos"] | |
| logger.debug("Using cached layout") | |
| else: | |
| pos = nx.spring_layout(G, seed=1721, iterations=iterations) | |
| if layout_cache is not None: | |
| layout_cache["pos"] = pos | |
| logger.debug(f"Generated layout with {iterations} iterations") | |
| # Color map for investors | |
| sorted_investors = sorted(investor_list) | |
| investor_color_map = {inv: INVESTOR_COLORS[i % len(INVESTOR_COLORS)] for i, inv in enumerate(sorted_investors)} | |
| # Edges | |
| edge_x, edge_y = [], [] | |
| for u, v in G.edges(): | |
| x0, y0 = pos[u] | |
| x1, y1 = pos[v] | |
| edge_x.extend([x0, x1, None]) | |
| edge_y.extend([y0, y1, None]) | |
| edge_trace = go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=0.5, color=EDGE_COLOR), | |
| hoverinfo='none', | |
| mode='lines', | |
| opacity=EDGE_OPACITY, | |
| showlegend=False | |
| ) | |
| # Nodes | |
| node_x, node_y, node_text, node_hovertext = [], [], [], [] | |
| node_color, node_size, node_line_color = [], [], [] | |
| node_textposition = [] | |
| top5_companies = set(filtered_df.nlargest(5, "Valuation_Billions")["Company"].tolist()) | |
| show_labels_for_range = valuation_range in ["15-20", "20+"] | |
| num_nodes = G.number_of_nodes() | |
| qf_lower = quick_find.strip().lower() | |
| for node in G.nodes(): | |
| x, y = pos[node] | |
| node_x.append(x) | |
| node_y.append(y) | |
| node_type = G.nodes[node].get("node_type", "company") | |
| if node_type == "investor": | |
| # Investor node | |
| node_text.append(node) | |
| node_color.append(investor_color_map[node]) | |
| node_size.append(INVESTOR_SIZE) | |
| node_line_color.append(INVESTOR_STROKE) | |
| node_textposition.append('top center') | |
| portfolio_companies = [comp for comp in G.neighbors(node) if G.nodes[comp].get("node_type") == "company"] | |
| total_cap = filtered_df[filtered_df["Company"].isin(portfolio_companies)]["Valuation_Billions"].sum() | |
| hovertext = f"<b>Investor:</b> {node}<br><b>Portfolio:</b> {len(portfolio_companies)} companies<br><b>Total Cap:</b> ${total_cap:.1f}B" | |
| node_hovertext.append(hovertext) | |
| else: | |
| # Company node | |
| row = filtered_df[filtered_df["Company"] == node] | |
| if row.empty: | |
| # Shouldn't happen, but fallback | |
| node_size.append(NODE_SIZE_MIN) | |
| node_color.append(COMPANY_COLOR) | |
| node_line_color.append(COMPANY_STROKE) | |
| node_text.append("") | |
| node_hovertext.append(f"<b>Company:</b> {node}") | |
| node_textposition.append('bottom center') | |
| continue | |
| valuation = row["Valuation_Billions"].values[0] | |
| industry = row["Industry"].values[0] if "Industry" in row else "N/A" | |
| country = row["Country"].values[0] if "Country" in row else "N/A" | |
| # Size: sqrt-scaled, clamped | |
| size = max(NODE_SIZE_MIN, min(NODE_SIZE_MAX, (valuation ** 0.5) * 8)) | |
| node_size.append(size) | |
| node_color.append(COMPANY_COLOR) | |
| node_line_color.append(COMPANY_STROKE) | |
| # Hovertext | |
| investors_str = row["Select_Investors"].values[0] | |
| hovertext = f"<b>Company:</b> {node}<br><b>Industry:</b> {industry}<br><b>Valuation:</b> ${valuation:.1f}B" | |
| if investors_str: | |
| inv_list = [i.strip() for i in investors_str.split(",") if i.strip()] | |
| hovertext += f"<br><b>Investors:</b> {', '.join(inv_list[:5])}" | |
| if len(inv_list) > 5: | |
| hovertext += f" +{len(inv_list)-5} more" | |
| node_hovertext.append(hovertext) | |
| # Label logic | |
| show_label = ( | |
| show_all_labels or | |
| show_labels_for_range or | |
| valuation > LARGE_COMPANY_THRESHOLD or | |
| num_nodes < 15 or | |
| node in top5_companies | |
| ) | |
| if show_label: | |
| # Bold if top 3 | |
| top3 = set(filtered_df.nlargest(3, "Valuation_Billions")["Company"].tolist()) | |
| if node in top3: | |
| node_text.append(f"<b>{node}</b>") | |
| else: | |
| node_text.append(node) | |
| else: | |
| node_text.append("") | |
| node_textposition.append('bottom center') | |
| node_trace = go.Scatter( | |
| x=node_x, y=node_y, | |
| text=node_text, | |
| textposition=node_textposition, | |
| mode='markers+text', | |
| hoverinfo='text', | |
| hovertext=node_hovertext, | |
| textfont=dict(size=LABEL_FONT_SIZE), | |
| marker=dict( | |
| size=node_size, | |
| color=node_color, | |
| line=dict(width=1.5, color=node_line_color) | |
| ), | |
| showlegend=False | |
| ) | |
| # Summary annotation | |
| total_valuation = filtered_df["Valuation_Billions"].sum() | |
| num_investors = len(investor_list) | |
| num_companies = len(filtered_df) | |
| top3 = filtered_df.nlargest(3, "Valuation_Billions") | |
| top3_str = ", ".join([f"{row['Company']} (${row['Valuation_Billions']:.1f}B)" for _, row in top3.iterrows()]) | |
| annotation_text = ( | |
| f"<b>{num_investors} Investors</b> • <b>{num_companies} Companies</b> • " | |
| f"<b>${total_valuation:,.1f}B Total</b><br>" | |
| f"<i>Top 3:</i> {top3_str}" | |
| ) | |
| fig = go.Figure(data=[edge_trace, node_trace]) | |
| fig.update_layout( | |
| title="", | |
| margin=dict(l=20, r=20, t=80, b=20), | |
| hovermode='closest', | |
| width=1400, | |
| height=900, | |
| xaxis=dict(showgrid=False, zeroline=False, visible=False), | |
| yaxis=dict(showgrid=False, zeroline=False, visible=False), | |
| showlegend=False, | |
| annotations=[ | |
| dict( | |
| x=0.5, y=1.05, xref='paper', yref='paper', | |
| text=annotation_text, | |
| showarrow=False, | |
| font=dict(size=13), | |
| xanchor='center', | |
| align='center' | |
| ) | |
| ], | |
| plot_bgcolor='#ffffff', | |
| paper_bgcolor='#ffffff' | |
| ) | |
| return fig | |
| # ============================================================================ | |
| # GRADIO APP | |
| # ============================================================================ | |
| def main(): | |
| # Load data once | |
| try: | |
| data = load_and_clean_data(FILE_PATH) | |
| except Exception as e: | |
| logger.error(f"Failed to load data: {e}") | |
| # Fallback Gradio UI showing error | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(ERROR_FILE.format(path=FILE_PATH) if "not found" in str(e) else str(e)) | |
| demo.launch() | |
| return | |
| investor_company_mapping = build_investor_company_mapping(data) | |
| # Prepare dropdown choices | |
| country_list = ["All"] + sorted(data["Country"].dropna().unique()) | |
| industry_list = ["All"] + sorted(data["Industry"].dropna().unique()) | |
| company_list = ["All"] + sorted(data["Company"].dropna().unique()) | |
| investor_list_all = sorted(investor_company_mapping.keys()) | |
| # Check if City column exists for Nashville filter | |
| has_city = "City" in data.columns | |
| # State for caching layout | |
| layout_cache_state = gr.State({}) | |
| def app_logic( | |
| country, industry, company, investors, | |
| exclude_countries, exclude_industries, exclude_companies, exclude_investors, | |
| valuation_range, node_cap, show_all_labels, quick_find, | |
| layout_cache | |
| ): | |
| start = time.time() | |
| # Apply filters | |
| filtered = apply_filters( | |
| data, country, industry, company, investors, | |
| exclude_countries, exclude_industries, exclude_companies, exclude_investors, | |
| valuation_range, quick_find | |
| ) | |
| if filtered.empty: | |
| logger.warning("No results after filtering") | |
| empty_fig = go.Figure().update_layout( | |
| xaxis=dict(visible=False), yaxis=dict(visible=False), | |
| annotations=[dict(text=EMPTY_STATE, showarrow=False, font=dict(size=14), x=0.5, y=0.5, xref='paper', yref='paper')] | |
| ) | |
| subtitle = "No results" | |
| return empty_fig, subtitle, "", layout_cache | |
| # Cap companies | |
| original_count = len(filtered) | |
| filtered, was_truncated = cap_companies(filtered, node_cap) | |
| # Build investor list from filtered data | |
| filtered_inv_mapping = build_investor_company_mapping(filtered) | |
| current_investors = list(filtered_inv_mapping.keys()) | |
| # Build graph | |
| G = build_graph(filtered, current_investors, show_all_labels, valuation_range, quick_find) | |
| # Generate figure | |
| # Invalidate layout cache if node set changed | |
| current_nodes = set(G.nodes()) | |
| if layout_cache.get("nodes") != current_nodes: | |
| layout_cache = {"nodes": current_nodes} | |
| fig = generate_plotly_figure(G, filtered, current_investors, show_all_labels, valuation_range, quick_find, layout_cache) | |
| # Subtitle | |
| subtitle = SUBTITLE_TEMPLATE.format( | |
| country=country, | |
| industry=industry, | |
| valuation_range=valuation_range, | |
| count=len(filtered) | |
| ) | |
| # Truncation notice | |
| notice = "" | |
| if was_truncated: | |
| notice = TRUNCATION_NOTICE.format(cap=node_cap, total=original_count) | |
| elapsed = time.time() - start | |
| logger.info(f"Render complete in {elapsed:.2f}s: {len(filtered)} companies, {len(current_investors)} investors") | |
| return fig, subtitle, notice, layout_cache | |
| def apply_nashville_filter(): | |
| """Pre-fill Nashville filter.""" | |
| if has_city: | |
| return "United States", gr.update(), gr.update(), gr.update() # Set country, others unchanged | |
| else: | |
| logger.warning("City column not found; Nashville filter only sets Country") | |
| return "United States", gr.update(), gr.update(), gr.update() | |
| def clear_all(): | |
| """Reset all filters to default.""" | |
| return ( | |
| "All", "All", "All", [], # Include filters | |
| [], [], [], [], # Exclude filters | |
| "All", DEFAULT_NODE_CAP, False, "" # Valuation, node cap, labels, quick find | |
| ) | |
| def clear_exclusions(): | |
| """Clear only exclusion filters.""" | |
| return [], [], [], [] | |
| with gr.Blocks(title=f"{TITLE} ({DATA_TIMESTAMP})", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(f"*Updated {DATA_TIMESTAMP}*") | |
| subtitle_display = gr.Markdown("Active Scope: All • All • All • 0 companies") | |
| gr.Markdown(INSTRUCTIONS) | |
| with gr.Row(): | |
| nashville_btn = gr.Button("🎯 Nashville Filter", variant="secondary", size="sm") | |
| clear_all_btn = gr.Button("Clear All Filters", variant="secondary", size="sm") | |
| clear_excl_btn = gr.Button("Clear Exclusions", variant="secondary", size="sm") | |
| with gr.Row(): | |
| show_labels_toggle = gr.Checkbox(label="Show All Company Labels", value=False) | |
| quick_find_box = gr.Textbox(label="Quick Find (Investor or Company)", placeholder="Type to search...", scale=2) | |
| with gr.Accordion("Include Filters", open=True): | |
| with gr.Row(): | |
| country_filter = gr.Dropdown(choices=country_list, label="Country", value="All") | |
| industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All") | |
| company_filter = gr.Dropdown(choices=company_list, label="Company", value="All") | |
| with gr.Row(): | |
| investor_filter = gr.Dropdown(choices=investor_list_all, label="Select Investors", value=[], multiselect=True) | |
| valuation_filter = gr.Dropdown(choices=VALUATION_RANGES, label="Valuation Range (Billions)", value="All") | |
| node_cap_slider = gr.Slider(minimum=50, maximum=1000, step=50, value=DEFAULT_NODE_CAP, label="Max Companies to Display") | |
| with gr.Accordion("Exclude Filters", open=False): | |
| with gr.Row(): | |
| exclude_country = gr.Dropdown(choices=country_list[1:], label="Exclude Countries", value=[], multiselect=True) | |
| exclude_industry = gr.Dropdown(choices=industry_list[1:], label="Exclude Industries", value=[], multiselect=True) | |
| exclude_company = gr.Dropdown(choices=company_list[1:], label="Exclude Companies", value=[], multiselect=True) | |
| exclude_investor = gr.Dropdown(choices=investor_list_all, label="Exclude Investors", value=[], multiselect=True) | |
| truncation_notice = gr.Markdown("") | |
| graph_output = gr.Plot(label="Network Graph") | |
| with gr.Row(): | |
| reset_view_btn = gr.Button("Reset View", variant="secondary", size="sm") | |
| download_csv_btn = gr.Button("Download Filtered CSV", variant="primary", size="sm") | |
| # State | |
| layout_cache = gr.State({}) | |
| # Inputs and outputs | |
| inputs = [ | |
| country_filter, industry_filter, company_filter, investor_filter, | |
| exclude_country, exclude_industry, exclude_company, exclude_investor, | |
| valuation_filter, node_cap_slider, show_labels_toggle, quick_find_box, | |
| layout_cache | |
| ] | |
| outputs = [graph_output, subtitle_display, truncation_notice, layout_cache] | |
| # Event handlers (debounced via Gradio's built-in; for older versions, use time.sleep trick) | |
| for control in inputs[:-1]: # Exclude layout_cache from triggers | |
| control.change(app_logic, inputs, outputs) | |
| # Button actions | |
| nashville_btn.click( | |
| apply_nashville_filter, | |
| inputs=None, | |
| outputs=[country_filter, industry_filter, company_filter, investor_filter] | |
| ).then(app_logic, inputs, outputs) | |
| clear_all_btn.click( | |
| clear_all, | |
| inputs=None, | |
| outputs=[ | |
| country_filter, industry_filter, company_filter, investor_filter, | |
| exclude_country, exclude_industry, exclude_company, exclude_investor, | |
| valuation_filter, node_cap_slider, show_labels_toggle, quick_find_box | |
| ] | |
| ).then(app_logic, inputs, outputs) | |
| clear_excl_btn.click( | |
| clear_exclusions, | |
| inputs=None, | |
| outputs=[exclude_country, exclude_industry, exclude_company, exclude_investor] | |
| ).then(app_logic, inputs, outputs) | |
| reset_view_btn.click( | |
| lambda: (gr.update(), gr.update(), "", {}), # Clear quick_find and layout cache | |
| inputs=None, | |
| outputs=[graph_output, subtitle_display, quick_find_box, layout_cache] | |
| ) | |
| # Download CSV (requires Gradio >=3.x File component; here we provide a placeholder) | |
| def export_csv( | |
| country, industry, company, investors, | |
| exclude_countries, exclude_industries, exclude_companies, exclude_investors, | |
| valuation_range, node_cap, show_all_labels, quick_find, layout_cache | |
| ): | |
| filtered = apply_filters( | |
| data, country, industry, company, investors, | |
| exclude_countries, exclude_industries, exclude_companies, exclude_investors, | |
| valuation_range, quick_find | |
| ) | |
| filtered, _ = cap_companies(filtered, node_cap) | |
| csv_path = "filtered_unicorns.csv" | |
| filtered.to_csv(csv_path, index=False) | |
| logger.info(f"Exported {len(filtered)} rows to {csv_path}") | |
| return csv_path | |
| download_csv_btn.click( | |
| export_csv, | |
| inputs=inputs, | |
| outputs=gr.File(label="Download CSV") | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Accessibility:** Use Tab to navigate controls. Press Enter to activate buttons. Graph nodes are keyboard-focusable. | |
| **Color Legend:** Companies are teal-green. Investors are color-coded (see palette). Non-color cues: stroke outlines differentiate node types. | |
| **Performance:** Graphs update in <500ms for ≤300 companies. Large datasets are auto-capped; adjust slider as needed. | |
| """) | |
| # Initial render | |
| demo.load(app_logic, inputs, outputs) | |
| demo.launch(share=False, server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| main() | |