Added all brands
Browse files
app.py
CHANGED
|
@@ -4,12 +4,54 @@ from pyvis.network import Network
|
|
| 4 |
import pickle
|
| 5 |
import math
|
| 6 |
|
| 7 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
@st.cache_resource
|
| 9 |
-
def load_graph():
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
return pickle.load(f)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5):
|
| 14 |
net = Network(notebook=False, width="100%", height="600px", directed=True)
|
| 15 |
net.set_options("""
|
|
@@ -62,18 +104,58 @@ def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5):
|
|
| 62 |
html_content = net.generate_html()
|
| 63 |
st.components.v1.html(html_content, height=600, scrolling=False)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Streamlit interface
|
| 67 |
-
st.title("
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
# Load the graph
|
| 70 |
-
G = load_graph()
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
import random
|
| 73 |
# Randomly sample a starting node
|
| 74 |
# Initialize session state for the starting node
|
| 75 |
-
if "start_node" not in st.session_state:
|
| 76 |
-
|
| 77 |
|
| 78 |
# Input: Starting node
|
| 79 |
start_node = st.number_input(
|
|
@@ -86,7 +168,7 @@ top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
|
|
| 86 |
|
| 87 |
# Trigger the visualization
|
| 88 |
if st.button("Expand Graph"):
|
| 89 |
-
if start_node in
|
| 90 |
-
dynamic_visualize_graph(
|
| 91 |
else:
|
| 92 |
st.error("The starting node is not in the graph!")
|
|
|
|
| 4 |
import pickle
|
| 5 |
import math
|
| 6 |
|
| 7 |
+
# Dictionary to map brands to their respective pickle files
|
| 8 |
+
BRAND_GRAPHS = {
|
| 9 |
+
'drumeo': 'drumeo_pop_items_labels.pkl',
|
| 10 |
+
'pianote': 'pianote_pop_items_labels.pkl',
|
| 11 |
+
'singeo': 'singeo_pop_items_labels.pkl',
|
| 12 |
+
'guitareo': 'guitareo_pop_items_labels.pkl'
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
@st.cache_resource
|
| 16 |
+
def load_graph(brand):
|
| 17 |
+
"""
|
| 18 |
+
Load the graph for the selected brand.
|
| 19 |
+
Parameters:
|
| 20 |
+
brand (str): The brand name corresponding to the graph to load.
|
| 21 |
+
Returns:
|
| 22 |
+
nx.DiGraph: The loaded graph.
|
| 23 |
+
"""
|
| 24 |
+
with open(BRAND_GRAPHS[brand], 'rb') as f:
|
| 25 |
return pickle.load(f)
|
| 26 |
|
| 27 |
+
def filter_graph(graph, node_threshold=10, edge_threshold=5):
|
| 28 |
+
"""
|
| 29 |
+
Filters the graph to include only popular nodes and edges.
|
| 30 |
+
|
| 31 |
+
Parameters:
|
| 32 |
+
graph (nx.DiGraph): The original graph.
|
| 33 |
+
node_threshold (int): Minimum degree for a node to be included.
|
| 34 |
+
edge_threshold (int): Minimum weight for an edge to be included.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
nx.DiGraph: A filtered graph with popular nodes and edges.
|
| 38 |
+
"""
|
| 39 |
+
# Identify popular nodes based on their degree
|
| 40 |
+
popular_nodes = [
|
| 41 |
+
node for node in graph.nodes
|
| 42 |
+
if graph.degree(node) >= node_threshold
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
# Create a subgraph with only popular nodes
|
| 46 |
+
filtered_graph = graph.subgraph(popular_nodes).copy()
|
| 47 |
+
|
| 48 |
+
# Remove edges that don't meet the weight threshold
|
| 49 |
+
for u, v, data in list(filtered_graph.edges(data=True)):
|
| 50 |
+
if data.get("weight", 0) < edge_threshold:
|
| 51 |
+
filtered_graph.remove_edge(u, v)
|
| 52 |
+
|
| 53 |
+
return filtered_graph
|
| 54 |
+
|
| 55 |
def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5):
|
| 56 |
net = Network(notebook=False, width="100%", height="600px", directed=True)
|
| 57 |
net.set_options("""
|
|
|
|
| 104 |
html_content = net.generate_html()
|
| 105 |
st.components.v1.html(html_content, height=600, scrolling=False)
|
| 106 |
|
| 107 |
+
def display_node_info(graph, node_id):
|
| 108 |
+
"""
|
| 109 |
+
Display all attributes of a node and its edges in the graph.
|
| 110 |
+
|
| 111 |
+
Parameters:
|
| 112 |
+
graph (nx.DiGraph): The graph containing the node.
|
| 113 |
+
node_id (int or str): The ID of the node to inspect.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
None
|
| 117 |
+
"""
|
| 118 |
+
if node_id not in graph:
|
| 119 |
+
print(f"Node {node_id} does not exist in the graph.")
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
# Display node attributes
|
| 123 |
+
print(f"Attributes of node {node_id}:")
|
| 124 |
+
for attr, value in graph.nodes[node_id].items():
|
| 125 |
+
print(f" {attr}: {value}")
|
| 126 |
+
|
| 127 |
+
# Display incoming edges
|
| 128 |
+
print(f"\nIncoming edges to node {node_id}:")
|
| 129 |
+
for u, v, data in graph.in_edges(node_id, data=True):
|
| 130 |
+
print(f" From {u} to {v} with attributes: {data}")
|
| 131 |
+
|
| 132 |
+
# Display outgoing edges
|
| 133 |
+
print(f"\nOutgoing edges from node {node_id}:")
|
| 134 |
+
for u, v, data in graph.out_edges(node_id, data=True):
|
| 135 |
+
print(f" From {u} to {v} with attributes: {data}")
|
| 136 |
+
|
| 137 |
|
| 138 |
# Streamlit interface
|
| 139 |
+
st.title("Interactive Graph Expansion with Tooltips")
|
| 140 |
+
|
| 141 |
+
# Brand Selection
|
| 142 |
+
selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys()))
|
| 143 |
|
| 144 |
+
# Load the graph for the selected brand
|
| 145 |
+
G = load_graph(selected_brand)
|
| 146 |
|
| 147 |
+
# Filter the graph for popular nodes and edges
|
| 148 |
+
node_degree_threshold = 1 # Minimum degree for nodes
|
| 149 |
+
edge_weight_threshold = 1 # Minimum weight for edges
|
| 150 |
+
G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold)
|
| 151 |
+
#print('spot check degree', G.degree(389062))
|
| 152 |
+
#print('spot check degree', G_filtered.degree(389062))
|
| 153 |
+
#display_node_info(G_filtered, 389062)
|
| 154 |
import random
|
| 155 |
# Randomly sample a starting node
|
| 156 |
# Initialize session state for the starting node
|
| 157 |
+
#if "start_node" not in st.session_state:
|
| 158 |
+
st.session_state.start_node = random.choice(list(G_filtered.nodes)) # Randomly select a node once
|
| 159 |
|
| 160 |
# Input: Starting node
|
| 161 |
start_node = st.number_input(
|
|
|
|
| 168 |
|
| 169 |
# Trigger the visualization
|
| 170 |
if st.button("Expand Graph"):
|
| 171 |
+
if start_node in G_filtered:
|
| 172 |
+
dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k)
|
| 173 |
else:
|
| 174 |
st.error("The starting node is not in the graph!")
|