ChristophSchuhmann commited on
Commit
e6bc26d
·
verified ·
1 Parent(s): a8c6cd7

Initial upload: popularity prediction MLP head + evaluation report

Browse files
Files changed (3) hide show
  1. README.md +127 -0
  2. evaluation_report.html +0 -0
  3. popularity_head.pt +3 -0
README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-4.0
3
+ tags:
4
+ - audio
5
+ - music
6
+ - whisper
7
+ - popularity-prediction
8
+ - laion
9
+ - laion-tunes
10
+ library_name: transformers
11
+ pipeline_tag: audio-classification
12
+ ---
13
+
14
+ # Music Popularity Predictor
15
+
16
+ Predicts **play count** and **upvote/like count** of AI-generated music tracks from audio alone.
17
+
18
+ ## Architecture
19
+
20
+ | Component | Details |
21
+ |-----------|---------|
22
+ | **Encoder** | [laion/music-whisper](https://huggingface.co/laion/music-whisper) (Whisper Small fine-tuned for music captioning, frozen) |
23
+ | **Pooling** | Encoder output (1500x768) → 10 segments of 150 frames → mean/max/min pool → 23,040-dim |
24
+ | **MLP Head** | 23040 → 1024 → 256 (LayerNorm) → two prediction heads (play count + upvote count) |
25
+ | **Output** | log1p-scaled: `log(1 + count)` — use `math.expm1()` to convert back |
26
+
27
+ ## Training
28
+
29
+ - **Data**: ~39,000 stratified samples from the [LAION-Tunes](https://huggingface.co/datasets/ai-music/ai-music-deduplicated) dataset (Suno, Udio, Mureka, Riffusion, Sonauto)
30
+ - **Loss**: Huber Loss
31
+ - **Optimizer**: AdamW (lr=5e-4, weight_decay=1e-4, cosine schedule, 3 epochs)
32
+ - **Best val loss**: 4.004 (epoch 2)
33
+
34
+ ### Evaluation (200 validation samples)
35
+
36
+ | Metric | Play Count | Upvote Count |
37
+ |--------|-----------|--------------|
38
+ | Pearson r | 0.145 | 0.102 |
39
+ | Log-Pearson r | 0.414 | 0.413 |
40
+ | Log MAE | 2.981 | 1.923 |
41
+
42
+ ## Usage
43
+
44
+ ```python
45
+ import torch
46
+ import torch.nn as nn
47
+ import numpy as np
48
+ import librosa
49
+ import math
50
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
51
+ from huggingface_hub import hf_hub_download
52
+
53
+
54
+ # --- Define the MLP head ---
55
+
56
+ class PopularityMLP(nn.Module):
57
+ def __init__(self):
58
+ super().__init__()
59
+ self.bottleneck = nn.Sequential(
60
+ nn.Linear(23040, 1024), nn.ReLU(), nn.Dropout(0.3),
61
+ nn.Linear(1024, 256), nn.ReLU(), nn.LayerNorm(256),
62
+ )
63
+ self.play_head = nn.Sequential(nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 1))
64
+ self.upvote_head = nn.Sequential(nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 1))
65
+
66
+ def forward(self, x):
67
+ feat = self.bottleneck(x)
68
+ return self.play_head(feat).squeeze(-1), self.upvote_head(feat).squeeze(-1)
69
+
70
+
71
+ # --- Load models ---
72
+
73
+ # Whisper encoder from laion/music-whisper
74
+ processor = WhisperProcessor.from_pretrained("laion/music-whisper")
75
+ whisper = WhisperForConditionalGeneration.from_pretrained(
76
+ "laion/music-whisper", torch_dtype=torch.float16
77
+ ).cuda().eval()
78
+ encoder = whisper.get_encoder()
79
+
80
+ # Popularity head from this repo
81
+ head_path = hf_hub_download("laion/music-popularity", "popularity_head.pt")
82
+ mlp = PopularityMLP().cuda()
83
+ mlp.load_state_dict(torch.load(head_path, map_location="cuda")["mlp_state_dict"])
84
+ mlp.eval()
85
+
86
+
87
+ # --- Run inference ---
88
+
89
+ audio, sr = librosa.load("song.mp3", sr=16000, mono=True)
90
+ audio = audio[:30 * 16000] # first 30 seconds
91
+ if len(audio) < 30 * 16000:
92
+ audio = np.pad(audio, (0, 30 * 16000 - len(audio)))
93
+
94
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
95
+
96
+ with torch.no_grad():
97
+ enc_out = encoder(inputs.input_features.cuda().half()).last_hidden_state # (1, 1500, 768)
98
+
99
+ # Segment pooling: 10 segments, mean/max/min
100
+ segments = enc_out.view(1, 10, 150, 768)
101
+ pooled = torch.cat([segments.mean(2), segments.max(2).values, segments.min(2).values], dim=2)
102
+ pooled = pooled.view(1, -1).float() # (1, 23040)
103
+
104
+ pred_play, pred_upvote = mlp(pooled)
105
+
106
+ print(f"Estimated plays: {math.expm1(pred_play.item()):,.0f}")
107
+ print(f"Estimated upvotes: {math.expm1(pred_upvote.item()):,.0f}")
108
+ ```
109
+
110
+ ## Files
111
+
112
+ | File | Description |
113
+ |------|-------------|
114
+ | `popularity_head.pt` | MLP head weights (91 MB) |
115
+ | `evaluation_report.html` | Detailed evaluation with plots |
116
+
117
+ The Whisper encoder is loaded separately from [laion/music-whisper](https://huggingface.co/laion/music-whisper).
118
+
119
+ ## License
120
+
121
+ CC BY 4.0 — Christoph Schuhmann / LAION
122
+
123
+ ## Acknowledgments
124
+
125
+ - Encoder: [laion/music-whisper](https://huggingface.co/laion/music-whisper) (OpenAI Whisper Small, fine-tuned for music captioning)
126
+ - Dataset: [LAION-Tunes](https://huggingface.co/datasets/ai-music/ai-music-deduplicated) (AI-generated music from Suno, Udio, Mureka, Riffusion, Sonauto)
127
+ - Developed by Christoph Schuhmann and the LAION community
evaluation_report.html ADDED
The diff for this file is too large to render. See raw diff
 
popularity_head.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b7d751eae3f7625257708b7163534cce28c48e6db9453a62fab467b4b6729af
3
+ size 95565585