| import streamlit as st |
| import networkx as nx |
| from pyvis.network import Network |
| import pickle |
| import math |
| import random |
| import requests |
| import os |
| from huggingface_hub import hf_hub_download |
|
|
| |
| BRAND_GRAPHS = { |
| 'drumeo': 'drumeo_graph.pkl', |
| 'pianote': 'pianote_graph.pkl', |
| 'singeo': 'singeo_graph.pkl', |
| 'guitareo': 'guitareo_graph.pkl' |
| } |
|
|
| |
| |
| AUTH_TOKEN = os.getenv('HF_TOKEN') |
| API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/" |
|
|
|
|
| @st.cache_resource |
| def load_graph_from_hf(brand): |
| """ |
| Load the graph for the selected brand from HuggingFace Hub. |
| """ |
| try: |
| |
| HF_REPO = f'MusoraProductDepartment/{brand}-graph' |
| cache_dir = '/tmp' |
| file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], token=AUTH_TOKEN, cache_dir=cache_dir, repo_type='model') |
| |
| with open(file_path, 'rb') as f: |
| return pickle.load(f) |
| except Exception as e: |
| st.error(f"Error loading graph from HuggingFace: {e}") |
| return None |
|
|
|
|
| def filter_graph(graph, node_threshold=10, edge_threshold=5): |
| """ |
| Filters the graph to include only popular nodes and edges. |
| """ |
| popular_nodes = [ |
| node for node in graph.nodes |
| if graph.degree(node) >= node_threshold |
| ] |
|
|
| filtered_graph = graph.subgraph(popular_nodes).copy() |
|
|
| for u, v, data in list(filtered_graph.edges(data=True)): |
| if data.get("weight", 0) < edge_threshold: |
| filtered_graph.remove_edge(u, v) |
|
|
| return filtered_graph |
|
|
|
|
| def get_rankings_from_api(brand, user_id, content_ids): |
| """ |
| Call the rank_items API to fetch rankings for the given user and content IDs. |
| """ |
| try: |
| payload = { |
| "brand": brand.upper(), |
| "user_id": int(user_id), |
| "content_ids": [int(content_id) for content_id in content_ids] |
| } |
| headers = { |
| "Authorization": f"Bearer {AUTH_TOKEN}", |
| "accept": "application/json", |
| "Content-Type": "application/json" |
| } |
| response = requests.post(API_URL, json=payload, headers=headers) |
| response.raise_for_status() |
| rankings = response.json() |
| return rankings |
| except Exception as e: |
| st.error(f"Error calling rank_items API: {e}") |
| return {} |
|
|
|
|
| def rank_to_color(rank, max_rank): |
| """ |
| Map a rank to a grayscale color, where dark gray indicates high relevance (low rank), |
| and light gray indicates low relevance (high rank). |
| """ |
| if rank > max_rank: |
| return "#E8E8E8" |
| intensity = int(55 + (rank / max_rank) * 200) |
| return f"rgb({intensity}, {intensity}, {intensity})" |
|
|
|
|
| def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5, show_titles=False, rankings=None): |
| net = Network(notebook=False, width="100%", height="600px", directed=True) |
| net.set_options(""" |
| var options = { |
| "physics": { |
| "barnesHut": { |
| "gravitationalConstant": -15000, |
| "centralGravity": 0.8 |
| } |
| } |
| } |
| """) |
|
|
| visited_nodes = set() |
| added_edges = set() |
| current_nodes = [int(start_node)] |
|
|
| max_rank = len(rankings) if rankings else 0 |
|
|
| |
| start_title = graph.nodes[int(start_node)].get('title', 'No title available') |
| start_in_degree = graph.in_degree(int(start_node)) |
| start_out_degree = graph.out_degree(int(start_node)) |
| start_node_size = (start_in_degree + start_out_degree) * 0.15 |
| start_rank = rankings.index(int(start_node)) if rankings and int(start_node) in rankings else max_rank + 1 |
| if rankings: |
| start_border_color = rank_to_color(start_rank, max_rank) |
| else: |
| start_border_color = 'darkblue' |
| label = str(start_node) if not show_titles else f"{str(start_node)}: {start_title[:15]}..." |
| net.add_node( |
| int(start_node), |
| label=label, |
| color={"background": "darkblue", "border": start_border_color}, |
| title=f"{start_title}, In-degree: {start_in_degree}, Out-degree: {start_out_degree}, Rank: {start_rank}", |
| size=start_node_size, |
| borderWidth=3, |
| borderWidthSelected=6 |
| ) |
| visited_nodes.add(int(start_node)) |
|
|
| for layer in range(layers): |
| next_nodes = [] |
| for node in current_nodes: |
| neighbors = sorted( |
| [(int(neighbor), data['weight']) for neighbor, data in graph[node].items()], |
| key=lambda x: x[1], |
| reverse=True |
| )[:top_k] |
|
|
| for neighbor, weight in neighbors: |
| if neighbor not in visited_nodes: |
| neighbor_title = graph.nodes[neighbor].get('title', 'No title available') |
| neighbor_in_degree = graph.in_degree(neighbor) |
| neighbor_out_degree = graph.out_degree(neighbor) |
| neighbor_size = (neighbor_in_degree + neighbor_out_degree) * 0.15 |
| neighbor_rank = rankings.index(neighbor) if rankings and neighbor in rankings else max_rank + 1 |
|
|
| node_color = 'red' if neighbor_in_degree > neighbor_out_degree * 1.5 else \ |
| 'green' if neighbor_out_degree > neighbor_in_degree * 1.5 else 'lightblue' |
| if rankings: |
| neighbor_border_color = rank_to_color(neighbor_rank, max_rank) |
| else: |
| neighbor_border_color = node_color |
|
|
| label = str(neighbor) if not show_titles else f"{str(neighbor)}: {neighbor_title[:15]}..." |
| net.add_node( |
| neighbor, |
| label=label, |
| title=f"{neighbor_title}, In-degree: {neighbor_in_degree}, Out-degree: {neighbor_out_degree}, Rank: {neighbor_rank}", |
| size=neighbor_size, |
| color={"background": node_color, "border": neighbor_border_color}, |
| borderWidth=3, |
| borderWidthSelected=6 |
| ) |
| edge = (node, neighbor) |
| if edge not in added_edges: |
| edge_width = math.log(weight + 1) * 8 |
| net.add_edge(node, neighbor, label=f"w:{weight}", width=edge_width, color='lightgray') |
| added_edges.add(edge) |
| visited_nodes.add(neighbor) |
| next_nodes.append(neighbor) |
|
|
| current_nodes = next_nodes |
|
|
| html_content = net.generate_html() |
| st.components.v1.html(html_content, height=600, scrolling=False) |
|
|
|
|
| st.title("Popular Path Expansion + Personalization") |
|
|
| |
| selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys())) |
|
|
| if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand: |
| st.session_state.selected_brand = selected_brand |
| G = load_graph_from_hf(selected_brand) |
|
|
| |
| popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True) |
| top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes |
| st.session_state.start_node = random.choice(top_20_nodes) |
| else: |
| G = load_graph_from_hf(selected_brand) |
|
|
| |
| if st.button("Random Selection"): |
| st.session_state.start_node = random.choice(list(G.nodes)) |
|
|
| |
| start_node = st.number_input( |
| "Enter the starting node ID:", |
| value=st.session_state.start_node, |
| step=1 |
| ) |
|
|
| |
| student_id = st.text_input("Enter a student ID (optional):", value="") |
|
|
| |
| show_titles = st.checkbox("Show content titles", value=False) |
|
|
| |
| node_degree_threshold = 1 |
| edge_weight_threshold = 1 |
| G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold) |
|
|
| |
| rankings = {} |
| if student_id: |
| content_ids = list(G_filtered.nodes) |
| rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids) |
| if rankings: |
| rankings = rankings['ranked_content_ids'] |
|
|
| layers = st.slider("Depth to explore:", 1, 6, value=3) |
| top_k = st.slider("Branching factor (per node):", 1, 6, value=3) |
|
|
| if st.button("Expand Graph"): |
| if start_node in G_filtered: |
| dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k, show_titles=show_titles, rankings=rankings) |
| else: |
| st.error("The starting node is not in the graph!") |
|
|