LeonceNsh commited on
Commit
f5f328a
·
verified ·
1 Parent(s): 00607ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -91
app.py CHANGED
@@ -8,40 +8,40 @@ from typing import List, Dict, Tuple, Optional
8
  from functools import lru_cache
9
  import time
10
 
11
- ============================================================================
12
- CONFIGURATION
13
- ============================================================================
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
  FILE_PATH = "cbinsights_data.csv"
18
- DATA_TIMESTAMP = "2024-09" Update manually or parse from filename
19
 
20
- UI Copy
21
  TITLE = "Venture Networks Visualization"
22
  SUBTITLE_TEMPLATE = "Active: {country} • {industry} • {valuation_range} • {count} companies"
23
  INSTRUCTIONS = """
24
- How to use:
25
- 1. Filter by Country, Industry, Company, Investor, and Valuation Range
26
- 2. Hover over nodes to see details • Click a node to focus and view full information
27
- 3. Download the filtered dataset as CSV • Use Nashville Filter for local quick access
28
  """
29
 
30
  EMPTY_STATE = """
31
- No results match your filters.
32
- Try: Clearing exclusions • Expanding valuation range • Selecting "All" for Country or Industry
33
  """
34
 
35
- ERROR_VALUATION = "Data Error: Could not identify a single valuation column. Found: {columns}"
36
- ERROR_FILE = "File Error: Dataset not found at `{path}`. Ensure `cbinsights_data.csv` is in the working directory."
37
- TRUNCATION_NOTICE = "Notice: Showing top {cap} of {total} companies by valuation. Adjust slider or refine filters."
38
 
39
- Graph Design
40
- COMPANY_COLOR = "66c2a5"
41
- COMPANY_STROKE = "2d6a4f"
42
- INVESTOR_STROKE = "000000"
43
- INVESTOR_COLORS = ["E69F00", "56B4E9", "009E73", "F0E442", "0072B2", "D55E00", "CC79A7", "999999"]
44
- EDGE_COLOR = "cccccc"
45
  EDGE_OPACITY = 0.6
46
 
47
  NODE_SIZE_MIN = 10
@@ -49,7 +49,7 @@ NODE_SIZE_MAX = 60
49
  INVESTOR_SIZE = 36
50
  LABEL_FONT_SIZE = 11
51
  INVESTOR_LABEL_FONT_SIZE = 12
52
- LARGE_COMPANY_THRESHOLD = 10 Show labels for valuations >10B
53
 
54
  DEFAULT_NODE_CAP = 300
55
  SPRING_LAYOUT_ITERATIONS_SMALL = 150
@@ -59,9 +59,9 @@ DEBOUNCE_MS = 250
59
  VALUATION_RANGES = ["All", "1-5", "5-10", "10-15", "15-20", "20+"]
60
 
61
 
62
- ============================================================================
63
- DATA LOADING AND PREPROCESSING
64
- ============================================================================
65
  def load_and_clean_data(file_path: str) -> pd.DataFrame:
66
  """Load CSV, standardize columns, filter Health, parse valuation."""
67
  try:
@@ -74,18 +74,18 @@ def load_and_clean_data(file_path: str) -> pd.DataFrame:
74
  logger.error(f"Error loading CSV: {e}")
75
  raise
76
 
77
- Standardize columns
78
  data.columns = data.columns.str.strip().str.lower()
79
  logger.info(f"Columns: {data.columns.tolist()}")
80
 
81
- Identify valuation column
82
  val_cols = [col for col in data.columns if 'valuation' in col]
83
  if len(val_cols) != 1:
84
  logger.error(f"Expected 1 valuation column, found {len(val_cols)}: {val_cols}")
85
  raise ValueError(ERROR_VALUATION.format(columns=val_cols))
86
  val_col = val_cols[0]
87
 
88
- Clean valuation
89
  data["Valuation_Billions"] = (
90
  data[val_col]
91
  .astype(str)
@@ -94,7 +94,7 @@ def load_and_clean_data(file_path: str) -> pd.DataFrame:
94
  )
95
  data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce').fillna(0)
96
 
97
- Rename columns
98
  rename_map = {
99
  "company": "Company",
100
  "date_joined": "Date_Joined",
@@ -105,15 +105,15 @@ def load_and_clean_data(file_path: str) -> pd.DataFrame:
105
  }
106
  data.rename(columns=rename_map, inplace=True)
107
 
108
- Strip whitespace
109
  for col in data.select_dtypes(include='object').columns:
110
  data[col] = data[col].str.strip()
111
 
112
- Filter out "Health" (case-insensitive); keep "Healthcare"
113
  data = data[~data["Industry"].str.lower().isin(['health'])]
114
  logger.info(f"After filtering 'Health': {len(data)} rows")
115
 
116
- Fill missing Select_Investors
117
  data["Select_Investors"] = data["Select_Investors"].fillna("")
118
 
119
  return data
@@ -134,9 +134,9 @@ def build_investor_company_mapping(df: pd.DataFrame) -> Dict[str, List[str]]:
134
  return mapping
135
 
136
 
137
- ============================================================================
138
- FILTERING LOGIC
139
- ============================================================================
140
  def filter_by_valuation_range(df: pd.DataFrame, selected_range: str) -> pd.DataFrame:
141
  """Filter dataframe by valuation range (billions)."""
142
  if selected_range == "All":
@@ -170,10 +170,10 @@ def apply_filters(
170
  """Apply all inclusion and exclusion filters."""
171
  filtered = df.copy()
172
 
173
- Valuation range
174
  filtered = filter_by_valuation_range(filtered, valuation_range)
175
 
176
- Include filters
177
  if country != "All":
178
  filtered = filtered[filtered["Country"] == country]
179
  if industry != "All":
@@ -181,11 +181,11 @@ def apply_filters(
181
  if company != "All":
182
  filtered = filtered[filtered["Company"] == company]
183
  if investors:
184
- Exact token match: split Select_Investors and check membership
185
  pattern = '|'.join([re.escape(inv) for inv in investors])
186
  filtered = filtered[filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)]
187
 
188
- Exclude filters
189
  if exclude_countries:
190
  filtered = filtered[~filtered["Country"].isin(exclude_countries)]
191
  if exclude_industries:
@@ -196,8 +196,8 @@ def apply_filters(
196
  pattern = '|'.join([re.escape(inv) for inv in exclude_investors])
197
  filtered = filtered[~filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)]
198
 
199
- Quick find (highlight only; filter applied in graph rendering)
200
- For filtering, we match Company or any investor token
201
  if quick_find.strip():
202
  qf = quick_find.strip()
203
  mask = (
@@ -219,9 +219,9 @@ def cap_companies(df: pd.DataFrame, cap: int) -> Tuple[pd.DataFrame, bool]:
219
  return capped, True
220
 
221
 
222
- ============================================================================
223
- GRAPH GENERATION
224
- ============================================================================
225
  def build_graph(
226
  filtered_df: pd.DataFrame,
227
  investor_list: List[str],
@@ -261,7 +261,7 @@ def generate_plotly_figure(
261
  annotations=[dict(text=EMPTY_STATE, showarrow=False, font=dict(size=14), x=0.5, y=0.5, xref='paper', yref='paper')]
262
  )
263
 
264
- Layout
265
  iterations = SPRING_LAYOUT_ITERATIONS_SMALL if G.number_of_nodes() < 200 else SPRING_LAYOUT_ITERATIONS_LARGE
266
  if layout_cache and "pos" in layout_cache:
267
  pos = layout_cache["pos"]
@@ -272,11 +272,11 @@ def generate_plotly_figure(
272
  layout_cache["pos"] = pos
273
  logger.debug(f"Generated layout with {iterations} iterations")
274
 
275
- Color map for investors
276
  sorted_investors = sorted(investor_list)
277
  investor_color_map = {inv: INVESTOR_COLORS[i % len(INVESTOR_COLORS)] for i, inv in enumerate(sorted_investors)}
278
 
279
- Edges
280
  edge_x, edge_y = [], []
281
  for u, v in G.edges():
282
  x0, y0 = pos[u]
@@ -293,7 +293,7 @@ def generate_plotly_figure(
293
  showlegend=False
294
  )
295
 
296
- Nodes
297
  node_x, node_y, node_text, node_hovertext = [], [], [], []
298
  node_color, node_size, node_line_color = [], [], []
299
  node_textposition = []
@@ -311,7 +311,7 @@ def generate_plotly_figure(
311
  node_type = G.nodes[node].get("node_type", "company")
312
 
313
  if node_type == "investor":
314
- Investor node
315
  node_text.append(node)
316
  node_color.append(investor_color_map[node])
317
  node_size.append(INVESTOR_SIZE)
@@ -322,10 +322,10 @@ def generate_plotly_figure(
322
  hovertext = f"<b>Investor:</b> {node}<br><b>Portfolio:</b> {len(portfolio_companies)} companies<br><b>Total Cap:</b> ${total_cap:.1f}B"
323
  node_hovertext.append(hovertext)
324
  else:
325
- Company node
326
  row = filtered_df[filtered_df["Company"] == node]
327
  if row.empty:
328
- Shouldn't happen, but fallback
329
  node_size.append(NODE_SIZE_MIN)
330
  node_color.append(COMPANY_COLOR)
331
  node_line_color.append(COMPANY_STROKE)
@@ -338,13 +338,13 @@ def generate_plotly_figure(
338
  industry = row["Industry"].values[0] if "Industry" in row else "N/A"
339
  country = row["Country"].values[0] if "Country" in row else "N/A"
340
 
341
- Size: sqrt-scaled, clamped
342
- size = max(NODE_SIZE_MIN, min(NODE_SIZE_MAX, (valuation 0.5) * 8))
343
  node_size.append(size)
344
  node_color.append(COMPANY_COLOR)
345
  node_line_color.append(COMPANY_STROKE)
346
 
347
- Hovertext
348
  investors_str = row["Select_Investors"].values[0]
349
  hovertext = f"<b>Company:</b> {node}<br><b>Industry:</b> {industry}<br><b>Valuation:</b> ${valuation:.1f}B"
350
  if investors_str:
@@ -354,7 +354,7 @@ def generate_plotly_figure(
354
  hovertext += f" +{len(inv_list)-5} more"
355
  node_hovertext.append(hovertext)
356
 
357
- Label logic
358
  show_label = (
359
  show_all_labels or
360
  show_labels_for_range or
@@ -363,7 +363,7 @@ def generate_plotly_figure(
363
  node in top5_companies
364
  )
365
  if show_label:
366
- Bold if top 3
367
  top3 = set(filtered_df.nlargest(3, "Valuation_Billions")["Company"].tolist())
368
  if node in top3:
369
  node_text.append(f"<b>{node}</b>")
@@ -389,7 +389,7 @@ def generate_plotly_figure(
389
  showlegend=False
390
  )
391
 
392
- Summary annotation
393
  total_valuation = filtered_df["Valuation_Billions"].sum()
394
  num_investors = len(investor_list)
395
  num_companies = len(filtered_df)
@@ -422,41 +422,41 @@ def generate_plotly_figure(
422
  align='center'
423
  )
424
  ],
425
- plot_bgcolor='ffffff',
426
- paper_bgcolor='ffffff'
427
  )
428
 
429
  return fig
430
 
431
 
432
- ============================================================================
433
- GRADIO APP
434
- ============================================================================
435
  def main():
436
- Load data once
437
  try:
438
  data = load_and_clean_data(FILE_PATH)
439
  except Exception as e:
440
  logger.error(f"Failed to load data: {e}")
441
- Fallback Gradio UI showing error
442
  with gr.Blocks(title=TITLE) as demo:
443
- gr.Markdown(f" {TITLE}")
444
  gr.Markdown(ERROR_FILE.format(path=FILE_PATH) if "not found" in str(e) else str(e))
445
  demo.launch()
446
  return
447
 
448
  investor_company_mapping = build_investor_company_mapping(data)
449
 
450
- Prepare dropdown choices
451
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
452
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
453
  company_list = ["All"] + sorted(data["Company"].dropna().unique())
454
  investor_list_all = sorted(investor_company_mapping.keys())
455
 
456
- Check if City column exists for Nashville filter
457
  has_city = "City" in data.columns
458
 
459
- State for caching layout
460
  layout_cache_state = gr.State({})
461
 
462
  def app_logic(
@@ -467,7 +467,7 @@ def main():
467
  ):
468
  start = time.time()
469
 
470
- Apply filters
471
  filtered = apply_filters(
472
  data, country, industry, company, investors,
473
  exclude_countries, exclude_industries, exclude_companies, exclude_investors,
@@ -483,25 +483,25 @@ def main():
483
  subtitle = "No results"
484
  return empty_fig, subtitle, "", layout_cache
485
 
486
- Cap companies
487
  original_count = len(filtered)
488
  filtered, was_truncated = cap_companies(filtered, node_cap)
489
 
490
- Build investor list from filtered data
491
  filtered_inv_mapping = build_investor_company_mapping(filtered)
492
  current_investors = list(filtered_inv_mapping.keys())
493
 
494
- Build graph
495
  G = build_graph(filtered, current_investors, show_all_labels, valuation_range, quick_find)
496
 
497
- Generate figure
498
- Invalidate layout cache if node set changed
499
  current_nodes = set(G.nodes())
500
  if layout_cache.get("nodes") != current_nodes:
501
  layout_cache = {"nodes": current_nodes}
502
  fig = generate_plotly_figure(G, filtered, current_investors, show_all_labels, valuation_range, quick_find, layout_cache)
503
 
504
- Subtitle
505
  subtitle = SUBTITLE_TEMPLATE.format(
506
  country=country,
507
  industry=industry,
@@ -509,7 +509,7 @@ def main():
509
  count=len(filtered)
510
  )
511
 
512
- Truncation notice
513
  notice = ""
514
  if was_truncated:
515
  notice = TRUNCATION_NOTICE.format(cap=node_cap, total=original_count)
@@ -522,7 +522,7 @@ def main():
522
  def apply_nashville_filter():
523
  """Pre-fill Nashville filter."""
524
  if has_city:
525
- return "United States", gr.update(), gr.update(), gr.update() Set country, others unchanged
526
  else:
527
  logger.warning("City column not found; Nashville filter only sets Country")
528
  return "United States", gr.update(), gr.update(), gr.update()
@@ -530,9 +530,9 @@ def main():
530
  def clear_all():
531
  """Reset all filters to default."""
532
  return (
533
- "All", "All", "All", [], Include filters
534
- [], [], [], [], Exclude filters
535
- "All", DEFAULT_NODE_CAP, False, "" Valuation, node cap, labels, quick find
536
  )
537
 
538
  def clear_exclusions():
@@ -540,7 +540,7 @@ def main():
540
  return [], [], [], []
541
 
542
  with gr.Blocks(title=f"{TITLE} ({DATA_TIMESTAMP})", theme=gr.themes.Soft()) as demo:
543
- gr.Markdown(f" {TITLE}")
544
  gr.Markdown(f"*Updated {DATA_TIMESTAMP}*")
545
 
546
  subtitle_display = gr.Markdown("Active Scope: All • All • All • 0 companies")
@@ -580,10 +580,10 @@ def main():
580
  reset_view_btn = gr.Button("Reset View", variant="secondary", size="sm")
581
  download_csv_btn = gr.Button("Download Filtered CSV", variant="primary", size="sm")
582
 
583
- State
584
  layout_cache = gr.State({})
585
 
586
- Inputs and outputs
587
  inputs = [
588
  country_filter, industry_filter, company_filter, investor_filter,
589
  exclude_country, exclude_industry, exclude_company, exclude_investor,
@@ -592,11 +592,11 @@ def main():
592
  ]
593
  outputs = [graph_output, subtitle_display, truncation_notice, layout_cache]
594
 
595
- Event handlers (debounced via Gradio's built-in; for older versions, use time.sleep trick)
596
- for control in inputs[:-1]: Exclude layout_cache from triggers
597
  control.change(app_logic, inputs, outputs)
598
 
599
- Button actions
600
  nashville_btn.click(
601
  apply_nashville_filter,
602
  inputs=None,
@@ -620,12 +620,12 @@ def main():
620
  ).then(app_logic, inputs, outputs)
621
 
622
  reset_view_btn.click(
623
- lambda: (gr.update(), gr.update(), "", {}), Clear quick_find and layout cache
624
  inputs=None,
625
  outputs=[graph_output, subtitle_display, quick_find_box, layout_cache]
626
  )
627
 
628
- Download CSV (requires Gradio >=3.x File component; here we provide a placeholder)
629
  def export_csv(
630
  country, industry, company, investors,
631
  exclude_countries, exclude_industries, exclude_companies, exclude_investors,
@@ -650,12 +650,12 @@ def main():
650
 
651
  gr.Markdown("""
652
  ---
653
- Accessibility: Use Tab to navigate controls. Press Enter to activate buttons. Graph nodes are keyboard-focusable.
654
- Color Legend: Companies are teal-green. Investors are color-coded (see palette). Non-color cues: stroke outlines differentiate node types.
655
- Performance: Graphs update in <500ms for ≤300 companies. Large datasets are auto-capped; adjust slider as needed.
656
  """)
657
 
658
- Initial render
659
  demo.load(app_logic, inputs, outputs)
660
 
661
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
 
8
  from functools import lru_cache
9
  import time
10
 
11
+ # ============================================================================
12
+ # CONFIGURATION
13
+ # ============================================================================
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
16
 
17
  FILE_PATH = "cbinsights_data.csv"
18
+ DATA_TIMESTAMP = "2024-09" # Update manually or parse from filename
19
 
20
+ # UI Copy
21
  TITLE = "Venture Networks Visualization"
22
  SUBTITLE_TEMPLATE = "Active: {country} • {industry} • {valuation_range} • {count} companies"
23
  INSTRUCTIONS = """
24
+ **How to use:**
25
+ 1. **Filter** by Country, Industry, Company, Investor, and Valuation Range
26
+ 2. **Hover** over nodes to see details • **Click** a node to focus and view full information
27
+ 3. **Download** the filtered dataset as CSV • Use **Nashville Filter** for local quick access
28
  """
29
 
30
  EMPTY_STATE = """
31
+ ### No results match your filters.
32
+ **Try:** Clearing exclusions • Expanding valuation range • Selecting "All" for Country or Industry
33
  """
34
 
35
+ ERROR_VALUATION = "**Data Error:** Could not identify a single valuation column. Found: {columns}"
36
+ ERROR_FILE = "**File Error:** Dataset not found at `{path}`. Ensure `cbinsights_data.csv` is in the working directory."
37
+ TRUNCATION_NOTICE = "**Notice:** Showing top {cap} of {total} companies by valuation. Adjust slider or refine filters."
38
 
39
+ # Graph Design
40
+ COMPANY_COLOR = "#66c2a5"
41
+ COMPANY_STROKE = "#2d6a4f"
42
+ INVESTOR_STROKE = "#000000"
43
+ INVESTOR_COLORS = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999"]
44
+ EDGE_COLOR = "#cccccc"
45
  EDGE_OPACITY = 0.6
46
 
47
  NODE_SIZE_MIN = 10
 
49
  INVESTOR_SIZE = 36
50
  LABEL_FONT_SIZE = 11
51
  INVESTOR_LABEL_FONT_SIZE = 12
52
+ LARGE_COMPANY_THRESHOLD = 10 # Show labels for valuations >10B
53
 
54
  DEFAULT_NODE_CAP = 300
55
  SPRING_LAYOUT_ITERATIONS_SMALL = 150
 
59
  VALUATION_RANGES = ["All", "1-5", "5-10", "10-15", "15-20", "20+"]
60
 
61
 
62
+ # ============================================================================
63
+ # DATA LOADING AND PREPROCESSING
64
+ # ============================================================================
65
  def load_and_clean_data(file_path: str) -> pd.DataFrame:
66
  """Load CSV, standardize columns, filter Health, parse valuation."""
67
  try:
 
74
  logger.error(f"Error loading CSV: {e}")
75
  raise
76
 
77
+ # Standardize columns
78
  data.columns = data.columns.str.strip().str.lower()
79
  logger.info(f"Columns: {data.columns.tolist()}")
80
 
81
+ # Identify valuation column
82
  val_cols = [col for col in data.columns if 'valuation' in col]
83
  if len(val_cols) != 1:
84
  logger.error(f"Expected 1 valuation column, found {len(val_cols)}: {val_cols}")
85
  raise ValueError(ERROR_VALUATION.format(columns=val_cols))
86
  val_col = val_cols[0]
87
 
88
+ # Clean valuation
89
  data["Valuation_Billions"] = (
90
  data[val_col]
91
  .astype(str)
 
94
  )
95
  data["Valuation_Billions"] = pd.to_numeric(data["Valuation_Billions"], errors='coerce').fillna(0)
96
 
97
+ # Rename columns
98
  rename_map = {
99
  "company": "Company",
100
  "date_joined": "Date_Joined",
 
105
  }
106
  data.rename(columns=rename_map, inplace=True)
107
 
108
+ # Strip whitespace
109
  for col in data.select_dtypes(include='object').columns:
110
  data[col] = data[col].str.strip()
111
 
112
+ # Filter out "Health" (case-insensitive); keep "Healthcare"
113
  data = data[~data["Industry"].str.lower().isin(['health'])]
114
  logger.info(f"After filtering 'Health': {len(data)} rows")
115
 
116
+ # Fill missing Select_Investors
117
  data["Select_Investors"] = data["Select_Investors"].fillna("")
118
 
119
  return data
 
134
  return mapping
135
 
136
 
137
+ # ============================================================================
138
+ # FILTERING LOGIC
139
+ # ============================================================================
140
  def filter_by_valuation_range(df: pd.DataFrame, selected_range: str) -> pd.DataFrame:
141
  """Filter dataframe by valuation range (billions)."""
142
  if selected_range == "All":
 
170
  """Apply all inclusion and exclusion filters."""
171
  filtered = df.copy()
172
 
173
+ # Valuation range
174
  filtered = filter_by_valuation_range(filtered, valuation_range)
175
 
176
+ # Include filters
177
  if country != "All":
178
  filtered = filtered[filtered["Country"] == country]
179
  if industry != "All":
 
181
  if company != "All":
182
  filtered = filtered[filtered["Company"] == company]
183
  if investors:
184
+ # Exact token match: split Select_Investors and check membership
185
  pattern = '|'.join([re.escape(inv) for inv in investors])
186
  filtered = filtered[filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)]
187
 
188
+ # Exclude filters
189
  if exclude_countries:
190
  filtered = filtered[~filtered["Country"].isin(exclude_countries)]
191
  if exclude_industries:
 
196
  pattern = '|'.join([re.escape(inv) for inv in exclude_investors])
197
  filtered = filtered[~filtered["Select_Investors"].str.contains(pattern, case=False, na=False, regex=True)]
198
 
199
+ # Quick find (highlight only; filter applied in graph rendering)
200
+ # For filtering, we match Company or any investor token
201
  if quick_find.strip():
202
  qf = quick_find.strip()
203
  mask = (
 
219
  return capped, True
220
 
221
 
222
+ # ============================================================================
223
+ # GRAPH GENERATION
224
+ # ============================================================================
225
  def build_graph(
226
  filtered_df: pd.DataFrame,
227
  investor_list: List[str],
 
261
  annotations=[dict(text=EMPTY_STATE, showarrow=False, font=dict(size=14), x=0.5, y=0.5, xref='paper', yref='paper')]
262
  )
263
 
264
+ # Layout
265
  iterations = SPRING_LAYOUT_ITERATIONS_SMALL if G.number_of_nodes() < 200 else SPRING_LAYOUT_ITERATIONS_LARGE
266
  if layout_cache and "pos" in layout_cache:
267
  pos = layout_cache["pos"]
 
272
  layout_cache["pos"] = pos
273
  logger.debug(f"Generated layout with {iterations} iterations")
274
 
275
+ # Color map for investors
276
  sorted_investors = sorted(investor_list)
277
  investor_color_map = {inv: INVESTOR_COLORS[i % len(INVESTOR_COLORS)] for i, inv in enumerate(sorted_investors)}
278
 
279
+ # Edges
280
  edge_x, edge_y = [], []
281
  for u, v in G.edges():
282
  x0, y0 = pos[u]
 
293
  showlegend=False
294
  )
295
 
296
+ # Nodes
297
  node_x, node_y, node_text, node_hovertext = [], [], [], []
298
  node_color, node_size, node_line_color = [], [], []
299
  node_textposition = []
 
311
  node_type = G.nodes[node].get("node_type", "company")
312
 
313
  if node_type == "investor":
314
+ # Investor node
315
  node_text.append(node)
316
  node_color.append(investor_color_map[node])
317
  node_size.append(INVESTOR_SIZE)
 
322
  hovertext = f"<b>Investor:</b> {node}<br><b>Portfolio:</b> {len(portfolio_companies)} companies<br><b>Total Cap:</b> ${total_cap:.1f}B"
323
  node_hovertext.append(hovertext)
324
  else:
325
+ # Company node
326
  row = filtered_df[filtered_df["Company"] == node]
327
  if row.empty:
328
+ # Shouldn't happen, but fallback
329
  node_size.append(NODE_SIZE_MIN)
330
  node_color.append(COMPANY_COLOR)
331
  node_line_color.append(COMPANY_STROKE)
 
338
  industry = row["Industry"].values[0] if "Industry" in row else "N/A"
339
  country = row["Country"].values[0] if "Country" in row else "N/A"
340
 
341
+ # Size: sqrt-scaled, clamped
342
+ size = max(NODE_SIZE_MIN, min(NODE_SIZE_MAX, (valuation ** 0.5) * 8))
343
  node_size.append(size)
344
  node_color.append(COMPANY_COLOR)
345
  node_line_color.append(COMPANY_STROKE)
346
 
347
+ # Hovertext
348
  investors_str = row["Select_Investors"].values[0]
349
  hovertext = f"<b>Company:</b> {node}<br><b>Industry:</b> {industry}<br><b>Valuation:</b> ${valuation:.1f}B"
350
  if investors_str:
 
354
  hovertext += f" +{len(inv_list)-5} more"
355
  node_hovertext.append(hovertext)
356
 
357
+ # Label logic
358
  show_label = (
359
  show_all_labels or
360
  show_labels_for_range or
 
363
  node in top5_companies
364
  )
365
  if show_label:
366
+ # Bold if top 3
367
  top3 = set(filtered_df.nlargest(3, "Valuation_Billions")["Company"].tolist())
368
  if node in top3:
369
  node_text.append(f"<b>{node}</b>")
 
389
  showlegend=False
390
  )
391
 
392
+ # Summary annotation
393
  total_valuation = filtered_df["Valuation_Billions"].sum()
394
  num_investors = len(investor_list)
395
  num_companies = len(filtered_df)
 
422
  align='center'
423
  )
424
  ],
425
+ plot_bgcolor='#ffffff',
426
+ paper_bgcolor='#ffffff'
427
  )
428
 
429
  return fig
430
 
431
 
432
+ # ============================================================================
433
+ # GRADIO APP
434
+ # ============================================================================
435
  def main():
436
+ # Load data once
437
  try:
438
  data = load_and_clean_data(FILE_PATH)
439
  except Exception as e:
440
  logger.error(f"Failed to load data: {e}")
441
+ # Fallback Gradio UI showing error
442
  with gr.Blocks(title=TITLE) as demo:
443
+ gr.Markdown(f"# {TITLE}")
444
  gr.Markdown(ERROR_FILE.format(path=FILE_PATH) if "not found" in str(e) else str(e))
445
  demo.launch()
446
  return
447
 
448
  investor_company_mapping = build_investor_company_mapping(data)
449
 
450
+ # Prepare dropdown choices
451
  country_list = ["All"] + sorted(data["Country"].dropna().unique())
452
  industry_list = ["All"] + sorted(data["Industry"].dropna().unique())
453
  company_list = ["All"] + sorted(data["Company"].dropna().unique())
454
  investor_list_all = sorted(investor_company_mapping.keys())
455
 
456
+ # Check if City column exists for Nashville filter
457
  has_city = "City" in data.columns
458
 
459
+ # State for caching layout
460
  layout_cache_state = gr.State({})
461
 
462
  def app_logic(
 
467
  ):
468
  start = time.time()
469
 
470
+ # Apply filters
471
  filtered = apply_filters(
472
  data, country, industry, company, investors,
473
  exclude_countries, exclude_industries, exclude_companies, exclude_investors,
 
483
  subtitle = "No results"
484
  return empty_fig, subtitle, "", layout_cache
485
 
486
+ # Cap companies
487
  original_count = len(filtered)
488
  filtered, was_truncated = cap_companies(filtered, node_cap)
489
 
490
+ # Build investor list from filtered data
491
  filtered_inv_mapping = build_investor_company_mapping(filtered)
492
  current_investors = list(filtered_inv_mapping.keys())
493
 
494
+ # Build graph
495
  G = build_graph(filtered, current_investors, show_all_labels, valuation_range, quick_find)
496
 
497
+ # Generate figure
498
+ # Invalidate layout cache if node set changed
499
  current_nodes = set(G.nodes())
500
  if layout_cache.get("nodes") != current_nodes:
501
  layout_cache = {"nodes": current_nodes}
502
  fig = generate_plotly_figure(G, filtered, current_investors, show_all_labels, valuation_range, quick_find, layout_cache)
503
 
504
+ # Subtitle
505
  subtitle = SUBTITLE_TEMPLATE.format(
506
  country=country,
507
  industry=industry,
 
509
  count=len(filtered)
510
  )
511
 
512
+ # Truncation notice
513
  notice = ""
514
  if was_truncated:
515
  notice = TRUNCATION_NOTICE.format(cap=node_cap, total=original_count)
 
522
  def apply_nashville_filter():
523
  """Pre-fill Nashville filter."""
524
  if has_city:
525
+ return "United States", gr.update(), gr.update(), gr.update() # Set country, others unchanged
526
  else:
527
  logger.warning("City column not found; Nashville filter only sets Country")
528
  return "United States", gr.update(), gr.update(), gr.update()
 
530
  def clear_all():
531
  """Reset all filters to default."""
532
  return (
533
+ "All", "All", "All", [], # Include filters
534
+ [], [], [], [], # Exclude filters
535
+ "All", DEFAULT_NODE_CAP, False, "" # Valuation, node cap, labels, quick find
536
  )
537
 
538
  def clear_exclusions():
 
540
  return [], [], [], []
541
 
542
  with gr.Blocks(title=f"{TITLE} ({DATA_TIMESTAMP})", theme=gr.themes.Soft()) as demo:
543
+ gr.Markdown(f"# {TITLE}")
544
  gr.Markdown(f"*Updated {DATA_TIMESTAMP}*")
545
 
546
  subtitle_display = gr.Markdown("Active Scope: All • All • All • 0 companies")
 
580
  reset_view_btn = gr.Button("Reset View", variant="secondary", size="sm")
581
  download_csv_btn = gr.Button("Download Filtered CSV", variant="primary", size="sm")
582
 
583
+ # State
584
  layout_cache = gr.State({})
585
 
586
+ # Inputs and outputs
587
  inputs = [
588
  country_filter, industry_filter, company_filter, investor_filter,
589
  exclude_country, exclude_industry, exclude_company, exclude_investor,
 
592
  ]
593
  outputs = [graph_output, subtitle_display, truncation_notice, layout_cache]
594
 
595
+ # Event handlers (debounced via Gradio's built-in; for older versions, use time.sleep trick)
596
+ for control in inputs[:-1]: # Exclude layout_cache from triggers
597
  control.change(app_logic, inputs, outputs)
598
 
599
+ # Button actions
600
  nashville_btn.click(
601
  apply_nashville_filter,
602
  inputs=None,
 
620
  ).then(app_logic, inputs, outputs)
621
 
622
  reset_view_btn.click(
623
+ lambda: (gr.update(), gr.update(), "", {}), # Clear quick_find and layout cache
624
  inputs=None,
625
  outputs=[graph_output, subtitle_display, quick_find_box, layout_cache]
626
  )
627
 
628
+ # Download CSV (requires Gradio >=3.x File component; here we provide a placeholder)
629
  def export_csv(
630
  country, industry, company, investors,
631
  exclude_countries, exclude_industries, exclude_companies, exclude_investors,
 
650
 
651
  gr.Markdown("""
652
  ---
653
+ **Accessibility:** Use Tab to navigate controls. Press Enter to activate buttons. Graph nodes are keyboard-focusable.
654
+ **Color Legend:** Companies are teal-green. Investors are color-coded (see palette). Non-color cues: stroke outlines differentiate node types.
655
+ **Performance:** Graphs update in <500ms for ≤300 companies. Large datasets are auto-capped; adjust slider as needed.
656
  """)
657
 
658
+ # Initial render
659
  demo.load(app_logic, inputs, outputs)
660
 
661
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860)