networkx-saas / app.py
LeonceNsh's picture
Update app.py
f5f328a verified
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import gradio as gr
import re
import logging
from typing import List, Dict, Tuple, Optional
from functools import lru_cache
import time
# ============================================================================
# CONFIGURATION
# ============================================================================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
FILE_PATH = "cbinsights_data.csv"
DATA_TIMESTAMP = "2024-09" # Update manually or parse from filename
# UI Copy
TITLE = "Venture Networks Visualization"
SUBTITLE_TEMPLATE = "Active: {country} • {industry} • {valuation_range} • {count} companies"
INSTRUCTIONS = """
**How to use:**
1. **Filter** by Country, Industry, Company, Investor, and Valuation Range
2. **Hover** over nodes to see details • **Click** a node to focus and view full information
3. **Download** the filtered dataset as CSV • Use **Nashville Filter** for local quick access
"""
EMPTY_STATE = """
### No results match your filters.
**Try:** Clearing exclusions • Expanding valuation range • Selecting "All" for Country or Industry
"""
ERROR_VALUATION = "**Data Error:** Could not identify a single valuation column. Found: {columns}"
ERROR_FILE = "**File Error:** Dataset not found at `{path}`. Ensure `cbinsights_data.csv` is in the working directory."
TRUNCATION_NOTICE = "**Notice:** Showing top {cap} of {total} companies by valuation. Adjust slider or refine filters."
# Graph Design
COMPANY_COLOR = "#66c2a5"
COMPANY_STROKE = "#2d6a4f"
INVESTOR_STROKE = "#000000"
INVESTOR_COLORS = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999"]
EDGE_COLOR = "#cccccc"
EDGE_OPACITY = 0.6
NODE_SIZE_MIN = 10
NODE_SIZE_MAX = 60
INVESTOR_SIZE = 36
LABEL_FONT_SIZE = 11
INVESTOR_LABEL_FONT_SIZE = 12
LARGE_COMPANY_THRESHOLD = 10 # Show labels for valuations >10B
DEFAULT_NODE_CAP = 300
SPRING_LAYOUT_ITERATIONS_SMALL = 150
SPRING_LAYOUT_ITERATIONS_LARGE = 100
DEBOUNCE_MS = 250
VALUATION_RANGES = ["All", "1-5", "5-10", "10-15", "15-20", "20+"]
# ============================================================================
# DATA LOADING AND PREPROCESSING
# ============================================================================
def load_and_clean_data(file_path: str) -> pd.DataFrame:
"""Load CSV, standardize columns, filter Health, parse valuation."""
try:
data = pd.read_csv(file_path, skiprows=1)
logger.info(f"Loaded {len(data)} rows from {file_path}")
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
raise ValueError(ERROR_FILE.format(path=file_path))
except Exception as e:
logger.error(f"Error loading CSV: {e}")
raise
# Standardize columns
data.columns = data.columns.str.strip().str.lower()
logger.info(f"Columns: {data.columns.tolist()}")
# Identify valuation column
val_cols = [col for col in data.columns if 'valuation' in col]
if len(val_cols) != 1:
logger.error(f"Expected 1 valuation column, found {len(val_cols)}: {val_cols}")
raise ValueError(ERROR_VALUATION.format(columns=val_cols))
val_col = val_cols[0]
# Clean valuation
data["Valuation_Billions"] = (
data[val_col]
.astype(str)
.str.replace(r'[\$,]', '', regex=True)
.replace('', '0')
)
data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce').fillna(0)
# Rename columns
rename_map = {
"company": "Company",
"date_joined": "Date_Joined",
"country": "Country",
"city": "City",
"industry": "Industry",
"select_investors": "Select_Investors"
}
data.rename(columns=rename_map, inplace=True)
# Strip whitespace
for col in data.select_dtypes(include='object').columns:
data[col] = data[col].str.strip()
# Filter out "Health" (case-insensitive); keep "Healthcare"
data = data[~data["Industry"].str.lower().isin(['health'])]
logger.info(f"After filtering 'Health': {len(data)} rows")
# Fill missing Select_Investors
data["Select_Investors"] = data["Select_Investors"].fillna("")
return data
def build_investor_company_mapping(df: pd.DataFrame) -> Dict[str, List[str]]:
"""Map investor -> [companies]. Exact token matching."""
mapping = {}
for _, row in df.iterrows():
company = row["Company"]
investors_str = row["Select_Investors"]
if investors_str:
for inv in investors_str.split(","):
inv = inv.strip()
if inv:
mapping.setdefault(inv, []).append(company)
logger.debug(f"Built mapping for {len(mapping)} investors")
return mapping
# ============================================================================
# FILTERING LOGIC
# ============================================================================
def filter_by_valuation_range(df: pd.DataFrame, selected_range: str) -> pd.DataFrame:
"""Filter dataframe by valuation range (billions)."""
if selected_range == "All":
return df
elif selected_range == "1-5":
return df[(df["Valuation_Billions"] >= 1) & (df["Valuation_Billions"] < 5)]
elif selected_range == "5-10":
return df[(df["Valuation_Billions"] >= 5) & (df["Valuation_Billions"] < 10)]
elif selected_range == "10-15":
return df[(df["Valuation_Billions"] >= 10) & (df["Valuation_Billions"] < 15)]
elif selected_range == "15-20":
return df[(df["Valuation_Billions"] >= 15) & (df["Valuation_Billions"] < 20)]
elif selected_range == "20+":
return df[df["Valuation_Billions"] >= 20]
return df
def apply_filters(
df: pd.DataFrame,
country: str,
industry: str,
company: str,
investors: List[str],
exclude_countries: List[str],
exclude_industries: List[str],
exclude_companies: List[str],
exclude_investors: List[str],
valuation_range: str,
quick_find: str
) -> pd.DataFrame:
"""Apply all inclusion and exclusion filters."""
filtered = df.copy()
# Valuation range
filtered = filter_by_valuation_range(filtered, valuation_range)
# Include filters
if country != "All":
filtered = filtered[filtered["Country"] == country]
if industry != "All":
filtered = filtered[filtered["Industry"] == industry]
if company != "All":
filtered = filtered[filtered["Company"] == company]
if investors:
# Exact token match: split Select_Investors and check membership
pattern = '|'.join([re.escape(inv) for inv in investors])
filtered = filtered[filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)]
# Exclude filters
if exclude_countries:
filtered = filtered[~filtered["Country"].isin(exclude_countries)]
if exclude_industries:
filtered = filtered[~filtered["Industry"].isin(exclude_industries)]
if exclude_companies:
filtered = filtered[~filtered["Company"].isin(exclude_companies)]
if exclude_investors:
pattern = '|'.join([re.escape(inv) for inv in exclude_investors])
filtered = filtered[~filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)]
# Quick find (highlight only; filter applied in graph rendering)
# For filtering, we match Company or any investor token
if quick_find.strip():
qf = quick_find.strip()
mask = (
filtered["Company"].str.contains(qf, case=False, na=False) |
filtered["Select_Investors"].str.contains(re.escape(qf), case=False, na=False, regex=True)
)
filtered = filtered[mask]
logger.debug(f"After filters: {len(filtered)} rows")
return filtered
def cap_companies(df: pd.DataFrame, cap: int) -> Tuple[pd.DataFrame, bool]:
"""Limit to top N companies by valuation. Return (capped_df, was_truncated)."""
if len(df) <= cap:
return df, False
capped = df.nlargest(cap, "Valuation_Billions")
logger.info(f"Truncated {len(df)} companies to {cap}")
return capped, True
# ============================================================================
# GRAPH GENERATION
# ============================================================================
def build_graph(
filtered_df: pd.DataFrame,
investor_list: List[str],
show_all_labels: bool,
valuation_range: str,
quick_find: str
) -> nx.Graph:
"""Build NetworkX graph from filtered data and investor list."""
G = nx.Graph()
for investor in investor_list:
companies = filtered_df[
filtered_df["Select_Investors"].str.contains(re.escape(investor), case=False, na=False, regex=True)
]["Company"].tolist()
for company in companies:
G.add_node(company, node_type="company")
G.add_node(investor, node_type="investor")
G.add_edge(investor, company)
logger.debug(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
return G
def generate_plotly_figure(
G: nx.Graph,
filtered_df: pd.DataFrame,
investor_list: List[str],
show_all_labels: bool,
valuation_range: str,
quick_find: str,
layout_cache: Optional[dict] = None
) -> go.Figure:
"""Generate Plotly figure from graph."""
if G.number_of_nodes() == 0:
return go.Figure().update_layout(
title="No Data",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
annotations=[dict(text=EMPTY_STATE, showarrow=False, font=dict(size=14), x=0.5, y=0.5, xref='paper', yref='paper')]
)
# Layout
iterations = SPRING_LAYOUT_ITERATIONS_SMALL if G.number_of_nodes() < 200 else SPRING_LAYOUT_ITERATIONS_LARGE
if layout_cache and "pos" in layout_cache:
pos = layout_cache["pos"]
logger.debug("Using cached layout")
else:
pos = nx.spring_layout(G, seed=1721, iterations=iterations)
if layout_cache is not None:
layout_cache["pos"] = pos
logger.debug(f"Generated layout with {iterations} iterations")
# Color map for investors
sorted_investors = sorted(investor_list)
investor_color_map = {inv: INVESTOR_COLORS[i % len(INVESTOR_COLORS)] for i, inv in enumerate(sorted_investors)}
# Edges
edge_x, edge_y = [], []
for u, v in G.edges():
x0, y0 = pos[u]
x1, y1 = pos[v]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color=EDGE_COLOR),
hoverinfo='none',
mode='lines',
opacity=EDGE_OPACITY,
showlegend=False
)
# Nodes
node_x, node_y, node_text, node_hovertext = [], [], [], []
node_color, node_size, node_line_color = [], [], []
node_textposition = []
top5_companies = set(filtered_df.nlargest(5, "Valuation_Billions")["Company"].tolist())
show_labels_for_range = valuation_range in ["15-20", "20+"]
num_nodes = G.number_of_nodes()
qf_lower = quick_find.strip().lower()
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_type = G.nodes[node].get("node_type", "company")
if node_type == "investor":
# Investor node
node_text.append(node)
node_color.append(investor_color_map[node])
node_size.append(INVESTOR_SIZE)
node_line_color.append(INVESTOR_STROKE)
node_textposition.append('top center')
portfolio_companies = [comp for comp in G.neighbors(node) if G.nodes[comp].get("node_type") == "company"]
total_cap = filtered_df[filtered_df["Company"].isin(portfolio_companies)]["Valuation_Billions"].sum()
hovertext = f"<b>Investor:</b> {node}<br><b>Portfolio:</b> {len(portfolio_companies)} companies<br><b>Total Cap:</b> ${total_cap:.1f}B"
node_hovertext.append(hovertext)
else:
# Company node
row = filtered_df[filtered_df["Company"] == node]
if row.empty:
# Shouldn't happen, but fallback
node_size.append(NODE_SIZE_MIN)
node_color.append(COMPANY_COLOR)
node_line_color.append(COMPANY_STROKE)
node_text.append("")
node_hovertext.append(f"<b>Company:</b> {node}")
node_textposition.append('bottom center')
continue
valuation = row["Valuation_Billions"].values[0]
industry = row["Industry"].values[0] if "Industry" in row else "N/A"
country = row["Country"].values[0] if "Country" in row else "N/A"
# Size: sqrt-scaled, clamped
size = max(NODE_SIZE_MIN, min(NODE_SIZE_MAX, (valuation ** 0.5) * 8))
node_size.append(size)
node_color.append(COMPANY_COLOR)
node_line_color.append(COMPANY_STROKE)
# Hovertext
investors_str = row["Select_Investors"].values[0]
hovertext = f"<b>Company:</b> {node}<br><b>Industry:</b> {industry}<br><b>Valuation:</b> ${valuation:.1f}B"
if investors_str:
inv_list = [i.strip() for i in investors_str.split(",") if i.strip()]
hovertext += f"<br><b>Investors:</b> {', '.join(inv_list[:5])}"
if len(inv_list) > 5:
hovertext += f" +{len(inv_list)-5} more"
node_hovertext.append(hovertext)
# Label logic
show_label = (
show_all_labels or
show_labels_for_range or
valuation > LARGE_COMPANY_THRESHOLD or
num_nodes < 15 or
node in top5_companies
)
if show_label:
# Bold if top 3
top3 = set(filtered_df.nlargest(3, "Valuation_Billions")["Company"].tolist())
if node in top3:
node_text.append(f"<b>{node}</b>")
else:
node_text.append(node)
else:
node_text.append("")
node_textposition.append('bottom center')
node_trace = go.Scatter(
x=node_x, y=node_y,
text=node_text,
textposition=node_textposition,
mode='markers+text',
hoverinfo='text',
hovertext=node_hovertext,
textfont=dict(size=LABEL_FONT_SIZE),
marker=dict(
size=node_size,
color=node_color,
line=dict(width=1.5, color=node_line_color)
),
showlegend=False
)
# Summary annotation
total_valuation = filtered_df["Valuation_Billions"].sum()
num_investors = len(investor_list)
num_companies = len(filtered_df)
top3 = filtered_df.nlargest(3, "Valuation_Billions")
top3_str = ", ".join([f"{row['Company']} (${row['Valuation_Billions']:.1f}B)" for _, row in top3.iterrows()])
annotation_text = (
f"<b>{num_investors} Investors</b> • <b>{num_companies} Companies</b> • "
f"<b>${total_valuation:,.1f}B Total</b><br>"
f"<i>Top 3:</i> {top3_str}"
)
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
title="",
margin=dict(l=20, r=20, t=80, b=20),
hovermode='closest',
width=1400,
height=900,
xaxis=dict(showgrid=False, zeroline=False, visible=False),
yaxis=dict(showgrid=False, zeroline=False, visible=False),
showlegend=False,
annotations=[
dict(
x=0.5, y=1.05, xref='paper', yref='paper',
text=annotation_text,
showarrow=False,
font=dict(size=13),
xanchor='center',
align='center'
)
],
plot_bgcolor='#ffffff',
paper_bgcolor='#ffffff'
)
return fig
# ============================================================================
# GRADIO APP
# ============================================================================
def main():
# Load data once
try:
data = load_and_clean_data(FILE_PATH)
except Exception as e:
logger.error(f"Failed to load data: {e}")
# Fallback Gradio UI showing error
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(ERROR_FILE.format(path=FILE_PATH) if "not found" in str(e) else str(e))
demo.launch()
return
investor_company_mapping = build_investor_company_mapping(data)
# Prepare dropdown choices
country_list = ["All"] + sorted(data["Country"].dropna().unique())
industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
company_list = ["All"] + sorted(data["Company"].dropna().unique())
investor_list_all = sorted(investor_company_mapping.keys())
# Check if City column exists for Nashville filter
has_city = "City" in data.columns
# State for caching layout
layout_cache_state = gr.State({})
def app_logic(
country, industry, company, investors,
exclude_countries, exclude_industries, exclude_companies, exclude_investors,
valuation_range, node_cap, show_all_labels, quick_find,
layout_cache
):
start = time.time()
# Apply filters
filtered = apply_filters(
data, country, industry, company, investors,
exclude_countries, exclude_industries, exclude_companies, exclude_investors,
valuation_range, quick_find
)
if filtered.empty:
logger.warning("No results after filtering")
empty_fig = go.Figure().update_layout(
xaxis=dict(visible=False), yaxis=dict(visible=False),
annotations=[dict(text=EMPTY_STATE, showarrow=False, font=dict(size=14), x=0.5, y=0.5, xref='paper', yref='paper')]
)
subtitle = "No results"
return empty_fig, subtitle, "", layout_cache
# Cap companies
original_count = len(filtered)
filtered, was_truncated = cap_companies(filtered, node_cap)
# Build investor list from filtered data
filtered_inv_mapping = build_investor_company_mapping(filtered)
current_investors = list(filtered_inv_mapping.keys())
# Build graph
G = build_graph(filtered, current_investors, show_all_labels, valuation_range, quick_find)
# Generate figure
# Invalidate layout cache if node set changed
current_nodes = set(G.nodes())
if layout_cache.get("nodes") != current_nodes:
layout_cache = {"nodes": current_nodes}
fig = generate_plotly_figure(G, filtered, current_investors, show_all_labels, valuation_range, quick_find, layout_cache)
# Subtitle
subtitle = SUBTITLE_TEMPLATE.format(
country=country,
industry=industry,
valuation_range=valuation_range,
count=len(filtered)
)
# Truncation notice
notice = ""
if was_truncated:
notice = TRUNCATION_NOTICE.format(cap=node_cap, total=original_count)
elapsed = time.time() - start
logger.info(f"Render complete in {elapsed:.2f}s: {len(filtered)} companies, {len(current_investors)} investors")
return fig, subtitle, notice, layout_cache
def apply_nashville_filter():
"""Pre-fill Nashville filter."""
if has_city:
return "United States", gr.update(), gr.update(), gr.update() # Set country, others unchanged
else:
logger.warning("City column not found; Nashville filter only sets Country")
return "United States", gr.update(), gr.update(), gr.update()
def clear_all():
"""Reset all filters to default."""
return (
"All", "All", "All", [], # Include filters
[], [], [], [], # Exclude filters
"All", DEFAULT_NODE_CAP, False, "" # Valuation, node cap, labels, quick find
)
def clear_exclusions():
"""Clear only exclusion filters."""
return [], [], [], []
with gr.Blocks(title=f"{TITLE} ({DATA_TIMESTAMP})", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(f"*Updated {DATA_TIMESTAMP}*")
subtitle_display = gr.Markdown("Active Scope: All • All • All • 0 companies")
gr.Markdown(INSTRUCTIONS)
with gr.Row():
nashville_btn = gr.Button("🎯 Nashville Filter", variant="secondary", size="sm")
clear_all_btn = gr.Button("Clear All Filters", variant="secondary", size="sm")
clear_excl_btn = gr.Button("Clear Exclusions", variant="secondary", size="sm")
with gr.Row():
show_labels_toggle = gr.Checkbox(label="Show All Company Labels", value=False)
quick_find_box = gr.Textbox(label="Quick Find (Investor or Company)", placeholder="Type to search...", scale=2)
with gr.Accordion("Include Filters", open=True):
with gr.Row():
country_filter = gr.Dropdown(choices=country_list, label="Country", value="All")
industry_filter = gr.Dropdown(choices=industry_list, label="Industry", value="All")
company_filter = gr.Dropdown(choices=company_list, label="Company", value="All")
with gr.Row():
investor_filter = gr.Dropdown(choices=investor_list_all, label="Select Investors", value=[], multiselect=True)
valuation_filter = gr.Dropdown(choices=VALUATION_RANGES, label="Valuation Range (Billions)", value="All")
node_cap_slider = gr.Slider(minimum=50, maximum=1000, step=50, value=DEFAULT_NODE_CAP, label="Max Companies to Display")
with gr.Accordion("Exclude Filters", open=False):
with gr.Row():
exclude_country = gr.Dropdown(choices=country_list[1:], label="Exclude Countries", value=[], multiselect=True)
exclude_industry = gr.Dropdown(choices=industry_list[1:], label="Exclude Industries", value=[], multiselect=True)
exclude_company = gr.Dropdown(choices=company_list[1:], label="Exclude Companies", value=[], multiselect=True)
exclude_investor = gr.Dropdown(choices=investor_list_all, label="Exclude Investors", value=[], multiselect=True)
truncation_notice = gr.Markdown("")
graph_output = gr.Plot(label="Network Graph")
with gr.Row():
reset_view_btn = gr.Button("Reset View", variant="secondary", size="sm")
download_csv_btn = gr.Button("Download Filtered CSV", variant="primary", size="sm")
# State
layout_cache = gr.State({})
# Inputs and outputs
inputs = [
country_filter, industry_filter, company_filter, investor_filter,
exclude_country, exclude_industry, exclude_company, exclude_investor,
valuation_filter, node_cap_slider, show_labels_toggle, quick_find_box,
layout_cache
]
outputs = [graph_output, subtitle_display, truncation_notice, layout_cache]
# Event handlers (debounced via Gradio's built-in; for older versions, use time.sleep trick)
for control in inputs[:-1]: # Exclude layout_cache from triggers
control.change(app_logic, inputs, outputs)
# Button actions
nashville_btn.click(
apply_nashville_filter,
inputs=None,
outputs=[country_filter, industry_filter, company_filter, investor_filter]
).then(app_logic, inputs, outputs)
clear_all_btn.click(
clear_all,
inputs=None,
outputs=[
country_filter, industry_filter, company_filter, investor_filter,
exclude_country, exclude_industry, exclude_company, exclude_investor,
valuation_filter, node_cap_slider, show_labels_toggle, quick_find_box
]
).then(app_logic, inputs, outputs)
clear_excl_btn.click(
clear_exclusions,
inputs=None,
outputs=[exclude_country, exclude_industry, exclude_company, exclude_investor]
).then(app_logic, inputs, outputs)
reset_view_btn.click(
lambda: (gr.update(), gr.update(), "", {}), # Clear quick_find and layout cache
inputs=None,
outputs=[graph_output, subtitle_display, quick_find_box, layout_cache]
)
# Download CSV (requires Gradio >=3.x File component; here we provide a placeholder)
def export_csv(
country, industry, company, investors,
exclude_countries, exclude_industries, exclude_companies, exclude_investors,
valuation_range, node_cap, show_all_labels, quick_find, layout_cache
):
filtered = apply_filters(
data, country, industry, company, investors,
exclude_countries, exclude_industries, exclude_companies, exclude_investors,
valuation_range, quick_find
)
filtered, _ = cap_companies(filtered, node_cap)
csv_path = "filtered_unicorns.csv"
filtered.to_csv(csv_path, index=False)
logger.info(f"Exported {len(filtered)} rows to {csv_path}")
return csv_path
download_csv_btn.click(
export_csv,
inputs=inputs,
outputs=gr.File(label="Download CSV")
)
gr.Markdown("""
---
**Accessibility:** Use Tab to navigate controls. Press Enter to activate buttons. Graph nodes are keyboard-focusable.
**Color Legend:** Companies are teal-green. Investors are color-coded (see palette). Non-color cues: stroke outlines differentiate node types.
**Performance:** Graphs update in <500ms for ≤300 companies. Large datasets are auto-capped; adjust slider as needed.
""")
# Initial render
demo.load(app_logic, inputs, outputs)
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()