Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from pathlib import Path | |
| import json | |
| import h5py | |
| from sklearn.neighbors import NearestNeighbors | |
| from sklearn.preprocessing import normalize | |
| import random | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Download model and cache on startup | |
| if not os.path.exists("best_model.pth"): | |
| print("π₯ Downloading model...") | |
| hf_hub_download( | |
| repo_id="mrob937/music-recommender-assets", | |
| filename="best_model.pth", | |
| local_dir=".", | |
| repo_type="model" | |
| ) | |
| if not os.path.exists("features_cache.h5"): | |
| print("π₯ Downloading features...") | |
| hf_hub_download( | |
| repo_id="mrob937/music-recommender-assets", | |
| filename="features_cache.h5", | |
| local_dir=".", | |
| repo_type="model" | |
| ) | |
| # ============================================================================ | |
| # COPY YOUR PlaylistCNNClassifier CLASS HERE (from your notebook) | |
| # ============================================================================ | |
| class PlaylistCNNClassifier(nn.Module): | |
| """ | |
| CNN classifier for playlist classification from MERT audio features. | |
| Handles high-resolution input features [batch, time_steps, feature_dim] by: | |
| 1. Downsampling time dimension with adaptive pooling | |
| 2. Learning temporal patterns with 1D convolutions | |
| 3. Global pooling and classification | |
| Args: | |
| num_classes: Number of playlist classes to predict | |
| input_feature_dim: Feature dimension from MERT (default: 1024) | |
| target_time_steps: Target number of time steps after downsampling (default: 512) | |
| dropout: Dropout probability (default: 0.3) | |
| """ | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| input_feature_dim: int = 1024, | |
| target_time_steps: int = 512, | |
| dropout: float = 0.3 | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.input_feature_dim = input_feature_dim | |
| self.target_time_steps = target_time_steps | |
| # Adaptive pooling to downsample time dimension | |
| self.adaptive_pool = nn.AdaptiveAvgPool1d(target_time_steps) | |
| # 1D Convolutional layers | |
| # Input: [batch, 1024, 512] | |
| self.conv1 = nn.Conv1d( | |
| in_channels=input_feature_dim, | |
| out_channels=512, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1 | |
| ) | |
| self.bn1 = nn.BatchNorm1d(512) | |
| # After conv1: [batch, 512, 256] | |
| self.conv2 = nn.Conv1d( | |
| in_channels=512, | |
| out_channels=256, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1 | |
| ) | |
| self.bn2 = nn.BatchNorm1d(256) | |
| # After conv2: [batch, 256, 128] | |
| self.conv3 = nn.Conv1d( | |
| in_channels=256, | |
| out_channels=128, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1 | |
| ) | |
| self.bn3 = nn.BatchNorm1d(128) | |
| # After conv3: [batch, 128, 64] | |
| # Global pooling will reduce to [batch, 128] | |
| # Fully connected layers for classification | |
| self.fc1 = nn.Linear(128, 256) | |
| self.fc2 = nn.Linear(256, num_classes) | |
| self.dropout = nn.Dropout(dropout) | |
| self.relu = nn.ReLU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass. | |
| Args: | |
| x: Input tensor of shape [batch, time_steps, feature_dim] | |
| e.g., [batch, 16404, 1024] | |
| Returns: | |
| logits: Class logits of shape [batch, num_classes] | |
| """ | |
| # x shape: [batch, time_steps, feature_dim] | |
| # Need to transpose for Conv1d: [batch, feature_dim, time_steps] | |
| x = x.transpose(1, 2) # [batch, 1024, 16404] | |
| # Downsample time dimension | |
| x = self.adaptive_pool(x) # [batch, 1024, 512] | |
| # 1D Convolutional blocks | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.dropout(x) # [batch, 512, 256] | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| x = self.relu(x) | |
| x = self.dropout(x) # [batch, 256, 128] | |
| x = self.conv3(x) | |
| x = self.bn3(x) | |
| x = self.relu(x) | |
| x = self.dropout(x) # [batch, 128, 64] | |
| # Global average pooling across time dimension | |
| x = F.adaptive_avg_pool1d(x, 1) # [batch, 128, 1] | |
| x = x.squeeze(-1) # [batch, 128] | |
| # Fully connected layers | |
| x = self.fc1(x) | |
| x = self.relu(x) | |
| x = self.dropout(x) # [batch, 256] | |
| x = self.fc2(x) # [batch, num_classes] | |
| return x | |
| # ============================================================================ | |
| # Music Recommender Class | |
| # ============================================================================ | |
| class MusicRecommender: | |
| """Simplified recommender for Gradio demo.""" | |
| def __init__(self, model_path: str, cache_dir: str, device: str = "cpu"): | |
| self.device = device | |
| self.n_neighbors = 5 | |
| # Load model | |
| checkpoint = torch.load(model_path, map_location=device) | |
| label_to_name_raw = checkpoint['label_to_name'] | |
| self.label_to_name = {int(k): v for k, v in label_to_name_raw.items()} | |
| self.num_classes = checkpoint['num_classes'] | |
| self.model = PlaylistCNNClassifier(num_classes=self.num_classes).to(device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.eval() | |
| # Load cached features | |
| cache_path = Path(cache_dir) | |
| h5_file = cache_path / "features_cache.h5" | |
| with h5py.File(h5_file, "r") as f: | |
| self.features = f["features"][:] | |
| self.labels = f["labels"][:] | |
| self.file_paths = [ | |
| (p.decode() if isinstance(p, bytes) else p) | |
| for p in f["file_paths"][:] | |
| ] | |
| # Prepare features for KNN | |
| self.features_flat = self.features.mean(axis=1) | |
| self.features_normalized = normalize(self.features_flat, norm='l2', axis=1) | |
| # Build KNN indices per playlist | |
| self.knn_models = {} | |
| for label in range(self.num_classes): | |
| playlist_indices = np.where(self.labels == label)[0] | |
| if len(playlist_indices) == 0: | |
| continue | |
| playlist_features = self.features_normalized[playlist_indices] | |
| n_neighbors_playlist = min(self.n_neighbors + 1, len(playlist_indices)) | |
| knn = NearestNeighbors(n_neighbors=n_neighbors_playlist, metric='cosine', algorithm='brute') | |
| knn.fit(playlist_features) | |
| self.knn_models[label] = { | |
| 'model': knn, | |
| 'indices': playlist_indices, | |
| 'name': self.label_to_name[label] | |
| } | |
| def get_song_name(self, file_path: str) -> str: | |
| return Path(file_path).stem | |
| def predict_playlists(self, features: np.ndarray): | |
| features_tensor = torch.from_numpy(features).float().unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(features_tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| top2_probs, top2_labels = torch.topk(probs, k=2) | |
| return [(top2_labels[i].item(), top2_probs[i].item()) for i in range(2)] | |
| def find_similar_in_playlist(self, query_features: np.ndarray, playlist_label: int, exclude_idx: int = None): | |
| if playlist_label not in self.knn_models: | |
| return [] | |
| knn_data = self.knn_models[playlist_label] | |
| knn_model = knn_data['model'] | |
| playlist_indices = knn_data['indices'] | |
| distances, indices = knn_model.kneighbors([query_features]) | |
| global_indices = playlist_indices[indices[0]] | |
| results = [] | |
| for idx, dist in zip(global_indices, distances[0]): | |
| if exclude_idx is not None and idx == exclude_idx: | |
| continue | |
| results.append((idx, dist)) | |
| return results | |
| def recommend(self, song_name: str, diversity_mode: str = "diverse"): | |
| """ | |
| Recommend songs with configurable diversity. | |
| Args: | |
| song_name: Name of the query song | |
| diversity_mode: One of: | |
| - "focused": Only from top predicted playlists (original behavior) | |
| - "balanced": Mix of predicted + exploratory recommendations | |
| - "diverse": Equal weight to all playlists, maximizing variety | |
| """ | |
| # Find song by name | |
| matching_paths = [ | |
| path for path in self.file_paths | |
| if self.get_song_name(path).lower() == song_name.lower() | |
| ] | |
| if not matching_paths: | |
| matching_paths = [ | |
| path for path in self.file_paths | |
| if song_name.lower() in self.get_song_name(path).lower() | |
| ] | |
| if not matching_paths: | |
| return f"β Song '{song_name}' not found in dataset.", "", [] | |
| song_path = matching_paths[0] | |
| song_idx = self.file_paths.index(song_path) | |
| query_features = self.features[song_idx] | |
| query_features_flat = self.features_normalized[song_idx] | |
| query_label = self.labels[song_idx] | |
| # Predict playlists (now get more for diversity modes) | |
| if diversity_mode == "focused": | |
| top_playlists = self.predict_playlists(query_features) | |
| else: | |
| # Get top 5 playlists for more variety | |
| features_tensor = torch.from_numpy(query_features).float().unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(features_tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| k = min(5, self.num_classes) | |
| topk_probs, topk_labels = torch.topk(probs, k=k) | |
| top_playlists = [(topk_labels[i].item(), topk_probs[i].item()) for i in range(k)] | |
| # Find similar songs with diversity weighting | |
| candidates = [] | |
| if diversity_mode == "diverse": | |
| # Equal weight to all playlists - ignore predicted probabilities | |
| for playlist_label, _ in top_playlists: | |
| similar_songs = self.find_similar_in_playlist( | |
| query_features_flat, playlist_label, exclude_idx=song_idx | |
| ) | |
| for song_idx_candidate, distance in similar_songs[:self.n_neighbors]: | |
| # Use only similarity, not playlist probability | |
| score = (1 - distance) | |
| candidates.append((song_idx_candidate, score, playlist_label, distance)) | |
| elif diversity_mode == "balanced": | |
| # Mix of top predictions + exploratory picks | |
| for i, (playlist_label, prob) in enumerate(top_playlists): | |
| similar_songs = self.find_similar_in_playlist( | |
| query_features_flat, playlist_label, exclude_idx=song_idx | |
| ) | |
| # Boost factor decreases for lower-ranked playlists but not as much | |
| if i < 2: | |
| # Top 2 playlists get full probability weight | |
| boost = 1.0 | |
| else: | |
| # Others get partial boost to encourage diversity | |
| boost = 0.5 | |
| for song_idx_candidate, distance in similar_songs[:self.n_neighbors]: | |
| score = (1 - distance) * (prob * boost + 0.2) # +0.2 base score for diversity | |
| candidates.append((song_idx_candidate, score, playlist_label, distance)) | |
| else: # "focused" - original behavior | |
| for playlist_label, prob in top_playlists[:2]: | |
| similar_songs = self.find_similar_in_playlist( | |
| query_features_flat, playlist_label, exclude_idx=song_idx | |
| ) | |
| for song_idx_candidate, distance in similar_songs[:self.n_neighbors]: | |
| score = (1 - distance) * prob | |
| candidates.append((song_idx_candidate, score, playlist_label, distance)) | |
| if not candidates: | |
| return "β No recommendations found.", "", [] | |
| # Remove duplicates (keep highest score) | |
| seen = {} | |
| for candidate in candidates: | |
| idx = candidate[0] | |
| if idx not in seen or candidate[1] > seen[idx][1]: | |
| seen[idx] = candidate | |
| candidates = list(seen.values()) | |
| candidates.sort(key=lambda x: x[1], reverse=True) | |
| # Format output | |
| current_info = f""" | |
| ### π§ Now Playing | |
| **{self.get_song_name(song_path)}** | |
| π Playlist: {self.label_to_name[query_label]} | |
| π¨ Diversity Mode: {diversity_mode.capitalize()} | |
| """ | |
| num_playlists_shown = len(top_playlists) | |
| playlist_info = f"### π― Predicted Playlists (Top {num_playlists_shown})\n" | |
| for i, (pl_label, prob) in enumerate(top_playlists, 1): | |
| playlist_info += f"{i}. **{self.label_to_name[pl_label]}**: {prob*100:.1f}%\n" | |
| recommendations = [] | |
| for i, (idx, score, pl_label, dist) in enumerate(candidates[:5], 1): | |
| song = self.get_song_name(self.file_paths[idx]) | |
| playlist = self.label_to_name[pl_label] | |
| similarity = (1 - dist) * 100 | |
| recommendations.append([i, song, playlist, f"{similarity:.1f}%", f"{score*100:.1f}%"]) | |
| best_song = self.get_song_name(self.file_paths[candidates[0][0]]) | |
| return current_info, playlist_info, recommendations, f"## β¨ Next Song: **{best_song}**" | |
| # ============================================================================ | |
| # Initialize Recommender | |
| # ============================================================================ | |
| print("π Loading model and features...") | |
| recommender = MusicRecommender( | |
| model_path="best_model.pth", | |
| cache_dir=".", | |
| device="cpu" | |
| ) | |
| print("β Model loaded!") | |
| # Get list of all songs | |
| all_songs = [recommender.get_song_name(path) for path in recommender.file_paths] | |
| # ============================================================================ | |
| # Gradio Interface Functions | |
| # ============================================================================ | |
| def recommend_song(song_name: str): | |
| """Wrapper function for Gradio.""" | |
| if not song_name: | |
| return "β οΈ Please select a song first.", "", [], "" | |
| current, playlists, recs, next_song = recommender.recommend(song_name) | |
| return current, playlists, recs, next_song | |
| def random_song(): | |
| """Pick a random song.""" | |
| return random.choice(all_songs) | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="π΅ Music Recommender") as demo: | |
| gr.Markdown( | |
| """ | |
| # π΅ AI Music Recommendation System | |
| This system uses a **CNN classifier** to predict playlist categories and **cosine similarity (KNN)** | |
| to find the most similar songs within the top 2 predicted playlists. | |
| ### How it works: | |
| 1. π― Predicts the top 2 most likely playlists for your song using a trained CNN | |
| 2. π Searches for similar songs only within those playlists (reduced search space) | |
| 3. β¨ Returns the most similar song based on audio features | |
| --- | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| song_dropdown = gr.Dropdown( | |
| choices=all_songs, | |
| label="π΅ Select a Song", | |
| info="Choose a song from your music library", | |
| filterable=True | |
| ) | |
| with gr.Row(): | |
| recommend_btn = gr.Button("π― Get Recommendation", variant="primary", size="lg") | |
| random_btn = gr.Button("π² Random Song", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| f""" | |
| ### π Dataset Info | |
| - **Total Songs**: {len(all_songs)} | |
| - **Playlists**: {recommender.num_classes} | |
| - **Model**: PlaylistCNN + KNN | |
| """ | |
| ) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| current_song = gr.Markdown("### π§ Select a song to get started") | |
| with gr.Column(): | |
| playlist_pred = gr.Markdown("") | |
| next_song_output = gr.Markdown("", elem_classes="recommendation") | |
| gr.Markdown("### πΌ Top 5 Similar Songs") | |
| recommendations_table = gr.Dataframe( | |
| headers=["Rank", "Song", "Playlist", "Similarity", "Score"], | |
| datatype=["number", "str", "str", "str", "str"], | |
| col_count=(5, "fixed"), | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π§ Technical Details | |
| - **Features**: MERT audio embeddings (averaged across time) | |
| - **Classification**: Custom CNN with 1D convolutions | |
| - **Similarity**: Cosine similarity on normalized feature vectors | |
| - **Search Strategy**: Top-2 playlist filtering for efficient KNN | |
| """ | |
| ) | |
| # Event handlers | |
| recommend_btn.click( | |
| fn=recommend_song, | |
| inputs=[song_dropdown], | |
| outputs=[current_song, playlist_pred, recommendations_table, next_song_output] | |
| ) | |
| random_btn.click( | |
| fn=random_song, | |
| outputs=[song_dropdown] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() | |
| """Simplified recommender for Gradio demo.""" | |
| def __init__(self, model_path: str, cache_dir: str, device: str = "cpu"): | |
| self.device = device | |
| self.n_neighbors = 5 | |
| # Load model | |
| checkpoint = torch.load(model_path, map_location=device) | |
| label_to_name_raw = checkpoint['label_to_name'] | |
| self.label_to_name = {int(k): v for k, v in label_to_name_raw.items()} | |
| self.num_classes = checkpoint['num_classes'] | |
| self.model = PlaylistCNNClassifier(num_classes=self.num_classes).to(device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.eval() | |
| # Load cached features | |
| cache_path = Path(cache_dir) | |
| h5_file = cache_path / "features_cache.h5" | |
| with h5py.File(h5_file, "r") as f: | |
| self.features = f["features"][:] | |
| self.labels = f["labels"][:] | |
| self.file_paths = [ | |
| (p.decode() if isinstance(p, bytes) else p) | |
| for p in f["file_paths"][:] | |
| ] | |
| # Prepare features for KNN | |
| self.features_flat = self.features.mean(axis=1) | |
| self.features_normalized = normalize(self.features_flat, norm='l2', axis=1) | |
| # Build KNN indices per playlist | |
| self.knn_models = {} | |
| for label in range(self.num_classes): | |
| playlist_indices = np.where(self.labels == label)[0] | |
| if len(playlist_indices) == 0: | |
| continue | |
| playlist_features = self.features_normalized[playlist_indices] | |
| n_neighbors_playlist = min(self.n_neighbors + 1, len(playlist_indices)) | |
| knn = NearestNeighbors(n_neighbors=n_neighbors_playlist, metric='cosine', algorithm='brute') | |
| knn.fit(playlist_features) | |
| self.knn_models[label] = { | |
| 'model': knn, | |
| 'indices': playlist_indices, | |
| 'name': self.label_to_name[label] | |
| } | |
| def get_song_name(self, file_path: str) -> str: | |
| return Path(file_path).stem | |
| def predict_playlists(self, features: np.ndarray): | |
| features_tensor = torch.from_numpy(features).float().unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(features_tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| top2_probs, top2_labels = torch.topk(probs, k=2) | |
| return [(top2_labels[i].item(), top2_probs[i].item()) for i in range(2)] | |
| def find_similar_in_playlist(self, query_features: np.ndarray, playlist_label: int, exclude_idx: int = None): | |
| if playlist_label not in self.knn_models: | |
| return [] | |
| knn_data = self.knn_models[playlist_label] | |
| knn_model = knn_data['model'] | |
| playlist_indices = knn_data['indices'] | |
| distances, indices = knn_model.kneighbors([query_features]) | |
| global_indices = playlist_indices[indices[0]] | |
| results = [] | |
| for idx, dist in zip(global_indices, distances[0]): | |
| if exclude_idx is not None and idx == exclude_idx: | |
| continue | |
| results.append((idx, dist)) | |
| return results | |
| def recommend(self, song_name: str, diversity_mode: str = "diverse"): | |
| """ | |
| Recommend songs with configurable diversity. | |
| Args: | |
| song_name: Name of the query song | |
| diversity_mode: One of: | |
| - "focused": Only from top predicted playlists (original behavior) | |
| - "balanced": Mix of predicted + exploratory recommendations | |
| - "diverse": Equal weight to all playlists, maximizing variety | |
| """ | |
| # Find song by name | |
| matching_paths = [ | |
| path for path in self.file_paths | |
| if self.get_song_name(path).lower() == song_name.lower() | |
| ] | |
| if not matching_paths: | |
| matching_paths = [ | |
| path for path in self.file_paths | |
| if song_name.lower() in self.get_song_name(path).lower() | |
| ] | |
| if not matching_paths: | |
| return f"β Song '{song_name}' not found in dataset.", "", [] | |
| song_path = matching_paths[0] | |
| song_idx = self.file_paths.index(song_path) | |
| query_features = self.features[song_idx] | |
| query_features_flat = self.features_normalized[song_idx] | |
| query_label = self.labels[song_idx] | |
| # Predict playlists (now get more for diversity modes) | |
| if diversity_mode == "focused": | |
| top_playlists = self.predict_playlists(query_features) | |
| else: | |
| # Get top 5 playlists for more variety | |
| features_tensor = torch.from_numpy(query_features).float().unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(features_tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| k = min(5, self.num_classes) | |
| topk_probs, topk_labels = torch.topk(probs, k=k) | |
| top_playlists = [(topk_labels[i].item(), topk_probs[i].item()) for i in range(k)] | |
| # Find similar songs with diversity weighting | |
| candidates = [] | |
| if diversity_mode == "diverse": | |
| # Equal weight to all playlists - ignore predicted probabilities | |
| for playlist_label, _ in top_playlists: | |
| similar_songs = self.find_similar_in_playlist( | |
| query_features_flat, playlist_label, exclude_idx=song_idx | |
| ) | |
| for song_idx_candidate, distance in similar_songs[:self.n_neighbors]: | |
| # Use only similarity, not playlist probability | |
| score = (1 - distance) | |
| candidates.append((song_idx_candidate, score, playlist_label, distance)) | |
| elif diversity_mode == "balanced": | |
| # Mix of top predictions + exploratory picks | |
| for i, (playlist_label, prob) in enumerate(top_playlists): | |
| similar_songs = self.find_similar_in_playlist( | |
| query_features_flat, playlist_label, exclude_idx=song_idx | |
| ) | |
| # Boost factor decreases for lower-ranked playlists but not as much | |
| if i < 2: | |
| # Top 2 playlists get full probability weight | |
| boost = 1.0 | |
| else: | |
| # Others get partial boost to encourage diversity | |
| boost = 0.5 | |
| for song_idx_candidate, distance in similar_songs[:self.n_neighbors]: | |
| score = (1 - distance) * (prob * boost + 0.2) # +0.2 base score for diversity | |
| candidates.append((song_idx_candidate, score, playlist_label, distance)) | |
| else: # "focused" - original behavior | |
| for playlist_label, prob in top_playlists[:2]: | |
| similar_songs = self.find_similar_in_playlist( | |
| query_features_flat, playlist_label, exclude_idx=song_idx | |
| ) | |
| for song_idx_candidate, distance in similar_songs[:self.n_neighbors]: | |
| score = (1 - distance) * prob | |
| candidates.append((song_idx_candidate, score, playlist_label, distance)) | |
| if not candidates: | |
| return "β No recommendations found.", "", [] | |
| # Remove duplicates (keep highest score) | |
| seen = {} | |
| for candidate in candidates: | |
| idx = candidate[0] | |
| if idx not in seen or candidate[1] > seen[idx][1]: | |
| seen[idx] = candidate | |
| candidates = list(seen.values()) | |
| candidates.sort(key=lambda x: x[1], reverse=True) | |
| # Format output | |
| current_info = f""" | |
| ### π§ Now Playing | |
| **{self.get_song_name(song_path)}** | |
| π Playlist: {self.label_to_name[query_label]} | |
| π¨ Diversity Mode: {diversity_mode.capitalize()} | |
| """ | |
| num_playlists_shown = len(top_playlists) | |
| playlist_info = f"### π― Predicted Playlists (Top {num_playlists_shown})\n" | |
| for i, (pl_label, prob) in enumerate(top_playlists, 1): | |
| playlist_info += f"{i}. **{self.label_to_name[pl_label]}**: {prob*100:.1f}%\n" | |
| recommendations = [] | |
| for i, (idx, score, pl_label, dist) in enumerate(candidates[:5], 1): | |
| song = self.get_song_name(self.file_paths[idx]) | |
| playlist = self.label_to_name[pl_label] | |
| similarity = (1 - dist) * 100 | |
| recommendations.append([i, song, playlist, f"{similarity:.1f}%", f"{score*100:.1f}%"]) | |
| best_song = self.get_song_name(self.file_paths[candidates[0][0]]) | |
| return current_info, playlist_info, recommendations, f"## β¨ Next Song: **{best_song}**" | |
| # Initialize recommender (will be loaded when Space starts) | |
| print("π Loading model and features...") | |
| recommender = MusicRecommender( | |
| model_path="best_model.pth", # Upload this to your Space | |
| cache_dir=".", # features_cache.h5 should be in root | |
| device="cpu" # HuggingFace Spaces use CPU by default | |
| ) | |
| print("β Model loaded!") | |
| # Get list of all songs for dropdown | |
| all_songs = [recommender.get_song_name(path) for path in recommender.file_paths] | |
| def recommend_song(song_name: str): | |
| """Wrapper function for Gradio.""" | |
| if not song_name: | |
| return "β οΈ Please select a song first.", "", [], "" | |
| current, playlists, recs, next_song = recommender.recommend(song_name) | |
| return current, playlists, recs, next_song | |
| def random_song(): | |
| """Pick a random song.""" | |
| return random.choice(all_songs) | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="π΅ Music Recommender") as demo: | |
| gr.Markdown( | |
| """ | |
| # π΅ AI Music Recommendation System | |
| This system uses a **CNN classifier** to predict playlist categories and **cosine similarity (KNN)** | |
| to find the most similar songs within the top 2 predicted playlists. | |
| ### How it works: | |
| 1. π― Predicts the top 2 most likely playlists for your song using a trained CNN | |
| 2. π Searches for similar songs only within those playlists (reduced search space) | |
| 3. β¨ Returns the most similar song based on audio features | |
| --- | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| song_dropdown = gr.Dropdown( | |
| choices=all_songs, | |
| label="π΅ Select a Song", | |
| info="Choose a song from your music library", | |
| filterable=True | |
| ) | |
| with gr.Row(): | |
| recommend_btn = gr.Button("π― Get Recommendation", variant="primary", size="lg") | |
| random_btn = gr.Button("π² Random Song", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| f""" | |
| ### π Dataset Info | |
| - **Total Songs**: {len(all_songs)} | |
| - **Playlists**: {recommender.num_classes} | |
| - **Model**: PlaylistCNN + KNN | |
| """ | |
| ) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| current_song = gr.Markdown("### π§ Select a song to get started") | |
| with gr.Column(): | |
| playlist_pred = gr.Markdown("") | |
| next_song_output = gr.Markdown("", elem_classes="recommendation") | |
| gr.Markdown("### πΌ Top 5 Similar Songs") | |
| recommendations_table = gr.Dataframe( | |
| headers=["Rank", "Song", "Playlist", "Similarity", "Score"], | |
| datatype=["number", "str", "str", "str", "str"], | |
| col_count=(5, "fixed"), | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π§ Technical Details | |
| - **Features**: MERT audio embeddings (averaged across time) | |
| - **Classification**: Custom CNN with 1D convolutions | |
| - **Similarity**: Cosine similarity on normalized feature vectors | |
| - **Search Strategy**: Top-2 playlist filtering for efficient KNN | |
| """ | |
| ) | |
| # Event handlers | |
| recommend_btn.click( | |
| fn=recommend_song, | |
| inputs=[song_dropdown], | |
| outputs=[current_song, playlist_pred, recommendations_table, next_song_output] | |
| ) | |
| random_btn.click( | |
| fn=random_song, | |
| outputs=[song_dropdown] | |
| ) | |
| # Launch the demo | |
| if __name__ == "__main__": | |
| demo.launch() | |