Update app.py
Browse files
app.py
CHANGED
|
@@ -6,27 +6,43 @@ import math
|
|
| 6 |
import random
|
| 7 |
import requests
|
| 8 |
import os
|
|
|
|
| 9 |
|
| 10 |
-
# Dictionary to map brands to their respective
|
| 11 |
BRAND_GRAPHS = {
|
| 12 |
-
'drumeo': '
|
| 13 |
-
'pianote': '
|
| 14 |
-
'singeo': '
|
| 15 |
-
'guitareo': '
|
| 16 |
}
|
| 17 |
|
| 18 |
-
#
|
|
|
|
| 19 |
AUTH_TOKEN = os.getenv('HF_TOKEN')
|
| 20 |
-
|
| 21 |
API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/"
|
| 22 |
|
| 23 |
|
| 24 |
@st.cache_resource
|
| 25 |
-
def
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def filter_graph(graph, node_threshold=10, edge_threshold=5):
|
|
|
|
|
|
|
|
|
|
| 30 |
popular_nodes = [
|
| 31 |
node for node in graph.nodes
|
| 32 |
if graph.degree(node) >= node_threshold
|
|
@@ -40,6 +56,7 @@ def filter_graph(graph, node_threshold=10, edge_threshold=5):
|
|
| 40 |
|
| 41 |
return filtered_graph
|
| 42 |
|
|
|
|
| 43 |
def get_rankings_from_api(brand, user_id, content_ids):
|
| 44 |
"""
|
| 45 |
Call the rank_items API to fetch rankings for the given user and content IDs.
|
|
@@ -68,17 +85,10 @@ def rank_to_color(rank, max_rank):
|
|
| 68 |
"""
|
| 69 |
Map a rank to a grayscale color, where dark gray indicates high relevance (low rank),
|
| 70 |
and light gray indicates low relevance (high rank).
|
| 71 |
-
|
| 72 |
-
Parameters:
|
| 73 |
-
rank (int): The rank of the item.
|
| 74 |
-
max_rank (int): The maximum rank (for normalizing the gradient).
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
str: Hex color code for the grayscale shade.
|
| 78 |
"""
|
| 79 |
if rank > max_rank: # Handle items without ranking
|
| 80 |
return "#E8E8E8" # Very light gray for unranked items
|
| 81 |
-
intensity = int(
|
| 82 |
return f"rgb({intensity}, {intensity}, {intensity})" # Grayscale
|
| 83 |
|
| 84 |
|
|
@@ -178,14 +188,14 @@ selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys(
|
|
| 178 |
|
| 179 |
if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand:
|
| 180 |
st.session_state.selected_brand = selected_brand
|
| 181 |
-
G =
|
| 182 |
|
| 183 |
# Sort nodes by popularity (in-degree + out-degree) and select from top 20
|
| 184 |
popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True)
|
| 185 |
top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes
|
| 186 |
st.session_state.start_node = random.choice(top_20_nodes)
|
| 187 |
else:
|
| 188 |
-
G =
|
| 189 |
|
| 190 |
# Random Selection Button
|
| 191 |
if st.button("Random Selection"):
|
|
@@ -201,7 +211,6 @@ start_node = st.number_input(
|
|
| 201 |
# Input: Student ID
|
| 202 |
student_id = st.text_input("Enter a student ID (optional):", value="")
|
| 203 |
|
| 204 |
-
|
| 205 |
# Toggle for showing content titles
|
| 206 |
show_titles = st.checkbox("Show content titles", value=False)
|
| 207 |
|
|
@@ -217,7 +226,7 @@ if student_id:
|
|
| 217 |
rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids)
|
| 218 |
if rankings:
|
| 219 |
rankings = rankings['ranked_content_ids']
|
| 220 |
-
|
| 221 |
layers = st.slider("Depth to explore:", 1, 6, value=3)
|
| 222 |
top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
|
| 223 |
|
|
|
|
| 6 |
import random
|
| 7 |
import requests
|
| 8 |
import os
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
|
| 11 |
+
# Dictionary to map brands to their respective HuggingFace model repo files
|
| 12 |
BRAND_GRAPHS = {
|
| 13 |
+
'drumeo': 'drumeo_graph.pkl',
|
| 14 |
+
'pianote': 'pianote_graph.pkl',
|
| 15 |
+
'singeo': 'singeo_graph.pkl',
|
| 16 |
+
'guitareo': 'guitareo_graph.pkl'
|
| 17 |
}
|
| 18 |
|
| 19 |
+
# HuggingFace Repository Info
|
| 20 |
+
#HF_REPO = "MusoraProductDepartment/popular-path-graphs"
|
| 21 |
AUTH_TOKEN = os.getenv('HF_TOKEN')
|
|
|
|
| 22 |
API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/"
|
| 23 |
|
| 24 |
|
| 25 |
@st.cache_resource
|
| 26 |
+
def load_graph_from_hf(brand):
|
| 27 |
+
"""
|
| 28 |
+
Load the graph for the selected brand from HuggingFace Hub.
|
| 29 |
+
"""
|
| 30 |
+
try:
|
| 31 |
+
# Download the file from HuggingFace Hub
|
| 32 |
+
HF_REPO = '{brand}-graph'
|
| 33 |
+
file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], use_auth_token=AUTH_TOKEN)
|
| 34 |
+
# Load the graph
|
| 35 |
+
with open(file_path, 'rb') as f:
|
| 36 |
+
return pickle.load(f)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
st.error(f"Error loading graph from HuggingFace: {e}")
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
|
| 42 |
def filter_graph(graph, node_threshold=10, edge_threshold=5):
|
| 43 |
+
"""
|
| 44 |
+
Filters the graph to include only popular nodes and edges.
|
| 45 |
+
"""
|
| 46 |
popular_nodes = [
|
| 47 |
node for node in graph.nodes
|
| 48 |
if graph.degree(node) >= node_threshold
|
|
|
|
| 56 |
|
| 57 |
return filtered_graph
|
| 58 |
|
| 59 |
+
|
| 60 |
def get_rankings_from_api(brand, user_id, content_ids):
|
| 61 |
"""
|
| 62 |
Call the rank_items API to fetch rankings for the given user and content IDs.
|
|
|
|
| 85 |
"""
|
| 86 |
Map a rank to a grayscale color, where dark gray indicates high relevance (low rank),
|
| 87 |
and light gray indicates low relevance (high rank).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
if rank > max_rank: # Handle items without ranking
|
| 90 |
return "#E8E8E8" # Very light gray for unranked items
|
| 91 |
+
intensity = int(255 - (rank / max_rank) * 200) # Darker for lower ranks
|
| 92 |
return f"rgb({intensity}, {intensity}, {intensity})" # Grayscale
|
| 93 |
|
| 94 |
|
|
|
|
| 188 |
|
| 189 |
if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand:
|
| 190 |
st.session_state.selected_brand = selected_brand
|
| 191 |
+
G = load_graph_from_hf(selected_brand)
|
| 192 |
|
| 193 |
# Sort nodes by popularity (in-degree + out-degree) and select from top 20
|
| 194 |
popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True)
|
| 195 |
top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes
|
| 196 |
st.session_state.start_node = random.choice(top_20_nodes)
|
| 197 |
else:
|
| 198 |
+
G = load_graph_from_hf(selected_brand)
|
| 199 |
|
| 200 |
# Random Selection Button
|
| 201 |
if st.button("Random Selection"):
|
|
|
|
| 211 |
# Input: Student ID
|
| 212 |
student_id = st.text_input("Enter a student ID (optional):", value="")
|
| 213 |
|
|
|
|
| 214 |
# Toggle for showing content titles
|
| 215 |
show_titles = st.checkbox("Show content titles", value=False)
|
| 216 |
|
|
|
|
| 226 |
rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids)
|
| 227 |
if rankings:
|
| 228 |
rankings = rankings['ranked_content_ids']
|
| 229 |
+
|
| 230 |
layers = st.slider("Depth to explore:", 1, 6, value=3)
|
| 231 |
top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
|
| 232 |
|