loopback / app.py
DanielRegaladoCardoso's picture
Upload app.py with huggingface_hub
f5cf7c3 verified
Raw
History Blame Contribute Delete
4.05 kB
"""Loopback — Gradio demo for the two-tower music recommender.
Pick tracks → average their embeddings → retrieve top-10 nearest neighbors with FAISS.
"""
from __future__ import annotations
import os
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
import faiss
import gradio as gr
import numpy as np
import polars as pl
import torch
from huggingface_hub import hf_hub_download
from model import TwoTower
MODEL_REPO = "DanielRegaladoCardoso/loopback-twotower"
DATASET_REPO = "DanielRegaladoCardoso/lastfm-1k-twotower"
CKPT_NAME = "two_tower_epoch3.pt"
def load_artifacts():
print("Downloading checkpoint + dataset metadata from HF Hub...")
ckpt_path = hf_hub_download(MODEL_REPO, CKPT_NAME)
train_path = hf_hub_download(DATASET_REPO, "train.parquet", repo_type="dataset")
vocab_path = hf_hub_download(DATASET_REPO, "vocab.parquet", repo_type="dataset")
labels_path = hf_hub_download(DATASET_REPO, "track_labels.parquet", repo_type="dataset")
vocab = pl.read_parquet(vocab_path).row(0, named=True)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model = TwoTower(
vocab["n_users"], vocab["n_tracks"], vocab["n_artists"], out_dim=ckpt["embed_dim"]
)
model.load_state_dict(ckpt["model"])
model.eval()
print("Building track_idx → artist_idx lookup...")
train_df = pl.read_parquet(train_path, columns=["track_idx", "artist_idx"])
artist_lookup = np.zeros(vocab["n_tracks"], dtype=np.int64)
for row in train_df.group_by("track_idx").agg(pl.col("artist_idx").first()).iter_rows():
artist_lookup[row[0]] = row[1]
del train_df
print(f"Embedding {vocab['n_tracks']:,} tracks...")
all_vecs = []
with torch.no_grad():
for start in range(0, vocab["n_tracks"], 8192):
end = min(start + 8192, vocab["n_tracks"])
t = torch.arange(start, end)
a = torch.from_numpy(artist_lookup[start:end]).long()
all_vecs.append(model.track_tower(t, a).numpy())
vecs = np.concatenate(all_vecs).astype("float32")
print("Building FAISS index...")
index = faiss.IndexFlatIP(vecs.shape[1])
index.add(vecs)
labels = pl.read_parquet(labels_path).sort("track_idx")["label"].to_list()
print(f"Ready: {vocab['n_tracks']:,} tracks indexed.")
return index, labels, vecs
INDEX, LABELS, TRACK_VECS = load_artifacts()
LABEL_TO_IDX = {label: i for i, label in enumerate(LABELS)}
SAMPLE_TRACKS = sorted(LABELS[::500][:3000]) # ~3k for the dropdown
def recommend(track_choices: list[str], k: int = 10) -> str:
if not track_choices:
return "Pick at least one track."
idxs = [LABEL_TO_IDX[c] for c in track_choices if c in LABEL_TO_IDX]
if not idxs:
return "None of those tracks matched."
seed = TRACK_VECS[idxs].mean(axis=0, keepdims=True)
seed = seed / np.linalg.norm(seed)
scores, neighbors = INDEX.search(seed.astype("float32"), k + len(idxs))
out = [(n, s) for n, s in zip(neighbors[0], scores[0]) if n not in idxs][:k]
return "\n".join(f"{i+1}. {LABELS[n]} (sim={s:.3f})" for i, (n, s) in enumerate(out))
with gr.Blocks(title="loopback", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# loopback — two-tower music recommender\n\n"
"Pick tracks you'd listen to. The model averages their embeddings and retrieves the "
"10 nearest neighbors from a **1.5M-track catalog** (Last.fm 1K).\n\n"
"Trained on 15.3M listening events. "
"[Code](https://github.com/DanielRegaladoUMiami/loopback) · "
"[Dataset](https://huggingface.co/datasets/DanielRegaladoCardoso/lastfm-1k-twotower) · "
"[Model](https://huggingface.co/DanielRegaladoCardoso/loopback-twotower)"
)
picks = gr.Dropdown(choices=SAMPLE_TRACKS, multiselect=True, label="Seed tracks")
btn = gr.Button("Recommend", variant="primary")
out = gr.Textbox(label="Top-10 recommendations", lines=12)
btn.click(recommend, inputs=picks, outputs=out)
if __name__ == "__main__":
demo.launch()