Spaces:
Sleeping
Sleeping
| import hashlib | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import requests | |
| import networkx as nx | |
| from pyvis.network import Network | |
| import streamlit.components.v1 as components | |
| from pyspark.sql import SparkSession | |
| from pyspark.conf import SparkConf | |
| from pyspark.sql.functions import col | |
| from pyspark.sql.types import IntegerType | |
| from pyspark.ml.feature import StringIndexerModel | |
| from pyspark.ml.recommendation import ALSModel | |
| def get_spark(): | |
| conf = ( | |
| SparkConf() | |
| .setAppName("SMM2_Inference") | |
| .setMaster("local[*]") | |
| .set("spark.driver.memory", "8g") | |
| .set("spark.driver.maxResultSize", "2g") | |
| .set("spark.sql.shuffle.partitions", "8") | |
| .set("spark.memory.fraction", "0.6") | |
| ) | |
| spark = SparkSession.builder.config(conf=conf).getOrCreate() | |
| spark.sparkContext.setLogLevel("WARN") | |
| return spark | |
| def load_models(): | |
| pid_indexer = StringIndexerModel.load("./src/models/pid_indexer") | |
| als_model = ALSModel.load("./src/models/als_smm2") | |
| return pid_indexer, als_model | |
| def load_graph_data(): | |
| nodes_df = pd.read_csv("./src/models/graph/nodes.csv") | |
| edges_df = pd.read_csv("./src/models/graph/edges.csv") | |
| cluster_df = pd.read_csv("./src/models/graph/cluster_names.csv") | |
| # Unir group_name a los nodos via label → cluster_id | |
| nodes_df = nodes_df.merge( | |
| cluster_df[["cluster_id", "group_name"]], | |
| left_on="label", | |
| right_on="cluster_id", | |
| how="left", | |
| ).drop(columns=["cluster_id"]) | |
| return nodes_df, edges_df | |
| spark = get_spark() | |
| pid_indexer_model, model = load_models() | |
| charset = "0123456789BCDFGHJKLMNPQRSTVWXY" | |
| def data_id_to_course_id(data_id: int) -> str: | |
| theXOR = 0b00010110100000001110000001111100 | |
| fieldA = 0b1000 | |
| fieldD = 0b0 | |
| fieldE = 0b1 | |
| fieldB = (data_id - 31) % 64 | |
| exed = data_id ^ theXOR | |
| fieldC = exed & 0b00000000000011111111111111111111 | |
| fieldF = exed >> 20 | |
| intermediate = (fieldA << 40) + (fieldB << 34) + (fieldC << 14) + (fieldD << 13) + (fieldE << 12) + fieldF | |
| course_id = "" | |
| while intermediate > 0: | |
| course_id += charset[intermediate % 30] | |
| intermediate //= 30 | |
| return course_id | |
| def maker_id_course_recommendations(maker_id: str): | |
| maker_data = requests.get(f"https://tgrcode.com/mm2/user_info/{maker_id}").json() | |
| pid = str(maker_data["pid"]) | |
| user_df = spark.createDataFrame([(pid,)], ["pid"]) | |
| user_indexed = pid_indexer_model.transform(user_df) | |
| user_indexed = user_indexed.withColumn("pidx", col("pidx").cast(IntegerType())) | |
| recs = model.recommendForUserSubset(user_indexed.select("pidx"), 10) | |
| row = recs.collect()[0] | |
| recommended_ids = [rec.data_id for rec in row.recommendations] | |
| recommended_courses = [data_id_to_course_id(rec) for rec in recommended_ids] | |
| return [f"https://makercentral.io/levels/view/{rec}" for rec in recommended_courses] | |
| def maker_id_course_likes(maker_id: str): | |
| maker_likes = requests.get(f"https://tgrcode.com/mm2/get_liked/{maker_id}").json() | |
| liked_courses = [like["course_id"] for like in maker_likes["courses"]] | |
| return [f"https://makercentral.io/levels/view/{rec}" for rec in liked_courses] | |
| def get_color_for_label(label: str) -> str: | |
| hash_val = int(hashlib.md5(str(label).encode()).hexdigest()[:6], 16) | |
| r = (hash_val >> 16) & 0xFF | |
| g = (hash_val >> 8) & 0xFF | |
| b = hash_val & 0xFF | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| def build_graph(nodes_df: pd.DataFrame, edges_df: pd.DataFrame, n_edges: int = 300, seed: int = 42) -> str: | |
| # Muestreo aleatorio de aristas | |
| sampled_edges = edges_df.sample(n=min(n_edges, len(edges_df)), random_state=seed) | |
| active_ids = set(sampled_edges["src"]).union(set(sampled_edges["dst"])) | |
| active_nodes = nodes_df[nodes_df["id"].isin(active_ids)] | |
| G = nx.Graph() | |
| for _, row in active_nodes.iterrows(): | |
| level_name = row.get("level_name", "") or "" | |
| group_name = row.get("group_name", "") or "" | |
| tooltip = f"<b>{level_name}</b><br>Grupo: {group_name}<br>ID: {row['id']}" | |
| G.add_node( | |
| int(row["id"]), | |
| label=level_name[:20] or str(row["id"]), # label corto en el nodo | |
| title=tooltip, # tooltip al hacer hover | |
| color=get_color_for_label(row["label"]), | |
| group=group_name, | |
| ) | |
| for _, row in sampled_edges.iterrows(): | |
| G.add_edge(int(row["src"]), int(row["dst"]), weight=row["similarity"], width=row["similarity"] * 5) | |
| net = Network(height="600px", width="100%", notebook=False) | |
| net.from_nx(G) | |
| net.save_graph("graph.html") | |
| with open("graph.html", "r") as f: | |
| return f.read() | |
| # ── UI ──────────────────────────────────────────────────────────────────────── | |
| """ | |
| # Recomendación de Niveles Super Mario Maker 2 | |
| Desarrolle un sistema de recomendación de niveles para usuarios mediante el uso de un modelo ALS. | |
| Utilicé el (Mario Maker 2 Dataset)[https://tgrcode.com/posts/mario_maker_2_datasets] para este proyecto. | |
| Mas específicamente (TheGreatRambler/mm2_level_played)[https://huggingface.co/datasets/TheGreatRambler/mm2_level_played], ya que este contiene las conexiones de usuarios con los niveles que han jugado, y dos valores booleanos para identificar si el usuario completo el nivel, y si el usuario le dio like al nivel. | |
| """ | |
| tab_pred, tab_graph = st.tabs(["Predicción", "Grafo"]) | |
| with tab_pred: | |
| with st.form("level_recommendation_form"): | |
| maker_id = st.text_input("Ingresa tu Maker-ID (9 carácteres)", max_chars=9) | |
| submitted = st.form_submit_button("Buscar Recomendaciones") | |
| if submitted: | |
| if not maker_id or len(maker_id) < 9: | |
| st.error("El Maker-ID debe tener 9 carácteres.") | |
| else: | |
| try: | |
| course_rec = maker_id_course_recommendations(maker_id) | |
| course_like = maker_id_course_likes(maker_id) | |
| col_rec, col_like = st.columns(2) | |
| with col_rec: | |
| st.subheader("Niveles Recomendados") | |
| for url in course_rec: | |
| st.write(url) | |
| with col_like: | |
| st.subheader("Niveles Gustados por el Jugador") | |
| for url in course_like: | |
| st.write(url) | |
| except Exception as e: | |
| st.error(f"Error inesperado: {e}") | |
| with tab_graph: | |
| st.subheader("Grafo de Similitud entre Niveles") | |
| col_slider, col_seed = st.columns([3, 1]) | |
| with col_slider: | |
| n_edges = st.slider("Número de aristas a mostrar", min_value=50, max_value=5000, value=300, step=50) | |
| with col_seed: | |
| seed = st.number_input("Semilla aleatoria", min_value=0, max_value=9999, value=42, step=1) | |
| try: | |
| nodes_df, edges_df = load_graph_data() | |
| html = build_graph(nodes_df, edges_df, n_edges=n_edges, seed=seed) | |
| components.html(html, height=620) | |
| except Exception as e: | |
| st.error(f"Error cargando el grafo: {e}") | |
| st.exception(e) |