mrob937's picture
super_diverse
4e2f6dc verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import json
import h5py
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
import random
import gradio as gr
from huggingface_hub import hf_hub_download
import os
# Download model and cache on startup
if not os.path.exists("best_model.pth"):
print("πŸ“₯ Downloading model...")
hf_hub_download(
repo_id="mrob937/music-recommender-assets",
filename="best_model.pth",
local_dir=".",
repo_type="model"
)
if not os.path.exists("features_cache.h5"):
print("πŸ“₯ Downloading features...")
hf_hub_download(
repo_id="mrob937/music-recommender-assets",
filename="features_cache.h5",
local_dir=".",
repo_type="model"
)
# ============================================================================
# COPY YOUR PlaylistCNNClassifier CLASS HERE (from your notebook)
# ============================================================================
class PlaylistCNNClassifier(nn.Module):
"""
CNN classifier for playlist classification from MERT audio features.
Handles high-resolution input features [batch, time_steps, feature_dim] by:
1. Downsampling time dimension with adaptive pooling
2. Learning temporal patterns with 1D convolutions
3. Global pooling and classification
Args:
num_classes: Number of playlist classes to predict
input_feature_dim: Feature dimension from MERT (default: 1024)
target_time_steps: Target number of time steps after downsampling (default: 512)
dropout: Dropout probability (default: 0.3)
"""
def __init__(
self,
num_classes: int,
input_feature_dim: int = 1024,
target_time_steps: int = 512,
dropout: float = 0.3
):
super().__init__()
self.num_classes = num_classes
self.input_feature_dim = input_feature_dim
self.target_time_steps = target_time_steps
# Adaptive pooling to downsample time dimension
self.adaptive_pool = nn.AdaptiveAvgPool1d(target_time_steps)
# 1D Convolutional layers
# Input: [batch, 1024, 512]
self.conv1 = nn.Conv1d(
in_channels=input_feature_dim,
out_channels=512,
kernel_size=3,
stride=2,
padding=1
)
self.bn1 = nn.BatchNorm1d(512)
# After conv1: [batch, 512, 256]
self.conv2 = nn.Conv1d(
in_channels=512,
out_channels=256,
kernel_size=3,
stride=2,
padding=1
)
self.bn2 = nn.BatchNorm1d(256)
# After conv2: [batch, 256, 128]
self.conv3 = nn.Conv1d(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=2,
padding=1
)
self.bn3 = nn.BatchNorm1d(128)
# After conv3: [batch, 128, 64]
# Global pooling will reduce to [batch, 128]
# Fully connected layers for classification
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, num_classes)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input tensor of shape [batch, time_steps, feature_dim]
e.g., [batch, 16404, 1024]
Returns:
logits: Class logits of shape [batch, num_classes]
"""
# x shape: [batch, time_steps, feature_dim]
# Need to transpose for Conv1d: [batch, feature_dim, time_steps]
x = x.transpose(1, 2) # [batch, 1024, 16404]
# Downsample time dimension
x = self.adaptive_pool(x) # [batch, 1024, 512]
# 1D Convolutional blocks
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.dropout(x) # [batch, 512, 256]
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.dropout(x) # [batch, 256, 128]
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.dropout(x) # [batch, 128, 64]
# Global average pooling across time dimension
x = F.adaptive_avg_pool1d(x, 1) # [batch, 128, 1]
x = x.squeeze(-1) # [batch, 128]
# Fully connected layers
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x) # [batch, 256]
x = self.fc2(x) # [batch, num_classes]
return x
# ============================================================================
# Music Recommender Class
# ============================================================================
class MusicRecommender:
"""Simplified recommender for Gradio demo."""
def __init__(self, model_path: str, cache_dir: str, device: str = "cpu"):
self.device = device
self.n_neighbors = 5
# Load model
checkpoint = torch.load(model_path, map_location=device)
label_to_name_raw = checkpoint['label_to_name']
self.label_to_name = {int(k): v for k, v in label_to_name_raw.items()}
self.num_classes = checkpoint['num_classes']
self.model = PlaylistCNNClassifier(num_classes=self.num_classes).to(device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
# Load cached features
cache_path = Path(cache_dir)
h5_file = cache_path / "features_cache.h5"
with h5py.File(h5_file, "r") as f:
self.features = f["features"][:]
self.labels = f["labels"][:]
self.file_paths = [
(p.decode() if isinstance(p, bytes) else p)
for p in f["file_paths"][:]
]
# Prepare features for KNN
self.features_flat = self.features.mean(axis=1)
self.features_normalized = normalize(self.features_flat, norm='l2', axis=1)
# Build KNN indices per playlist
self.knn_models = {}
for label in range(self.num_classes):
playlist_indices = np.where(self.labels == label)[0]
if len(playlist_indices) == 0:
continue
playlist_features = self.features_normalized[playlist_indices]
n_neighbors_playlist = min(self.n_neighbors + 1, len(playlist_indices))
knn = NearestNeighbors(n_neighbors=n_neighbors_playlist, metric='cosine', algorithm='brute')
knn.fit(playlist_features)
self.knn_models[label] = {
'model': knn,
'indices': playlist_indices,
'name': self.label_to_name[label]
}
def get_song_name(self, file_path: str) -> str:
return Path(file_path).stem
def predict_playlists(self, features: np.ndarray):
features_tensor = torch.from_numpy(features).float().unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(features_tensor)
probs = torch.softmax(logits, dim=1)[0]
top2_probs, top2_labels = torch.topk(probs, k=2)
return [(top2_labels[i].item(), top2_probs[i].item()) for i in range(2)]
def find_similar_in_playlist(self, query_features: np.ndarray, playlist_label: int, exclude_idx: int = None):
if playlist_label not in self.knn_models:
return []
knn_data = self.knn_models[playlist_label]
knn_model = knn_data['model']
playlist_indices = knn_data['indices']
distances, indices = knn_model.kneighbors([query_features])
global_indices = playlist_indices[indices[0]]
results = []
for idx, dist in zip(global_indices, distances[0]):
if exclude_idx is not None and idx == exclude_idx:
continue
results.append((idx, dist))
return results
def recommend(self, song_name: str, diversity_mode: str = "diverse"):
"""
Recommend songs with configurable diversity.
Args:
song_name: Name of the query song
diversity_mode: One of:
- "focused": Only from top predicted playlists (original behavior)
- "balanced": Mix of predicted + exploratory recommendations
- "diverse": Equal weight to all playlists, maximizing variety
"""
# Find song by name
matching_paths = [
path for path in self.file_paths
if self.get_song_name(path).lower() == song_name.lower()
]
if not matching_paths:
matching_paths = [
path for path in self.file_paths
if song_name.lower() in self.get_song_name(path).lower()
]
if not matching_paths:
return f"❌ Song '{song_name}' not found in dataset.", "", []
song_path = matching_paths[0]
song_idx = self.file_paths.index(song_path)
query_features = self.features[song_idx]
query_features_flat = self.features_normalized[song_idx]
query_label = self.labels[song_idx]
# Predict playlists (now get more for diversity modes)
if diversity_mode == "focused":
top_playlists = self.predict_playlists(query_features)
else:
# Get top 5 playlists for more variety
features_tensor = torch.from_numpy(query_features).float().unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(features_tensor)
probs = torch.softmax(logits, dim=1)[0]
k = min(5, self.num_classes)
topk_probs, topk_labels = torch.topk(probs, k=k)
top_playlists = [(topk_labels[i].item(), topk_probs[i].item()) for i in range(k)]
# Find similar songs with diversity weighting
candidates = []
if diversity_mode == "diverse":
# Equal weight to all playlists - ignore predicted probabilities
for playlist_label, _ in top_playlists:
similar_songs = self.find_similar_in_playlist(
query_features_flat, playlist_label, exclude_idx=song_idx
)
for song_idx_candidate, distance in similar_songs[:self.n_neighbors]:
# Use only similarity, not playlist probability
score = (1 - distance)
candidates.append((song_idx_candidate, score, playlist_label, distance))
elif diversity_mode == "balanced":
# Mix of top predictions + exploratory picks
for i, (playlist_label, prob) in enumerate(top_playlists):
similar_songs = self.find_similar_in_playlist(
query_features_flat, playlist_label, exclude_idx=song_idx
)
# Boost factor decreases for lower-ranked playlists but not as much
if i < 2:
# Top 2 playlists get full probability weight
boost = 1.0
else:
# Others get partial boost to encourage diversity
boost = 0.5
for song_idx_candidate, distance in similar_songs[:self.n_neighbors]:
score = (1 - distance) * (prob * boost + 0.2) # +0.2 base score for diversity
candidates.append((song_idx_candidate, score, playlist_label, distance))
else: # "focused" - original behavior
for playlist_label, prob in top_playlists[:2]:
similar_songs = self.find_similar_in_playlist(
query_features_flat, playlist_label, exclude_idx=song_idx
)
for song_idx_candidate, distance in similar_songs[:self.n_neighbors]:
score = (1 - distance) * prob
candidates.append((song_idx_candidate, score, playlist_label, distance))
if not candidates:
return "❌ No recommendations found.", "", []
# Remove duplicates (keep highest score)
seen = {}
for candidate in candidates:
idx = candidate[0]
if idx not in seen or candidate[1] > seen[idx][1]:
seen[idx] = candidate
candidates = list(seen.values())
candidates.sort(key=lambda x: x[1], reverse=True)
# Format output
current_info = f"""
### 🎧 Now Playing
**{self.get_song_name(song_path)}**
πŸ“ Playlist: {self.label_to_name[query_label]}
🎨 Diversity Mode: {diversity_mode.capitalize()}
"""
num_playlists_shown = len(top_playlists)
playlist_info = f"### 🎯 Predicted Playlists (Top {num_playlists_shown})\n"
for i, (pl_label, prob) in enumerate(top_playlists, 1):
playlist_info += f"{i}. **{self.label_to_name[pl_label]}**: {prob*100:.1f}%\n"
recommendations = []
for i, (idx, score, pl_label, dist) in enumerate(candidates[:5], 1):
song = self.get_song_name(self.file_paths[idx])
playlist = self.label_to_name[pl_label]
similarity = (1 - dist) * 100
recommendations.append([i, song, playlist, f"{similarity:.1f}%", f"{score*100:.1f}%"])
best_song = self.get_song_name(self.file_paths[candidates[0][0]])
return current_info, playlist_info, recommendations, f"## ✨ Next Song: **{best_song}**"
# ============================================================================
# Initialize Recommender
# ============================================================================
print("πŸš€ Loading model and features...")
recommender = MusicRecommender(
model_path="best_model.pth",
cache_dir=".",
device="cpu"
)
print("βœ… Model loaded!")
# Get list of all songs
all_songs = [recommender.get_song_name(path) for path in recommender.file_paths]
# ============================================================================
# Gradio Interface Functions
# ============================================================================
def recommend_song(song_name: str):
"""Wrapper function for Gradio."""
if not song_name:
return "⚠️ Please select a song first.", "", [], ""
current, playlists, recs, next_song = recommender.recommend(song_name)
return current, playlists, recs, next_song
def random_song():
"""Pick a random song."""
return random.choice(all_songs)
# ============================================================================
# Gradio Interface
# ============================================================================
with gr.Blocks(theme=gr.themes.Soft(), title="🎡 Music Recommender") as demo:
gr.Markdown(
"""
# 🎡 AI Music Recommendation System
This system uses a **CNN classifier** to predict playlist categories and **cosine similarity (KNN)**
to find the most similar songs within the top 2 predicted playlists.
### How it works:
1. 🎯 Predicts the top 2 most likely playlists for your song using a trained CNN
2. πŸ” Searches for similar songs only within those playlists (reduced search space)
3. ✨ Returns the most similar song based on audio features
---
"""
)
with gr.Row():
with gr.Column(scale=2):
song_dropdown = gr.Dropdown(
choices=all_songs,
label="🎡 Select a Song",
info="Choose a song from your music library",
filterable=True
)
with gr.Row():
recommend_btn = gr.Button("🎯 Get Recommendation", variant="primary", size="lg")
random_btn = gr.Button("🎲 Random Song", size="lg")
with gr.Column(scale=1):
gr.Markdown(
f"""
### πŸ“Š Dataset Info
- **Total Songs**: {len(all_songs)}
- **Playlists**: {recommender.num_classes}
- **Model**: PlaylistCNN + KNN
"""
)
gr.Markdown("---")
with gr.Row():
with gr.Column():
current_song = gr.Markdown("### 🎧 Select a song to get started")
with gr.Column():
playlist_pred = gr.Markdown("")
next_song_output = gr.Markdown("", elem_classes="recommendation")
gr.Markdown("### 🎼 Top 5 Similar Songs")
recommendations_table = gr.Dataframe(
headers=["Rank", "Song", "Playlist", "Similarity", "Score"],
datatype=["number", "str", "str", "str", "str"],
col_count=(5, "fixed"),
)
gr.Markdown(
"""
---
### πŸ”§ Technical Details
- **Features**: MERT audio embeddings (averaged across time)
- **Classification**: Custom CNN with 1D convolutions
- **Similarity**: Cosine similarity on normalized feature vectors
- **Search Strategy**: Top-2 playlist filtering for efficient KNN
"""
)
# Event handlers
recommend_btn.click(
fn=recommend_song,
inputs=[song_dropdown],
outputs=[current_song, playlist_pred, recommendations_table, next_song_output]
)
random_btn.click(
fn=random_song,
outputs=[song_dropdown]
)
# Launch
if __name__ == "__main__":
demo.launch()
"""Simplified recommender for Gradio demo."""
def __init__(self, model_path: str, cache_dir: str, device: str = "cpu"):
self.device = device
self.n_neighbors = 5
# Load model
checkpoint = torch.load(model_path, map_location=device)
label_to_name_raw = checkpoint['label_to_name']
self.label_to_name = {int(k): v for k, v in label_to_name_raw.items()}
self.num_classes = checkpoint['num_classes']
self.model = PlaylistCNNClassifier(num_classes=self.num_classes).to(device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
# Load cached features
cache_path = Path(cache_dir)
h5_file = cache_path / "features_cache.h5"
with h5py.File(h5_file, "r") as f:
self.features = f["features"][:]
self.labels = f["labels"][:]
self.file_paths = [
(p.decode() if isinstance(p, bytes) else p)
for p in f["file_paths"][:]
]
# Prepare features for KNN
self.features_flat = self.features.mean(axis=1)
self.features_normalized = normalize(self.features_flat, norm='l2', axis=1)
# Build KNN indices per playlist
self.knn_models = {}
for label in range(self.num_classes):
playlist_indices = np.where(self.labels == label)[0]
if len(playlist_indices) == 0:
continue
playlist_features = self.features_normalized[playlist_indices]
n_neighbors_playlist = min(self.n_neighbors + 1, len(playlist_indices))
knn = NearestNeighbors(n_neighbors=n_neighbors_playlist, metric='cosine', algorithm='brute')
knn.fit(playlist_features)
self.knn_models[label] = {
'model': knn,
'indices': playlist_indices,
'name': self.label_to_name[label]
}
def get_song_name(self, file_path: str) -> str:
return Path(file_path).stem
def predict_playlists(self, features: np.ndarray):
features_tensor = torch.from_numpy(features).float().unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(features_tensor)
probs = torch.softmax(logits, dim=1)[0]
top2_probs, top2_labels = torch.topk(probs, k=2)
return [(top2_labels[i].item(), top2_probs[i].item()) for i in range(2)]
def find_similar_in_playlist(self, query_features: np.ndarray, playlist_label: int, exclude_idx: int = None):
if playlist_label not in self.knn_models:
return []
knn_data = self.knn_models[playlist_label]
knn_model = knn_data['model']
playlist_indices = knn_data['indices']
distances, indices = knn_model.kneighbors([query_features])
global_indices = playlist_indices[indices[0]]
results = []
for idx, dist in zip(global_indices, distances[0]):
if exclude_idx is not None and idx == exclude_idx:
continue
results.append((idx, dist))
return results
def recommend(self, song_name: str, diversity_mode: str = "diverse"):
"""
Recommend songs with configurable diversity.
Args:
song_name: Name of the query song
diversity_mode: One of:
- "focused": Only from top predicted playlists (original behavior)
- "balanced": Mix of predicted + exploratory recommendations
- "diverse": Equal weight to all playlists, maximizing variety
"""
# Find song by name
matching_paths = [
path for path in self.file_paths
if self.get_song_name(path).lower() == song_name.lower()
]
if not matching_paths:
matching_paths = [
path for path in self.file_paths
if song_name.lower() in self.get_song_name(path).lower()
]
if not matching_paths:
return f"❌ Song '{song_name}' not found in dataset.", "", []
song_path = matching_paths[0]
song_idx = self.file_paths.index(song_path)
query_features = self.features[song_idx]
query_features_flat = self.features_normalized[song_idx]
query_label = self.labels[song_idx]
# Predict playlists (now get more for diversity modes)
if diversity_mode == "focused":
top_playlists = self.predict_playlists(query_features)
else:
# Get top 5 playlists for more variety
features_tensor = torch.from_numpy(query_features).float().unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(features_tensor)
probs = torch.softmax(logits, dim=1)[0]
k = min(5, self.num_classes)
topk_probs, topk_labels = torch.topk(probs, k=k)
top_playlists = [(topk_labels[i].item(), topk_probs[i].item()) for i in range(k)]
# Find similar songs with diversity weighting
candidates = []
if diversity_mode == "diverse":
# Equal weight to all playlists - ignore predicted probabilities
for playlist_label, _ in top_playlists:
similar_songs = self.find_similar_in_playlist(
query_features_flat, playlist_label, exclude_idx=song_idx
)
for song_idx_candidate, distance in similar_songs[:self.n_neighbors]:
# Use only similarity, not playlist probability
score = (1 - distance)
candidates.append((song_idx_candidate, score, playlist_label, distance))
elif diversity_mode == "balanced":
# Mix of top predictions + exploratory picks
for i, (playlist_label, prob) in enumerate(top_playlists):
similar_songs = self.find_similar_in_playlist(
query_features_flat, playlist_label, exclude_idx=song_idx
)
# Boost factor decreases for lower-ranked playlists but not as much
if i < 2:
# Top 2 playlists get full probability weight
boost = 1.0
else:
# Others get partial boost to encourage diversity
boost = 0.5
for song_idx_candidate, distance in similar_songs[:self.n_neighbors]:
score = (1 - distance) * (prob * boost + 0.2) # +0.2 base score for diversity
candidates.append((song_idx_candidate, score, playlist_label, distance))
else: # "focused" - original behavior
for playlist_label, prob in top_playlists[:2]:
similar_songs = self.find_similar_in_playlist(
query_features_flat, playlist_label, exclude_idx=song_idx
)
for song_idx_candidate, distance in similar_songs[:self.n_neighbors]:
score = (1 - distance) * prob
candidates.append((song_idx_candidate, score, playlist_label, distance))
if not candidates:
return "❌ No recommendations found.", "", []
# Remove duplicates (keep highest score)
seen = {}
for candidate in candidates:
idx = candidate[0]
if idx not in seen or candidate[1] > seen[idx][1]:
seen[idx] = candidate
candidates = list(seen.values())
candidates.sort(key=lambda x: x[1], reverse=True)
# Format output
current_info = f"""
### 🎧 Now Playing
**{self.get_song_name(song_path)}**
πŸ“ Playlist: {self.label_to_name[query_label]}
🎨 Diversity Mode: {diversity_mode.capitalize()}
"""
num_playlists_shown = len(top_playlists)
playlist_info = f"### 🎯 Predicted Playlists (Top {num_playlists_shown})\n"
for i, (pl_label, prob) in enumerate(top_playlists, 1):
playlist_info += f"{i}. **{self.label_to_name[pl_label]}**: {prob*100:.1f}%\n"
recommendations = []
for i, (idx, score, pl_label, dist) in enumerate(candidates[:5], 1):
song = self.get_song_name(self.file_paths[idx])
playlist = self.label_to_name[pl_label]
similarity = (1 - dist) * 100
recommendations.append([i, song, playlist, f"{similarity:.1f}%", f"{score*100:.1f}%"])
best_song = self.get_song_name(self.file_paths[candidates[0][0]])
return current_info, playlist_info, recommendations, f"## ✨ Next Song: **{best_song}**"
# Initialize recommender (will be loaded when Space starts)
print("πŸš€ Loading model and features...")
recommender = MusicRecommender(
model_path="best_model.pth", # Upload this to your Space
cache_dir=".", # features_cache.h5 should be in root
device="cpu" # HuggingFace Spaces use CPU by default
)
print("βœ… Model loaded!")
# Get list of all songs for dropdown
all_songs = [recommender.get_song_name(path) for path in recommender.file_paths]
def recommend_song(song_name: str):
"""Wrapper function for Gradio."""
if not song_name:
return "⚠️ Please select a song first.", "", [], ""
current, playlists, recs, next_song = recommender.recommend(song_name)
return current, playlists, recs, next_song
def random_song():
"""Pick a random song."""
return random.choice(all_songs)
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="🎡 Music Recommender") as demo:
gr.Markdown(
"""
# 🎡 AI Music Recommendation System
This system uses a **CNN classifier** to predict playlist categories and **cosine similarity (KNN)**
to find the most similar songs within the top 2 predicted playlists.
### How it works:
1. 🎯 Predicts the top 2 most likely playlists for your song using a trained CNN
2. πŸ” Searches for similar songs only within those playlists (reduced search space)
3. ✨ Returns the most similar song based on audio features
---
"""
)
with gr.Row():
with gr.Column(scale=2):
song_dropdown = gr.Dropdown(
choices=all_songs,
label="🎡 Select a Song",
info="Choose a song from your music library",
filterable=True
)
with gr.Row():
recommend_btn = gr.Button("🎯 Get Recommendation", variant="primary", size="lg")
random_btn = gr.Button("🎲 Random Song", size="lg")
with gr.Column(scale=1):
gr.Markdown(
f"""
### πŸ“Š Dataset Info
- **Total Songs**: {len(all_songs)}
- **Playlists**: {recommender.num_classes}
- **Model**: PlaylistCNN + KNN
"""
)
gr.Markdown("---")
with gr.Row():
with gr.Column():
current_song = gr.Markdown("### 🎧 Select a song to get started")
with gr.Column():
playlist_pred = gr.Markdown("")
next_song_output = gr.Markdown("", elem_classes="recommendation")
gr.Markdown("### 🎼 Top 5 Similar Songs")
recommendations_table = gr.Dataframe(
headers=["Rank", "Song", "Playlist", "Similarity", "Score"],
datatype=["number", "str", "str", "str", "str"],
col_count=(5, "fixed"),
)
gr.Markdown(
"""
---
### πŸ”§ Technical Details
- **Features**: MERT audio embeddings (averaged across time)
- **Classification**: Custom CNN with 1D convolutions
- **Similarity**: Cosine similarity on normalized feature vectors
- **Search Strategy**: Top-2 playlist filtering for efficient KNN
"""
)
# Event handlers
recommend_btn.click(
fn=recommend_song,
inputs=[song_dropdown],
outputs=[current_song, playlist_pred, recommendations_table, next_song_output]
)
random_btn.click(
fn=random_song,
outputs=[song_dropdown]
)
# Launch the demo
if __name__ == "__main__":
demo.launch()