gmedin commited on
Commit
2aa8945
·
verified ·
1 Parent(s): a6a71d2

Added all brands

Browse files
Files changed (1) hide show
  1. app.py +92 -10
app.py CHANGED
@@ -4,12 +4,54 @@ from pyvis.network import Network
4
  import pickle
5
  import math
6
 
7
- # Load the graph from a pickle file
 
 
 
 
 
 
 
8
  @st.cache_resource
9
- def load_graph():
10
- with open('drumeo_multi_student_graph_labels.pkl', 'rb') as f:
 
 
 
 
 
 
 
11
  return pickle.load(f)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5):
14
  net = Network(notebook=False, width="100%", height="600px", directed=True)
15
  net.set_options("""
@@ -62,18 +104,58 @@ def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5):
62
  html_content = net.generate_html()
63
  st.components.v1.html(html_content, height=600, scrolling=False)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Streamlit interface
67
- st.title("Popular Path Expander")
 
 
 
68
 
69
- # Load the graph
70
- G = load_graph()
71
 
 
 
 
 
 
 
 
72
  import random
73
  # Randomly sample a starting node
74
  # Initialize session state for the starting node
75
- if "start_node" not in st.session_state:
76
- st.session_state.start_node = random.choice(list(G.nodes)) # Randomly select a node once
77
 
78
  # Input: Starting node
79
  start_node = st.number_input(
@@ -86,7 +168,7 @@ top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
86
 
87
  # Trigger the visualization
88
  if st.button("Expand Graph"):
89
- if start_node in G:
90
- dynamic_visualize_graph(G, start_node, layers=layers, top_k=top_k)
91
  else:
92
  st.error("The starting node is not in the graph!")
 
4
  import pickle
5
  import math
6
 
7
+ # Dictionary to map brands to their respective pickle files
8
+ BRAND_GRAPHS = {
9
+ 'drumeo': 'drumeo_pop_items_labels.pkl',
10
+ 'pianote': 'pianote_pop_items_labels.pkl',
11
+ 'singeo': 'singeo_pop_items_labels.pkl',
12
+ 'guitareo': 'guitareo_pop_items_labels.pkl'
13
+ }
14
+
15
  @st.cache_resource
16
+ def load_graph(brand):
17
+ """
18
+ Load the graph for the selected brand.
19
+ Parameters:
20
+ brand (str): The brand name corresponding to the graph to load.
21
+ Returns:
22
+ nx.DiGraph: The loaded graph.
23
+ """
24
+ with open(BRAND_GRAPHS[brand], 'rb') as f:
25
  return pickle.load(f)
26
 
27
+ def filter_graph(graph, node_threshold=10, edge_threshold=5):
28
+ """
29
+ Filters the graph to include only popular nodes and edges.
30
+
31
+ Parameters:
32
+ graph (nx.DiGraph): The original graph.
33
+ node_threshold (int): Minimum degree for a node to be included.
34
+ edge_threshold (int): Minimum weight for an edge to be included.
35
+
36
+ Returns:
37
+ nx.DiGraph: A filtered graph with popular nodes and edges.
38
+ """
39
+ # Identify popular nodes based on their degree
40
+ popular_nodes = [
41
+ node for node in graph.nodes
42
+ if graph.degree(node) >= node_threshold
43
+ ]
44
+
45
+ # Create a subgraph with only popular nodes
46
+ filtered_graph = graph.subgraph(popular_nodes).copy()
47
+
48
+ # Remove edges that don't meet the weight threshold
49
+ for u, v, data in list(filtered_graph.edges(data=True)):
50
+ if data.get("weight", 0) < edge_threshold:
51
+ filtered_graph.remove_edge(u, v)
52
+
53
+ return filtered_graph
54
+
55
  def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5):
56
  net = Network(notebook=False, width="100%", height="600px", directed=True)
57
  net.set_options("""
 
104
  html_content = net.generate_html()
105
  st.components.v1.html(html_content, height=600, scrolling=False)
106
 
107
+ def display_node_info(graph, node_id):
108
+ """
109
+ Display all attributes of a node and its edges in the graph.
110
+
111
+ Parameters:
112
+ graph (nx.DiGraph): The graph containing the node.
113
+ node_id (int or str): The ID of the node to inspect.
114
+
115
+ Returns:
116
+ None
117
+ """
118
+ if node_id not in graph:
119
+ print(f"Node {node_id} does not exist in the graph.")
120
+ return
121
+
122
+ # Display node attributes
123
+ print(f"Attributes of node {node_id}:")
124
+ for attr, value in graph.nodes[node_id].items():
125
+ print(f" {attr}: {value}")
126
+
127
+ # Display incoming edges
128
+ print(f"\nIncoming edges to node {node_id}:")
129
+ for u, v, data in graph.in_edges(node_id, data=True):
130
+ print(f" From {u} to {v} with attributes: {data}")
131
+
132
+ # Display outgoing edges
133
+ print(f"\nOutgoing edges from node {node_id}:")
134
+ for u, v, data in graph.out_edges(node_id, data=True):
135
+ print(f" From {u} to {v} with attributes: {data}")
136
+
137
 
138
  # Streamlit interface
139
+ st.title("Interactive Graph Expansion with Tooltips")
140
+
141
+ # Brand Selection
142
+ selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys()))
143
 
144
+ # Load the graph for the selected brand
145
+ G = load_graph(selected_brand)
146
 
147
+ # Filter the graph for popular nodes and edges
148
+ node_degree_threshold = 1 # Minimum degree for nodes
149
+ edge_weight_threshold = 1 # Minimum weight for edges
150
+ G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold)
151
+ #print('spot check degree', G.degree(389062))
152
+ #print('spot check degree', G_filtered.degree(389062))
153
+ #display_node_info(G_filtered, 389062)
154
  import random
155
  # Randomly sample a starting node
156
  # Initialize session state for the starting node
157
+ #if "start_node" not in st.session_state:
158
+ st.session_state.start_node = random.choice(list(G_filtered.nodes)) # Randomly select a node once
159
 
160
  # Input: Starting node
161
  start_node = st.number_input(
 
168
 
169
  # Trigger the visualization
170
  if st.button("Expand Graph"):
171
+ if start_node in G_filtered:
172
+ dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k)
173
  else:
174
  st.error("The starting node is not in the graph!")