gmedin commited on
Commit
c0428c6
·
verified ·
1 Parent(s): 7d63ff8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -6,27 +6,43 @@ import math
6
  import random
7
  import requests
8
  import os
 
9
 
10
- # Dictionary to map brands to their respective pickle files
11
  BRAND_GRAPHS = {
12
- 'drumeo': 'drumeo_pop_items_labels.pkl',
13
- 'pianote': 'pianote_pop_items_labels.pkl',
14
- 'singeo': 'singeo_pop_items_labels.pkl',
15
- 'guitareo': 'guitareo_pop_items_labels.pkl'
16
  }
17
 
18
- # API Authorization Token
 
19
  AUTH_TOKEN = os.getenv('HF_TOKEN')
20
-
21
  API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/"
22
 
23
 
24
  @st.cache_resource
25
- def load_graph(brand):
26
- with open(BRAND_GRAPHS[brand], 'rb') as f:
27
- return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def filter_graph(graph, node_threshold=10, edge_threshold=5):
 
 
 
30
  popular_nodes = [
31
  node for node in graph.nodes
32
  if graph.degree(node) >= node_threshold
@@ -40,6 +56,7 @@ def filter_graph(graph, node_threshold=10, edge_threshold=5):
40
 
41
  return filtered_graph
42
 
 
43
  def get_rankings_from_api(brand, user_id, content_ids):
44
  """
45
  Call the rank_items API to fetch rankings for the given user and content IDs.
@@ -68,17 +85,10 @@ def rank_to_color(rank, max_rank):
68
  """
69
  Map a rank to a grayscale color, where dark gray indicates high relevance (low rank),
70
  and light gray indicates low relevance (high rank).
71
-
72
- Parameters:
73
- rank (int): The rank of the item.
74
- max_rank (int): The maximum rank (for normalizing the gradient).
75
-
76
- Returns:
77
- str: Hex color code for the grayscale shade.
78
  """
79
  if rank > max_rank: # Handle items without ranking
80
  return "#E8E8E8" # Very light gray for unranked items
81
- intensity = int(55 + (rank / max_rank) * 200) # Scale intensity (darker for lower ranks)
82
  return f"rgb({intensity}, {intensity}, {intensity})" # Grayscale
83
 
84
 
@@ -178,14 +188,14 @@ selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys(
178
 
179
  if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand:
180
  st.session_state.selected_brand = selected_brand
181
- G = load_graph(selected_brand)
182
 
183
  # Sort nodes by popularity (in-degree + out-degree) and select from top 20
184
  popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True)
185
  top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes
186
  st.session_state.start_node = random.choice(top_20_nodes)
187
  else:
188
- G = load_graph(selected_brand)
189
 
190
  # Random Selection Button
191
  if st.button("Random Selection"):
@@ -201,7 +211,6 @@ start_node = st.number_input(
201
  # Input: Student ID
202
  student_id = st.text_input("Enter a student ID (optional):", value="")
203
 
204
-
205
  # Toggle for showing content titles
206
  show_titles = st.checkbox("Show content titles", value=False)
207
 
@@ -217,7 +226,7 @@ if student_id:
217
  rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids)
218
  if rankings:
219
  rankings = rankings['ranked_content_ids']
220
- #print(rankings)
221
  layers = st.slider("Depth to explore:", 1, 6, value=3)
222
  top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
223
 
 
6
  import random
7
  import requests
8
  import os
9
+ from huggingface_hub import hf_hub_download
10
 
11
+ # Dictionary to map brands to their respective HuggingFace model repo files
12
  BRAND_GRAPHS = {
13
+ 'drumeo': 'drumeo_graph.pkl',
14
+ 'pianote': 'pianote_graph.pkl',
15
+ 'singeo': 'singeo_graph.pkl',
16
+ 'guitareo': 'guitareo_graph.pkl'
17
  }
18
 
19
+ # HuggingFace Repository Info
20
+ #HF_REPO = "MusoraProductDepartment/popular-path-graphs"
21
  AUTH_TOKEN = os.getenv('HF_TOKEN')
 
22
  API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/"
23
 
24
 
25
  @st.cache_resource
26
+ def load_graph_from_hf(brand):
27
+ """
28
+ Load the graph for the selected brand from HuggingFace Hub.
29
+ """
30
+ try:
31
+ # Download the file from HuggingFace Hub
32
+ HF_REPO = '{brand}-graph'
33
+ file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], use_auth_token=AUTH_TOKEN)
34
+ # Load the graph
35
+ with open(file_path, 'rb') as f:
36
+ return pickle.load(f)
37
+ except Exception as e:
38
+ st.error(f"Error loading graph from HuggingFace: {e}")
39
+ return None
40
+
41
 
42
  def filter_graph(graph, node_threshold=10, edge_threshold=5):
43
+ """
44
+ Filters the graph to include only popular nodes and edges.
45
+ """
46
  popular_nodes = [
47
  node for node in graph.nodes
48
  if graph.degree(node) >= node_threshold
 
56
 
57
  return filtered_graph
58
 
59
+
60
  def get_rankings_from_api(brand, user_id, content_ids):
61
  """
62
  Call the rank_items API to fetch rankings for the given user and content IDs.
 
85
  """
86
  Map a rank to a grayscale color, where dark gray indicates high relevance (low rank),
87
  and light gray indicates low relevance (high rank).
 
 
 
 
 
 
 
88
  """
89
  if rank > max_rank: # Handle items without ranking
90
  return "#E8E8E8" # Very light gray for unranked items
91
+ intensity = int(255 - (rank / max_rank) * 200) # Darker for lower ranks
92
  return f"rgb({intensity}, {intensity}, {intensity})" # Grayscale
93
 
94
 
 
188
 
189
  if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand:
190
  st.session_state.selected_brand = selected_brand
191
+ G = load_graph_from_hf(selected_brand)
192
 
193
  # Sort nodes by popularity (in-degree + out-degree) and select from top 20
194
  popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True)
195
  top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes
196
  st.session_state.start_node = random.choice(top_20_nodes)
197
  else:
198
+ G = load_graph_from_hf(selected_brand)
199
 
200
  # Random Selection Button
201
  if st.button("Random Selection"):
 
211
  # Input: Student ID
212
  student_id = st.text_input("Enter a student ID (optional):", value="")
213
 
 
214
  # Toggle for showing content titles
215
  show_titles = st.checkbox("Show content titles", value=False)
216
 
 
226
  rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids)
227
  if rankings:
228
  rankings = rankings['ranked_content_ids']
229
+
230
  layers = st.slider("Depth to explore:", 1, 6, value=3)
231
  top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
232