Files changed (1) hide show
  1. app.py +243 -195
app.py CHANGED
@@ -1,6 +1,9 @@
1
  # app.py
2
- # Interactive MF churn explorer — with client-side clickable Plotly
3
- # NOW WITH LEGEND UNDER CHART (only addition requested)
 
 
 
4
 
5
  import gradio as gr
6
  import pandas as pd
@@ -10,9 +13,10 @@ import numpy as np
10
  import json
11
  from collections import defaultdict
12
 
13
- # ---------------------------
14
- # Data
15
- # ---------------------------
 
16
  AMCS = [
17
  "SBI MF", "ICICI Pru MF", "HDFC MF", "Nippon India MF", "Kotak MF",
18
  "UTI MF", "Axis MF", "Aditya Birla SL MF", "Mirae MF", "DSP MF"
@@ -53,34 +57,41 @@ SELL_MAP = {
53
  COMPLETE_EXIT = {"DSP MF": ["Shriram Finance"]}
54
  FRESH_BUY = {"HDFC MF": ["Tata Elxsi"], "UTI MF": ["Adani Ports"], "Mirae MF": ["HAL"]}
55
 
 
56
  def sanitize_map(m):
57
  out = {}
58
  for k, vals in m.items():
59
  out[k] = [v for v in vals if v in COMPANIES]
60
  return out
61
 
 
62
  BUY_MAP = sanitize_map(BUY_MAP)
63
  SELL_MAP = sanitize_map(SELL_MAP)
64
  COMPLETE_EXIT = sanitize_map(COMPLETE_EXIT)
65
  FRESH_BUY = sanitize_map(FRESH_BUY)
66
 
67
- # ---------------------------
68
- # Build graph edges
69
- # ---------------------------
 
70
  company_edges = []
71
  for amc, comps in BUY_MAP.items():
72
  for c in comps:
73
  company_edges.append((amc, c, {"action": "buy", "weight": 1}))
 
74
  for amc, comps in SELL_MAP.items():
75
  for c in comps:
76
  company_edges.append((amc, c, {"action": "sell", "weight": 1}))
 
77
  for amc, comps in COMPLETE_EXIT.items():
78
  for c in comps:
79
  company_edges.append((amc, c, {"action": "complete_exit", "weight": 3}))
 
80
  for amc, comps in FRESH_BUY.items():
81
  for c in comps:
82
  company_edges.append((amc, c, {"action": "fresh_buy", "weight": 3}))
83
 
 
84
  def infer_amc_transfers(buy_map, sell_map):
85
  transfers = defaultdict(int)
86
  c2s = defaultdict(list)
@@ -97,48 +108,58 @@ def infer_amc_transfers(buy_map, sell_map):
97
  for c in set(c2s.keys()) | set(c2b.keys()):
98
  for s in c2s[c]:
99
  for b in c2b[c]:
100
- transfers[(s,b)] += 1
 
 
 
 
 
101
 
102
- out = []
103
- for (s,b), w in transfers.items():
104
- out.append((s,b,{"action":"transfer","weight":w}))
105
- return out
106
 
107
  transfer_edges = infer_amc_transfers(BUY_MAP, SELL_MAP)
108
 
 
109
  def build_graph(include_transfers=True):
110
  G = nx.DiGraph()
111
- for a in AMCS: G.add_node(a, type="amc")
112
- for c in COMPANIES: G.add_node(c, type="company")
113
 
114
- for u,v,attr in company_edges:
115
- if G.has_edge(u,v):
 
 
 
 
 
 
 
116
  G[u][v]["weight"] += attr["weight"]
117
  G[u][v]["actions"].append(attr["action"])
118
  else:
119
- G.add_edge(u,v,weight=attr["weight"], actions=[attr["action"]])
120
 
 
121
  if include_transfers:
122
- for s,b,attr in transfer_edges:
123
- if G.has_edge(s,b):
124
  G[s][b]["weight"] += attr["weight"]
125
  G[s][b]["actions"].append("transfer")
126
  else:
127
- G.add_edge(s,b,weight=attr["weight"], actions=["transfer"])
128
 
129
  return G
130
 
131
- # ---------------------------
132
- # Build Plotly figure for embedding
133
- # ---------------------------
134
- def build_plotly_figure(G,
135
- node_color_amc="#9EC5FF",
136
- node_color_company="#FFCF9E",
137
- edge_color_buy="#2ca02c",
138
- edge_color_sell="#d62728",
139
- edge_color_transfer="#888888",
140
- edge_thickness_base=1.4):
141
-
 
 
142
  pos = nx.spring_layout(G, seed=42, k=1.2)
143
 
144
  node_names = []
@@ -150,81 +171,100 @@ def build_plotly_figure(G,
150
  for n, d in G.nodes(data=True):
151
  node_names.append(n)
152
  x, y = pos[n]
153
- node_x.append(x); node_y.append(y)
 
 
154
  if d["type"] == "amc":
155
- node_color.append(node_color_amc); node_size.append(36)
 
156
  else:
157
- node_color.append(node_color_company); node_size.append(56)
 
158
 
159
  edge_traces = []
160
- edge_source_index = []
161
- edge_target_index = []
162
  edge_colors = []
163
  edge_widths = []
164
 
165
- for u,v,attrs in G.edges(data=True):
166
- x0,y0 = pos[u]; x1,y1 = pos[v]
 
167
  acts = attrs["actions"]
168
  weight = attrs["weight"]
169
 
170
  if "complete_exit" in acts:
171
- color = edge_color_sell; width = edge_thickness_base*3; dash="solid"
 
 
172
  elif "fresh_buy" in acts:
173
- color = edge_color_buy; width = edge_thickness_base*3; dash="solid"
 
 
174
  elif "transfer" in acts:
175
- color = edge_color_transfer; width=edge_thickness_base*(1+np.log1p(weight)); dash="dash"
 
 
176
  elif "sell" in acts:
177
- color = edge_color_sell; width=edge_thickness_base*(1+np.log1p(weight)); dash="dot"
 
 
178
  else:
179
- color = edge_color_buy; width=edge_thickness_base*(1+np.log1p(weight)); dash="solid"
180
-
181
- edge_traces.append(go.Scatter(
182
- x=[x0,x1], y=[y0,y1],
183
- mode="lines",
184
- line=dict(color=color, width=width, dash=dash),
185
- hoverinfo="none",
186
- opacity=1.0
187
- ))
188
- edge_source_index.append(node_names.index(u))
189
- edge_target_index.append(node_names.index(v))
 
 
 
 
 
 
190
  edge_colors.append(color)
191
  edge_widths.append(width)
192
 
193
  node_trace = go.Scatter(
194
- x=node_x, y=node_y,
 
195
  mode="markers+text",
196
- marker=dict(color=node_color, size=node_size, line=dict(width=2,color="#222")),
197
  text=node_names,
198
  textposition="top center",
199
  hoverinfo="text"
200
  )
201
 
202
- fig = go.Figure(data=edge_traces+[node_trace])
203
  fig.update_layout(
204
  showlegend=False,
205
  autosize=True,
206
- margin=dict(l=8,r=8,t=36,b=8),
207
  xaxis=dict(visible=False),
208
  yaxis=dict(visible=False)
209
  )
210
 
211
  meta = {
212
  "node_names": node_names,
213
- "edge_source_index": edge_source_index,
214
- "edge_target_index": edge_target_index,
215
  "edge_colors": edge_colors,
216
  "edge_widths": edge_widths
217
  }
218
- return fig, meta
219
 
220
- # ---------------------------
221
- # Create HTML with JS click-to-focus behavior
222
- # ---------------------------
223
  def make_network_html(fig, meta, div_id="network-plot-div"):
224
  fig_json = json.dumps(fig.to_plotly_json())
225
  meta_json = json.dumps(meta)
226
 
227
- return f"""
228
  <div id="{div_id}" style="width:100%;height:520px;"></div>
229
  <div style="margin-top:6px;margin-bottom:8px;">
230
  <button id="{div_id}-reset" style="padding:8px 12px;border-radius:6px;">Reset view</button>
@@ -243,96 +283,111 @@ Plotly.newPlot(container, fig.data, fig.layout, {{responsive:true}});
243
  const nodeTraceIndex = fig.data.length - 1;
244
  const edgeCount = fig.data.length - 1;
245
 
246
- const nameToIndex = {{}};
247
- meta.node_names.forEach((n,i)=>nameToIndex[n]=i);
248
 
249
- function focusNode(nodeName){{
 
250
  const idx = nameToIndex[nodeName];
251
  const keep = new Set([idx]);
252
 
253
- for(let e=0;e<meta.edge_source_index.length;e++){{
254
- const s=meta.edge_source_index[e];
255
- const t=meta.edge_target_index[e];
256
- if(s===idx) keep.add(t);
257
- if(t===idx) keep.add(s);
258
  }}
259
 
260
-
261
- // Update nodes (hide others + hide their labels)
262
  const N = meta.node_names.length;
263
  const nodeOp = Array(N).fill(0.0);
264
  const textColors = Array(N).fill("rgba(0,0,0,0)");
265
-
266
- for (let i = 0; i < N; i++) {
267
- if (keep.has(i)) {
268
  nodeOp[i] = 1.0;
269
- textColors[i] = "black"; // visible label
270
- }
271
- }
272
-
273
- Plotly.restyle(container, {
274
  "marker.opacity": [nodeOp],
275
  "textfont.color": [textColors]
276
- }, [nodeTraceIndex]);
277
-
278
-
279
- // edges
280
- for(let e=0;e<edgeCount;e++){{
281
- const s=meta.edge_source_index[e];
282
- const t=meta.edge_target_index[e];
283
- const show = keep.has(s)&&keep.has(t);
284
- const color = show?meta.edge_colors[e]:"rgba(0,0,0,0)";
285
- const width = show?meta.edge_widths[e]:0.1;
286
- Plotly.restyle(container,{{
287
- "line.color":[color],
288
- "line.width":[width]
289
- }},[e]);
290
  }}
291
- }}
292
 
293
- function resetView(){{
294
- const N=meta.node_names.length;
295
- const op=Array(N).fill(1.0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
 
 
297
  const N = meta.node_names.length;
298
  const nodeOp = Array(N).fill(1.0);
299
  const textColors = Array(N).fill("black");
300
-
301
- Plotly.restyle(container, {
302
- "marker.opacity":[nodeOp],
303
- "textfont.color":[textColors]
304
- }, [nodeTraceIndex]);
305
-
306
-
307
- for(let e=0;e<edgeCount;e++){{
308
- Plotly.restyle(container,{{
309
- "line.color":[meta.edge_colors[e]],
310
- "line.width":[meta.edge_widths[e]]
311
- }},[e]);
312
  }}
313
 
314
- Plotly.relayout(container, {{
315
- xaxis: {{autorange:true}},
316
- yaxis: {{autorange:true}}
317
- }});
318
  }}
319
 
320
- container.on('plotly_click', function(evt){{
321
- const p = evt.points[0];
322
- if(p.curveNumber===nodeTraceIndex){{
323
- const idx = p.pointNumber;
324
- const name = meta.node_names[idx];
325
- focusNode(name);
 
326
  }}
327
  }});
328
 
329
- document.getElementById("{div_id}-reset").onclick = resetView;
 
 
 
330
  </script>
331
  """
 
332
 
333
- # ---------------------------------------------
334
- # Build HTML network block
335
- # ---------------------------------------------
336
  def build_network_html(node_color_company="#FFCF9E",
337
  node_color_amc="#9EC5FF",
338
  edge_color_buy="#2ca02c",
@@ -340,7 +395,6 @@ def build_network_html(node_color_company="#FFCF9E",
340
  edge_color_transfer="#888888",
341
  edge_thickness=1.4,
342
  include_transfers=True):
343
-
344
  G = build_graph(include_transfers=include_transfers)
345
  fig, meta = build_plotly_figure(
346
  G,
@@ -353,72 +407,65 @@ def build_network_html(node_color_company="#FFCF9E",
353
  )
354
  return make_network_html(fig, meta)
355
 
356
- # Initial HTML
357
  initial_html = build_network_html()
 
 
358
 
359
- # ---------------------------
360
- # Company & AMC summaries
361
- # ---------------------------
362
  def company_trade_summary(company):
363
- buyers = [a for a,cs in BUY_MAP.items() if company in cs]
364
- sellers = [a for a,cs in SELL_MAP.items() if company in cs]
365
- fresh = [a for a,cs in FRESH_BUY.items() if company in cs]
366
- exits = [a for a,cs in COMPLETE_EXIT.items() if company in cs]
367
 
368
  df = pd.DataFrame({
369
- "Role": (["Buyer"]*len(buyers)) + (["Seller"]*len(sellers)) +
370
- (["Fresh buy"]*len(fresh)) + (["Complete exit"]*len(exits)),
371
  "AMC": buyers + sellers + fresh + exits
372
  })
373
 
374
  if df.empty:
375
- return None, pd.DataFrame(columns=["Role","AMC"])
376
 
377
  counts = df.groupby("Role").size().reset_index(name="Count")
378
- fig = go.Figure(go.Bar(x=counts["Role"], y=counts["Count"],
379
- marker_color=["green","red","orange","black"][:len(counts)]))
380
- fig.update_layout(title=f"Trade summary for {company}", autosize=True)
381
  return fig, df
382
 
383
  def amc_transfer_summary(amc):
384
  sold = SELL_MAP.get(amc, [])
385
  transfers = []
386
  for s in sold:
387
- buyers = [a for a,cs in BUY_MAP.items() if s in cs]
388
  for b in buyers:
389
  transfers.append({"security": s, "buyer_amc": b})
390
  df = pd.DataFrame(transfers)
391
  if df.empty:
392
- return None, pd.DataFrame(columns=["security","buyer_amc"])
393
-
394
  counts = df["buyer_amc"].value_counts().reset_index()
395
- counts.columns = ["Buyer AMC","Count"]
396
- fig = go.Figure(go.Bar(x=counts["Buyer AMC"], y=counts["Count"], marker_color="gray"))
397
- fig.update_layout(title=f"Inferred transfers from {amc}", autosize=True)
398
  return fig, df
399
 
400
- # ---------------------------
401
- # Mobile-friendly CSS
402
- # ---------------------------
403
  responsive_css = """
404
  .gradio-container { padding:0 !important; margin:0 !important; }
405
- .plotly-graph-div, .js-plotly-plot { width:100% !important; max-width:100% !important; }
406
  .js-plotly-plot { height:460px !important; }
407
  @media(max-width:780px){ .js-plotly-plot{ height:420px !important; } }
408
  body, html { overflow-x:hidden !important; }
409
  """
410
 
411
- # ---------------------------
412
- # UI BLOCKS WITH LEGEND ADDED
413
- # ---------------------------
414
  with gr.Blocks(css=responsive_css, title="MF Churn Explorer") as demo:
415
-
416
  gr.Markdown("## Mutual Fund Churn Explorer — Interactive Graph")
417
 
418
- # Chart (interactive HTML)
419
  network_html = gr.HTML(value=initial_html)
420
 
421
- # LEGEND (ONLY addition)
422
  legend_html = gr.HTML(value="""
423
  <div style='
424
  font-family: sans-serif;
@@ -460,66 +507,67 @@ with gr.Blocks(css=responsive_css, title="MF Churn Explorer") as demo:
460
  </div>
461
  """)
462
 
463
- # Controls
464
  with gr.Accordion("Network Customization — expand to edit", open=False):
465
  node_color_company = gr.ColorPicker("#FFCF9E", label="Company node color")
466
  node_color_amc = gr.ColorPicker("#9EC5FF", label="AMC node color")
467
  edge_color_buy = gr.ColorPicker("#2ca02c", label="BUY edge color")
468
  edge_color_sell = gr.ColorPicker("#d62728", label="SELL edge color")
469
  edge_color_transfer = gr.ColorPicker("#888888", label="Transfer edge color")
470
- edge_thickness = gr.Slider(0.5, 6, value=1.4, step=0.1, label="Edge thickness base")
471
  include_transfers = gr.Checkbox(value=True, label="Show AMC→AMC inferred transfers")
472
  update_button = gr.Button("Update Network Graph")
473
 
474
- # Company inspect
475
  gr.Markdown("### Inspect Company (buyers / sellers)")
476
- select_company = gr.Dropdown(COMPANIES, label="Select company")
477
  company_plot = gr.Plot()
478
  company_table = gr.DataFrame()
479
 
480
- # AMC inspect
481
  gr.Markdown("### Inspect AMC (inferred transfers)")
482
- select_amc = gr.Dropdown(AMCS, label="Select AMC")
483
  amc_plot = gr.Plot()
484
  amc_table = gr.DataFrame()
485
 
486
- # Callbacks
487
- def update_network(node_color_company_val,
488
- node_color_amc_val,
489
- edge_color_buy_val,
490
- edge_color_sell_val,
491
- edge_color_transfer_val,
492
- edge_thickness_val,
493
- include_transfers_val):
494
- return build_network_html(
495
- node_color_company=node_color_company_val,
496
- node_color_amc=node_color_amc_val,
497
- edge_color_buy=edge_color_buy_val,
498
- edge_color_sell=edge_color_sell_val,
499
- edge_color_transfer=edge_color_transfer_val,
500
- edge_thickness=edge_thickness_val,
501
- include_transfers=include_transfers_val
502
- )
503
 
504
- def on_company(c):
505
- fig, df = company_trade_summary(c)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  return fig, df
507
 
508
- def on_amc(a):
509
- fig, df = amc_transfer_summary(a)
 
 
510
  return fig, df
511
 
512
- update_button.click(
513
- update_network,
514
- inputs=[node_color_company, node_color_amc,
515
- edge_color_buy, edge_color_sell, edge_color_transfer,
516
- edge_thickness, include_transfers],
517
- outputs=[network_html]
518
- )
519
 
520
- select_company.change(on_company, [select_company], [company_plot, company_table])
521
- select_amc.change(on_amc, [select_amc], [amc_plot, amc_table])
522
 
523
- # Run app
524
  if __name__ == "__main__":
525
  demo.launch()
 
1
  # app.py
2
+ # Interactive MF churn explorer — Plotly graph with node click-to-focus
3
+ # + Legend
4
+ # + Fixed JS (labels hide properly)
5
+ # + Mobile-friendly
6
+ # + HF iframe safe
7
 
8
  import gradio as gr
9
  import pandas as pd
 
13
  import json
14
  from collections import defaultdict
15
 
16
+ # ============================================================
17
+ # DATA
18
+ # ============================================================
19
+
20
  AMCS = [
21
  "SBI MF", "ICICI Pru MF", "HDFC MF", "Nippon India MF", "Kotak MF",
22
  "UTI MF", "Axis MF", "Aditya Birla SL MF", "Mirae MF", "DSP MF"
 
57
  COMPLETE_EXIT = {"DSP MF": ["Shriram Finance"]}
58
  FRESH_BUY = {"HDFC MF": ["Tata Elxsi"], "UTI MF": ["Adani Ports"], "Mirae MF": ["HAL"]}
59
 
60
+
61
  def sanitize_map(m):
62
  out = {}
63
  for k, vals in m.items():
64
  out[k] = [v for v in vals if v in COMPANIES]
65
  return out
66
 
67
+
68
  BUY_MAP = sanitize_map(BUY_MAP)
69
  SELL_MAP = sanitize_map(SELL_MAP)
70
  COMPLETE_EXIT = sanitize_map(COMPLETE_EXIT)
71
  FRESH_BUY = sanitize_map(FRESH_BUY)
72
 
73
+ # ============================================================
74
+ # GRAPH BUILDING
75
+ # ============================================================
76
+
77
  company_edges = []
78
  for amc, comps in BUY_MAP.items():
79
  for c in comps:
80
  company_edges.append((amc, c, {"action": "buy", "weight": 1}))
81
+
82
  for amc, comps in SELL_MAP.items():
83
  for c in comps:
84
  company_edges.append((amc, c, {"action": "sell", "weight": 1}))
85
+
86
  for amc, comps in COMPLETE_EXIT.items():
87
  for c in comps:
88
  company_edges.append((amc, c, {"action": "complete_exit", "weight": 3}))
89
+
90
  for amc, comps in FRESH_BUY.items():
91
  for c in comps:
92
  company_edges.append((amc, c, {"action": "fresh_buy", "weight": 3}))
93
 
94
+
95
  def infer_amc_transfers(buy_map, sell_map):
96
  transfers = defaultdict(int)
97
  c2s = defaultdict(list)
 
108
  for c in set(c2s.keys()) | set(c2b.keys()):
109
  for s in c2s[c]:
110
  for b in c2b[c]:
111
+ transfers[(s, b)] += 1
112
+
113
+ output = []
114
+ for (s, b), w in transfers.items():
115
+ output.append((s, b, {"action": "transfer", "weight": w}))
116
+ return output
117
 
 
 
 
 
118
 
119
  transfer_edges = infer_amc_transfers(BUY_MAP, SELL_MAP)
120
 
121
+
122
  def build_graph(include_transfers=True):
123
  G = nx.DiGraph()
 
 
124
 
125
+ for a in AMCS:
126
+ G.add_node(a, type="amc")
127
+
128
+ for c in COMPANIES:
129
+ G.add_node(c, type="company")
130
+
131
+ # company edges
132
+ for u, v, attr in company_edges:
133
+ if G.has_edge(u, v):
134
  G[u][v]["weight"] += attr["weight"]
135
  G[u][v]["actions"].append(attr["action"])
136
  else:
137
+ G.add_edge(u, v, weight=attr["weight"], actions=[attr["action"]])
138
 
139
+ # inferred transfer edges
140
  if include_transfers:
141
+ for s, b, attr in transfer_edges:
142
+ if G.has_edge(s, b):
143
  G[s][b]["weight"] += attr["weight"]
144
  G[s][b]["actions"].append("transfer")
145
  else:
146
+ G.add_edge(s, b, weight=attr["weight"], actions=["transfer"])
147
 
148
  return G
149
 
150
+ # ============================================================
151
+ # PLOTLY FIGURE
152
+ # ============================================================
153
+
154
+ def build_plotly_figure(
155
+ G,
156
+ node_color_amc="#9EC5FF",
157
+ node_color_company="#FFCF9E",
158
+ edge_color_buy="#2ca02c",
159
+ edge_color_sell="#d62728",
160
+ edge_color_transfer="#888888",
161
+ edge_thickness_base=1.4
162
+ ):
163
  pos = nx.spring_layout(G, seed=42, k=1.2)
164
 
165
  node_names = []
 
171
  for n, d in G.nodes(data=True):
172
  node_names.append(n)
173
  x, y = pos[n]
174
+ node_x.append(x)
175
+ node_y.append(y)
176
+
177
  if d["type"] == "amc":
178
+ node_color.append(node_color_amc)
179
+ node_size.append(36)
180
  else:
181
+ node_color.append(node_color_company)
182
+ node_size.append(56)
183
 
184
  edge_traces = []
185
+ edge_source = []
186
+ edge_target = []
187
  edge_colors = []
188
  edge_widths = []
189
 
190
+ for u, v, attrs in G.edges(data=True):
191
+ x0, y0 = pos[u]
192
+ x1, y1 = pos[v]
193
  acts = attrs["actions"]
194
  weight = attrs["weight"]
195
 
196
  if "complete_exit" in acts:
197
+ color = edge_color_sell
198
+ width = edge_thickness_base * 3
199
+ dash = "solid"
200
  elif "fresh_buy" in acts:
201
+ color = edge_color_buy
202
+ width = edge_thickness_base * 3
203
+ dash = "solid"
204
  elif "transfer" in acts:
205
+ color = edge_color_transfer
206
+ width = edge_thickness_base * (1 + np.log1p(weight))
207
+ dash = "dash"
208
  elif "sell" in acts:
209
+ color = edge_color_sell
210
+ width = edge_thickness_base * (1 + np.log1p(weight))
211
+ dash = "dot"
212
  else:
213
+ color = edge_color_buy
214
+ width = edge_thickness_base * (1 + np.log1p(weight))
215
+ dash = "solid"
216
+
217
+ edge_traces.append(
218
+ go.Scatter(
219
+ x=[x0, x1],
220
+ y=[y0, y1],
221
+ mode="lines",
222
+ line=dict(color=color, width=width, dash=dash),
223
+ hoverinfo="none",
224
+ opacity=1.0
225
+ )
226
+ )
227
+
228
+ edge_source.append(node_names.index(u))
229
+ edge_target.append(node_names.index(v))
230
  edge_colors.append(color)
231
  edge_widths.append(width)
232
 
233
  node_trace = go.Scatter(
234
+ x=node_x,
235
+ y=node_y,
236
  mode="markers+text",
237
+ marker=dict(color=node_color, size=node_size, line=dict(width=2, color="#333")),
238
  text=node_names,
239
  textposition="top center",
240
  hoverinfo="text"
241
  )
242
 
243
+ fig = go.Figure(data=edge_traces + [node_trace])
244
  fig.update_layout(
245
  showlegend=False,
246
  autosize=True,
247
+ margin=dict(l=8, r=8, t=36, b=8),
248
  xaxis=dict(visible=False),
249
  yaxis=dict(visible=False)
250
  )
251
 
252
  meta = {
253
  "node_names": node_names,
254
+ "edge_source_index": edge_source,
255
+ "edge_target_index": edge_target,
256
  "edge_colors": edge_colors,
257
  "edge_widths": edge_widths
258
  }
 
259
 
260
+ return fig, meta
261
+ # ================= PART 2 / 3 =================
262
+ # HTML builder and JS (with escaped braces for f-string)
263
  def make_network_html(fig, meta, div_id="network-plot-div"):
264
  fig_json = json.dumps(fig.to_plotly_json())
265
  meta_json = json.dumps(meta)
266
 
267
+ html = f"""
268
  <div id="{div_id}" style="width:100%;height:520px;"></div>
269
  <div style="margin-top:6px;margin-bottom:8px;">
270
  <button id="{div_id}-reset" style="padding:8px 12px;border-radius:6px;">Reset view</button>
 
283
  const nodeTraceIndex = fig.data.length - 1;
284
  const edgeCount = fig.data.length - 1;
285
 
286
+ const nameToIndex = {{}};
287
+ meta.node_names.forEach((n,i) => nameToIndex[n]=i);
288
 
289
+ // focusNode: show only clicked node + its direct neighbors (Option A)
290
+ function focusNode(nodeName) {{
291
  const idx = nameToIndex[nodeName];
292
  const keep = new Set([idx]);
293
 
294
+ for (let e = 0; e < meta.edge_source_index.length; e++) {{
295
+ const s = meta.edge_source_index[e];
296
+ const t = meta.edge_target_index[e];
297
+ if (s === idx) {{ keep.add(t); }}
298
+ if (t === idx) {{ keep.add(s); }}
299
  }}
300
 
301
+ // Update nodes (hide others + hide labels)
 
302
  const N = meta.node_names.length;
303
  const nodeOp = Array(N).fill(0.0);
304
  const textColors = Array(N).fill("rgba(0,0,0,0)");
305
+
306
+ for (let i = 0; i < N; i++) {{
307
+ if (keep.has(i)) {{
308
  nodeOp[i] = 1.0;
309
+ textColors[i] = "black";
310
+ }}
311
+ }}
312
+
313
+ Plotly.restyle(container, {{
314
  "marker.opacity": [nodeOp],
315
  "textfont.color": [textColors]
316
+ }}, [nodeTraceIndex]);
317
+
318
+ // Update edges: show only edges connecting kept nodes
319
+ for (let e = 0; e < edgeCount; e++) {{
320
+ const s = meta.edge_source_index[e];
321
+ const t = meta.edge_target_index[e];
322
+ const show = (keep.has(s) && keep.has(t));
323
+ const color = show ? meta.edge_colors[e] : 'rgba(0,0,0,0)';
324
+ const width = show ? meta.edge_widths[e] : 0.1;
325
+ Plotly.restyle(container, {{
326
+ 'line.color': [color],
327
+ 'line.width': [width]
328
+ }}, [e]);
 
329
  }}
 
330
 
331
+ // zoom to bounding box of kept nodes
332
+ const nodes = fig.data[nodeTraceIndex];
333
+ const xs = [], ys = [];
334
+ for (let j = 0; j < meta.node_names.length; j++) {{
335
+ if (keep.has(j)) {{
336
+ xs.push(nodes.x[j]); ys.push(nodes.y[j]);
337
+ }}
338
+ }}
339
+ if (xs.length > 0) {{
340
+ const xmin = Math.min(...xs), xmax = Math.max(...xs);
341
+ const ymin = Math.min(...ys), ymax = Math.max(...ys);
342
+ const padX = (xmax - xmin) * 0.4 + 0.05;
343
+ const padY = (ymax - ymin) * 0.4 + 0.05;
344
+ Plotly.relayout(container, {{
345
+ xaxis: {{ range: [xmin - padX, xmax + padX] }},
346
+ yaxis: {{ range: [ymin - padY, ymax + padY] }}
347
+ }});
348
+ }}
349
+ }}
350
 
351
+ // reset view: restore nodes and edges
352
+ function resetView() {{
353
  const N = meta.node_names.length;
354
  const nodeOp = Array(N).fill(1.0);
355
  const textColors = Array(N).fill("black");
356
+
357
+ Plotly.restyle(container, {{
358
+ "marker.opacity": [nodeOp],
359
+ "textfont.color": [textColors]
360
+ }}, [nodeTraceIndex]);
361
+
362
+ for (let e = 0; e < edgeCount; e++) {{
363
+ Plotly.restyle(container, {{
364
+ 'line.color': [meta.edge_colors[e]],
365
+ 'line.width': [meta.edge_widths[e]]
366
+ }}, [e]);
 
367
  }}
368
 
369
+ Plotly.relayout(container, {{ xaxis: {{autorange:true}}, yaxis: {{autorange:true}} }});
 
 
 
370
  }}
371
 
372
+ // attach click handler
373
+ container.on('plotly_click', function(eventData) {{
374
+ const p = eventData.points[0];
375
+ if (p.curveNumber === nodeTraceIndex) {{
376
+ const nodeIndex = p.pointNumber;
377
+ const nodeName = meta.node_names[nodeIndex];
378
+ focusNode(nodeName);
379
  }}
380
  }});
381
 
382
+ // reset button
383
+ document.getElementById("{div_id}-reset").addEventListener('click', function() {{
384
+ resetView();
385
+ }});
386
  </script>
387
  """
388
+ return html
389
 
390
+ # helper to build final html block
 
 
391
  def build_network_html(node_color_company="#FFCF9E",
392
  node_color_amc="#9EC5FF",
393
  edge_color_buy="#2ca02c",
 
395
  edge_color_transfer="#888888",
396
  edge_thickness=1.4,
397
  include_transfers=True):
 
398
  G = build_graph(include_transfers=include_transfers)
399
  fig, meta = build_plotly_figure(
400
  G,
 
407
  )
408
  return make_network_html(fig, meta)
409
 
410
+ # initial HTML
411
  initial_html = build_network_html()
412
+ # ================= PART 3 / 3 =================
413
+ # company & amc summaries, UI and callbacks
414
 
 
 
 
415
  def company_trade_summary(company):
416
+ buyers = [a for a, cs in BUY_MAP.items() if company in cs]
417
+ sellers = [a for a, cs in SELL_MAP.items() if company in cs]
418
+ fresh = [a for a, cs in FRESH_BUY.items() if company in cs]
419
+ exits = [a for a, cs in COMPLETE_EXIT.items() if company in cs]
420
 
421
  df = pd.DataFrame({
422
+ "Role": (["Buyer"] * len(buyers)) + (["Seller"] * len(sellers)) +
423
+ (["Fresh buy"] * len(fresh)) + (["Complete exit"] * len(exits)),
424
  "AMC": buyers + sellers + fresh + exits
425
  })
426
 
427
  if df.empty:
428
+ return None, pd.DataFrame([], columns=["Role", "AMC"])
429
 
430
  counts = df.groupby("Role").size().reset_index(name="Count")
431
+ colors = ["green", "red", "orange", "black"][:len(counts)]
432
+ fig = go.Figure(go.Bar(x=counts["Role"], y=counts["Count"], marker_color=colors))
433
+ fig.update_layout(title_text=f"Trade summary for {company}", autosize=True, margin=dict(t=30, b=10))
434
  return fig, df
435
 
436
  def amc_transfer_summary(amc):
437
  sold = SELL_MAP.get(amc, [])
438
  transfers = []
439
  for s in sold:
440
+ buyers = [a for a, cs in BUY_MAP.items() if s in cs]
441
  for b in buyers:
442
  transfers.append({"security": s, "buyer_amc": b})
443
  df = pd.DataFrame(transfers)
444
  if df.empty:
445
+ return None, pd.DataFrame([], columns=["security", "buyer_amc"])
 
446
  counts = df["buyer_amc"].value_counts().reset_index()
447
+ counts.columns = ["Buyer AMC", "Count"]
448
+ fig = go.Figure(go.Bar(x=counts["Buyer AMC"], y=counts["Count"], marker_color="lightslategray"))
449
+ fig.update_layout(title_text=f"Inferred transfers from {amc}", autosize=True, margin=dict(t=30, b=10))
450
  return fig, df
451
 
452
+ # Mobile-friendly CSS (minimal)
 
 
453
  responsive_css = """
454
  .gradio-container { padding:0 !important; margin:0 !important; }
455
+ .plotly-graph-div, .js-plotly-plot, .output_plot { width:100% !important; max-width:100% !important; }
456
  .js-plotly-plot { height:460px !important; }
457
  @media(max-width:780px){ .js-plotly-plot{ height:420px !important; } }
458
  body, html { overflow-x:hidden !important; }
459
  """
460
 
461
+ # Build UI
 
 
462
  with gr.Blocks(css=responsive_css, title="MF Churn Explorer") as demo:
 
463
  gr.Markdown("## Mutual Fund Churn Explorer — Interactive Graph")
464
 
465
+ # Chart HTML (interactive client-side)
466
  network_html = gr.HTML(value=initial_html)
467
 
468
+ # Legend (ONLY addition)
469
  legend_html = gr.HTML(value="""
470
  <div style='
471
  font-family: sans-serif;
 
507
  </div>
508
  """)
509
 
510
+ # Controls (collapsed by default)
511
  with gr.Accordion("Network Customization — expand to edit", open=False):
512
  node_color_company = gr.ColorPicker("#FFCF9E", label="Company node color")
513
  node_color_amc = gr.ColorPicker("#9EC5FF", label="AMC node color")
514
  edge_color_buy = gr.ColorPicker("#2ca02c", label="BUY edge color")
515
  edge_color_sell = gr.ColorPicker("#d62728", label="SELL edge color")
516
  edge_color_transfer = gr.ColorPicker("#888888", label="Transfer edge color")
517
+ edge_thickness = gr.Slider(0.5, 6.0, value=1.4, step=0.1, label="Edge thickness base")
518
  include_transfers = gr.Checkbox(value=True, label="Show AMC→AMC inferred transfers")
519
  update_button = gr.Button("Update Network Graph")
520
 
521
+ # Company inspect (unchanged)
522
  gr.Markdown("### Inspect Company (buyers / sellers)")
523
+ select_company = gr.Dropdown(choices=COMPANIES, label="Select company")
524
  company_plot = gr.Plot()
525
  company_table = gr.DataFrame()
526
 
527
+ # AMC inspect (unchanged)
528
  gr.Markdown("### Inspect AMC (inferred transfers)")
529
+ select_amc = gr.Dropdown(choices=AMCS, label="Select AMC")
530
  amc_plot = gr.Plot()
531
  amc_table = gr.DataFrame()
532
 
533
+ # Place legend right after the chart (no layout changes beyond that)
534
+ # We add both components so legend appears below the chart area.
535
+ # Note: the order of declaration in Blocks determines visual order.
536
+ # legend_html.update(value=legend_html.value) # ensure added
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
+ # Callbacks
539
+ def update_network_html(node_color_company_val, node_color_amc_val,
540
+ edge_color_buy_val, edge_color_sell_val, edge_color_transfer_val,
541
+ edge_thickness_val, include_transfers_val):
542
+ return build_network_html(node_color_company=node_color_company_val,
543
+ node_color_amc=node_color_amc_val,
544
+ edge_color_buy=edge_color_buy_val,
545
+ edge_color_sell=edge_color_sell_val,
546
+ edge_color_transfer=edge_color_transfer_val,
547
+ edge_thickness=edge_thickness_val,
548
+ include_transfers=include_transfers_val)
549
+
550
+ def on_company_select(cname):
551
+ fig, df = company_trade_summary(cname)
552
+ if fig is None:
553
+ return None, pd.DataFrame([], columns=["Role", "AMC"])
554
  return fig, df
555
 
556
+ def on_amc_select(aname):
557
+ fig, df = amc_transfer_summary(aname)
558
+ if fig is None:
559
+ return None, pd.DataFrame([], columns=["security", "buyer_amc"])
560
  return fig, df
561
 
562
+ update_button.click(fn=update_network_html,
563
+ inputs=[node_color_company, node_color_amc,
564
+ edge_color_buy, edge_color_sell, edge_color_transfer,
565
+ edge_thickness, include_transfers],
566
+ outputs=[network_html])
 
 
567
 
568
+ select_company.change(fn=on_company_select, inputs=[select_company], outputs=[company_plot, company_table])
569
+ select_amc.change(fn=on_amc_select, inputs=[select_amc], outputs=[amc_plot, amc_table])
570
 
571
+ # Run
572
  if __name__ == "__main__":
573
  demo.launch()