AMCAnalysis / app.py
singhn9's picture
Update app.py
24312f4 verified
raw
history blame
15.1 kB
# app.py
# Static weighted semi-layer arc diagram (L1 labels outside)
# With short labels by default & full label on click
import gradio as gr
import pandas as pd
import json
import numpy as np
from collections import defaultdict
# ---------------------------
# Data
# ---------------------------
AMCS = [
"SBI MF", "ICICI Pru MF", "HDFC MF", "Nippon India MF", "Kotak MF",
"UTI MF", "Axis MF", "Aditya Birla SL MF", "Mirae MF", "DSP MF"
]
COMPANIES = [
"HDFC Bank", "ICICI Bank", "Bajaj Finance", "Bajaj Finserv", "Adani Ports",
"Tata Motors", "Shriram Finance", "HAL", "TCS", "AU Small Finance Bank",
"Pearl Global", "Hindalco", "Tata Elxsi", "Cummins India", "Vedanta"
]
BUY_MAP = {
"SBI MF": ["Bajaj Finance", "AU Small Finance Bank"],
"ICICI Pru MF": ["HDFC Bank"],
"HDFC MF": ["Tata Elxsi", "TCS"],
"Nippon India MF": ["Hindalco"],
"Kotak MF": ["Bajaj Finance"],
"UTI MF": ["Adani Ports", "Shriram Finance"],
"Axis MF": ["Tata Motors", "Shriram Finance"],
"Aditya Birla SL MF": ["AU Small Finance Bank"],
"Mirae MF": ["Bajaj Finance", "HAL"],
"DSP MF": ["Tata Motors", "Bajaj Finserv"]
}
SELL_MAP = {
"SBI MF": ["Tata Motors"],
"ICICI Pru MF": ["Bajaj Finance", "Adani Ports"],
"HDFC MF": ["HDFC Bank"],
"Nippon India MF": ["Hindalco"],
"Kotak MF": ["AU Small Finance Bank"],
"UTI MF": ["Hindalco", "TCS"],
"Axis MF": ["TCS"],
"Aditya Birla SL MF": ["Adani Ports"],
"Mirae MF": ["TCS"],
"DSP MF": ["HAL", "Shriram Finance"]
}
COMPLETE_EXIT = {"DSP MF": ["Shriram Finance"]}
FRESH_BUY = {"HDFC MF": ["Tata Elxsi"], "UTI MF": ["Adani Ports"], "Mirae MF": ["HAL"]}
def sanitize_map(m):
out = {}
for k, vals in m.items():
out[k] = [v for v in vals if v in COMPANIES]
return out
BUY_MAP = sanitize_map(BUY_MAP)
SELL_MAP = sanitize_map(SELL_MAP)
COMPLETE_EXIT = sanitize_map(COMPLETE_EXIT)
FRESH_BUY = sanitize_map(FRESH_BUY)
# ---------------------------
# Label maps (NEW)
# ---------------------------
SHORT_LABEL = {
"SBI MF": "SBI",
"ICICI Pru MF": "ICICI",
"HDFC MF": "HDFC",
"Nippon India MF": "NIP",
"Kotak MF": "KOTAK",
"UTI MF": "UTI",
"Axis MF": "AXIS",
"Aditya Birla SL MF": "ABSL",
"Mirae MF": "MIRAE",
"DSP MF": "DSP",
"HDFC Bank": "HDFC Bk",
"ICICI Bank": "ICICI Bk",
"Bajaj Finance": "Bajaj Fin",
"Bajaj Finserv": "Bajaj Fsrv",
"Adani Ports": "AdaniPt",
"Tata Motors": "TataMot",
"Shriram Finance": "Shriram",
"HAL": "HAL",
"TCS": "TCS",
"AU Small Finance Bank": "AU SFB",
"Pearl Global": "PearlG",
"Hindalco": "Hindalco",
"Tata Elxsi": "Elxsi",
"Cummins India": "Cummins",
"Vedanta": "Vedanta"
}
FULL_LABEL = {k: k for k in SHORT_LABEL}
# ---------------------------
# Infer AMC→AMC transfers
# ---------------------------
def infer_amc_transfers(buy_map, sell_map):
transfers = defaultdict(int)
c2s = defaultdict(list)
c2b = defaultdict(list)
for amc, comps in sell_map.items():
for c in comps: c2s[c].append(amc)
for amc, comps in buy_map.items():
for c in comps: c2b[c].append(amc)
for c in set(c2s) | set(c2b):
for s in c2s[c]:
for b in c2b[c]:
transfers[(s, b)] += 1
return transfers
TRANSFER_COUNTS = infer_amc_transfers(BUY_MAP, SELL_MAP)
# ---------------------------
# Mixed ordering to reduce crossings
# ---------------------------
def build_mixed_ordering(amcs, companies):
mixed = []
n = max(len(amcs), len(companies))
for i in range(n):
if i < len(amcs): mixed.append(amcs[i])
if i < len(companies): mixed.append(companies[i])
return mixed
NODES = build_mixed_ordering(AMCS, COMPANIES)
NODE_TYPE = {n: ("amc" if n in AMCS else "company") for n in NODES}
# ---------------------------
# Build flows
# ---------------------------
def build_flows():
buys, sells, transfers, loops = [], [], [], []
for amc, comps in BUY_MAP.items():
for c in comps:
w = 3 if (amc in FRESH_BUY and c in FRESH_BUY.get(amc,[])) else 1
buys.append((amc,c,w))
for amc, comps in SELL_MAP.items():
for c in comps:
w = 3 if (amc in COMPLETE_EXIT and c in COMPLETE_EXIT.get(amc,[])) else 1
sells.append((c,amc,w))
for (s,b),w in TRANSFER_COUNTS.items():
transfers.append((s,b,w))
for a,c,_ in buys:
for c2,b,_ in sells:
if c==c2: loops.append((a,c,b))
loops = list({(a,c,b) for (a,c,b) in loops})
return buys, sells, transfers, loops
BUYS, SELLS, TRANSFERS, LOOPS = build_flows()
# ---------------------------
# Inspect panels
# ---------------------------
def company_trade_summary(company):
buyers = [a for a,cs in BUY_MAP.items() if company in cs]
sellers = [a for a,cs in SELL_MAP.items() if company in cs]
fresh = [a for a,cs in FRESH_BUY.items() if company in cs]
exits = [a for a,cs in COMPLETE_EXIT.items() if company in cs]
df = pd.DataFrame({"Role": (["Buyer"]*len(buyers))+(["Seller"]*len(sellers))+
(["Fresh buy"]*len(fresh))+(["Complete exit"]*len(exits)),
"AMC": buyers+sellers+fresh+exits})
if df.empty: return None, df
counts = df.groupby("Role").size().reset_index(name="Count")
fig = {"data":[{"type":"bar","x":counts["Role"].tolist(),"y":counts["Count"].tolist()}],
"layout":{"title":f"Trades for {company}"}}
return fig,df
def amc_transfer_summary(amc):
sold = SELL_MAP.get(amc,[])
transfers=[]
for s in sold:
buyers=[a for a,cs in BUY_MAP.items() if s in cs]
for b in buyers: transfers.append({"security":s,"buyer_amc":b})
df=pd.DataFrame(transfers)
if df.empty:return None,df
counts=df["buyer_amc"].value_counts().reset_index()
counts.columns=["Buyer AMC","Count"]
fig={"data":[{"type":"bar","x":counts["Buyer AMC"].tolist(),"y":counts["Count"].tolist()}],
"layout":{"title":f"Inferred transfers from {amc}"}}
return fig,df
# ---------------------------
# HTML template: safe, no f-strings
# ---------------------------
JS_TEMPLATE = """
<div id="arc-container" style="width:100%; height:720px;"></div>
<div style="margin-top:8px;">
<button id="arc-reset" style="padding:8px 12px; border-radius:6px;">Reset</button>
</div>
<div style="margin-top:10px; font-family:sans-serif; font-size:13px;">
<b>Legend</b><br/>
BUY = green solid<br/>
SELL = red dotted<br/>
TRANSFER = grey<br/>
LOOP = teal external arc<br/>
<div style="margin-top:6px;color:#666;font-size:12px;">Labels: short by default. Clicking a node shows full name.</div>
</div>
<script src="https://d3js.org/d3.v7.min.js"></script>
<script>
const NODES = __NODES__;
const NODE_TYPE = __NODE_TYPE__;
const BUYS = __BUYS__;
const SELLS = __SELLS__;
const TRANSFERS = __TRANSFERS__;
const LOOPS = __LOOPS__;
const SHORT_LABEL_JS = __SHORT_LABEL__;
const FULL_LABEL_JS = __FULL_LABEL__;
function draw() {
const container = document.getElementById("arc-container");
container.innerHTML = "";
const w = Math.min(920, container.clientWidth || 820);
const h = Math.max(420, Math.floor(w * 0.75));
const svg = d3.select(container).append("svg")
.attr("width","100%")
.attr("height",h)
.attr("viewBox",[-w/2,-h/2,w,h].join(" "));
const radius = Math.min(w,h)*0.36;
const n=NODES.length;
function angleFor(i){return (i/n)*2*Math.PI;}
const nodePos = NODES.map((name,i)=>{
const ang=angleFor(i)-Math.PI/2;
return {name,angle:ang,x:Math.cos(ang)*radius,y:Math.sin(ang)*radius};
});
const nameToIndex={};
NODES.forEach((nm,i)=>nameToIndex[nm]=i);
const group=svg.append("g").selectAll("g").data(nodePos).enter().append("g")
.attr("transform", d=>`translate(${d.x},${d.y})`);
group.append("circle")
.attr("r",16)
.style("fill",d=>NODE_TYPE[d.name]==="amc"?"#2b6fa6":"#f2c88d")
.style("stroke","#222")
.style("stroke-width",1)
.style("cursor","pointer");
group.append("text")
.attr("x", d => Math.cos(d.angle) * (radius + 20))
.attr("y", d => Math.sin(d.angle) * (radius + 20))
.attr("dy", "0.35em")
.style("font-family", "sans-serif")
.style("font-size", Math.max(10, Math.min(13, radius*0.038)))
.style("text-anchor", "middle")
.style("cursor","pointer")
.text(d => d.name);
function bezierPath(x0,y0,x1,y1,above=true){
const mx=(x0+x1)/2, my=(y0+y1)/2;
const dx=mx, dy=my;
const len=Math.sqrt(dx*dx+dy*dy)||1;
const ux=dx/len, uy=dy/len;
const offset=(above?-1:1)*Math.max(30,radius*0.9);
const cx=mx+ux*offset, cy=my+uy*offset;
return `M ${x0} ${y0} Q ${cx} ${cy} ${x1} ${y1}`;
}
const allW=[].concat(BUYS.map(d=>d[2]),SELLS.map(d=>d[2]),TRANSFERS.map(d=>d[2]));
const stroke=d3.scaleLinear().domain([1,Math.max(...allW,1)]).range([1.0,6.0]);
// BUYS top
const buyG=svg.append("g");
BUYS.forEach(b=>{
const a=b[0], c=b[1], wt=b[2];
if(!(a in nameToIndex)||!(c in nameToIndex))return;
const s=nodePos[nameToIndex[a]], t=nodePos[nameToIndex[c]];
buyG.append("path")
.attr("d",bezierPath(s.x,s.y,t.x,t.y,true))
.attr("fill","none")
.attr("stroke","#2e8540")
.attr("stroke-width",stroke(wt))
.attr("opacity",0.92)
.attr("data-src",a)
.attr("data-tgt",c);
});
// SELLS bottom
const sellG=svg.append("g");
SELLS.forEach(s=>{
const c=s[0], a=s[1], wt=s[2];
if(!(c in nameToIndex)||!(a in nameToIndex))return;
const sp=nodePos[nameToIndex[c]], tp=nodePos[nameToIndex[a]];
sellG.append("path")
.attr("d",bezierPath(sp.x,sp.y,tp.x,tp.y,false))
.attr("fill","none")
.attr("stroke","#c0392b")
.attr("stroke-width",stroke(wt))
.attr("stroke-dasharray","4,3")
.attr("opacity",0.86)
.attr("data-src",c)
.attr("data-tgt",a);
});
// transfers
const trG=svg.append("g");
TRANSFERS.forEach(t=>{
const s=t[0], b=t[1], wt=t[2];
if(!(s in nameToIndex)||!(b in nameToIndex))return;
const sp=nodePos[nameToIndex[s]], tp=nodePos[nameToIndex[b]];
const mx=(sp.x+tp.x)/2, my=(sp.y+tp.y)/2;
const path=`M ${sp.x} ${sp.y} Q ${mx*0.3} ${my*0.3} ${tp.x} ${tp.y}`;
trG.append("path")
.attr("d",path)
.attr("fill","none")
.attr("stroke","#7d7d7d")
.attr("stroke-width",stroke(wt))
.attr("opacity",0.7)
.attr("data-src",s)
.attr("data-tgt",b);
});
// loops
const loopG=svg.append("g");
LOOPS.forEach(lp=>{
const a=lp[0], c=lp[1], b=lp[2];
if(!(a in nameToIndex)||!(b in nameToIndex))return;
const sa=nodePos[nameToIndex[a]], sb=nodePos[nameToIndex[b]];
const mx=(sa.x+sb.x)/2, my=(sa.y+sb.y)/2;
const len=Math.sqrt((sa.x-sb.x)**2+(sa.y-sb.y)**2);
const outward=Math.max(40,radius*0.28+len*0.12);
const ndx=mx, ndy=my;
const nlen=Math.sqrt(ndx*ndx+ndy*ndy)||1;
const ux=ndx/nlen, uy=ndy/nlen;
const cx=mx+ux*outward, cy=my+uy*outward;
const path=`M ${sa.x} ${sa.y} Q ${cx} ${cy} ${sb.x} ${sb.y}`;
loopG.append("path")
.attr("d",path)
.attr("fill","none")
.attr("stroke","#227a6d")
.attr("stroke-width",2.8)
.attr("opacity",0.95);
});
// click -> highlight + full label
function setOpacityFor(nodeName){
group.selectAll("circle").style("opacity",d=>d.name===nodeName?1.0:0.18);
group.selectAll("text")
.style("opacity",d=>d.name===nodeName?1.0:0.28)
.text(d=>d.name===nodeName?FULL_LABEL_JS[d.name]:SHORT_LABEL_JS[d.name]);
buyG.selectAll("path").style("opacity",function(){
return (this.getAttribute("data-src")===nodeName||
this.getAttribute("data-tgt")===nodeName)?0.98:0.06;
});
sellG.selectAll("path").style("opacity",function(){
return (this.getAttribute("data-src")===nodeName||
this.getAttribute("data-tgt")===nodeName)?0.98:0.06;
});
trG.selectAll("path").style("opacity",function(){
return (this.getAttribute("data-src")===nodeName||
this.getAttribute("data-tgt")===nodeName)?0.98:0.06;
});
}
function resetOpacity(){
group.selectAll("circle").style("opacity",1.0);
group.selectAll("text").style("opacity",1.0)
.text(d=>SHORT_LABEL_JS[d.name]);
buyG.selectAll("path").style("opacity",0.92);
sellG.selectAll("path").style("opacity",0.86);
trG.selectAll("path").style("opacity",0.7);
loopG.selectAll("path").style("opacity",0.95);
}
group.selectAll("circle").on("click",function(e,d){
setOpacityFor(d.name);
e.stopPropagation();
});
group.selectAll("text").on("click",function(e,d){
setOpacityFor(d.name);
e.stopPropagation();
});
document.getElementById("arc-reset").onclick=resetOpacity;
svg.on("click",()=>resetOpacity());
}
draw();
window.addEventListener("resize",draw);
</script>
"""
# ---------------------------
# Build final HTML
# ---------------------------
def make_arc_html(nodes, node_type, buys, sells, transfers, loops):
html = JS_TEMPLATE
html = html.replace("__NODES__", json.dumps(nodes))
html = html.replace("__NODE_TYPE__", json.dumps(node_type))
html = html.replace("__BUYS__", json.dumps(buys))
html = html.replace("__SELLS__", json.dumps(sells))
html = html.replace("__TRANSFERS__", json.dumps(transfers))
html = html.replace("__LOOPS__", json.dumps(loops))
html = html.replace("__SHORT_LABEL__", json.dumps(SHORT_LABEL))
html = html.replace("__FULL_LABEL__", json.dumps(FULL_LABEL))
return html
initial_html = make_arc_html(NODES, NODE_TYPE, BUYS, SELLS, TRANSFERS, LOOPS)
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks(title="MF Churn — Arc Diagram (Short→Full Label on Click)") as demo:
gr.Markdown("## Mutual Fund Churn — Weighted Arc Diagram (Short labels → Full label on click)")
gr.HTML(initial_html)
gr.Markdown("### Inspect Company / AMC")
select_company = gr.Dropdown(COMPANIES, label="Select company")
company_plot = gr.Plot()
company_table = gr.DataFrame()
select_amc = gr.Dropdown(AMCS, label="Select AMC")
amc_plot = gr.Plot()
amc_table = gr.DataFrame()
select_company.change(company_trade_summary, inputs=[select_company], outputs=[company_plot, company_table])
select_amc.change(amc_transfer_summary, inputs=[select_amc], outputs=[amc_plot, amc_table])
if __name__ == "__main__":
demo.launch()