File size: 4,652 Bytes
d69b6f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from huggingface_hub import PyTorchModelHubMixin
class ResBlock1D(nn.Module):
"""
Residual Block for extracting rhythmic features from audio spectrograms.
Maintains temporal resolution while increasing receptive field.
"""
def __init__(self, channels, kernel_size=3, dilation=1):
super().__init__()
padding = (kernel_size - 1) * dilation // 2
self.conv1 = nn.Conv1d(
channels, channels, kernel_size, padding=padding, dilation=dilation
)
self.bn1 = nn.BatchNorm1d(channels)
self.conv2 = nn.Conv1d(
channels, channels, kernel_size, padding=padding, dilation=dilation
)
self.bn2 = nn.BatchNorm1d(channels)
def forward(self, x):
res = x
x = F.gelu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return F.gelu(x + res)
class GameChartEvaluator(nn.Module, PyTorchModelHubMixin):
def __init__(self, input_dim=80, d_model=128, n_layers=4):
super().__init__()
# --- Early Fusion ---
# Input is (Batch, 80 * 2, Time)
# We stack Music (80) + Chart (80) = 160 channels
self.input_proj = nn.Conv1d(
input_dim * 2, d_model, kernel_size=3, stride=1, padding=1
)
# --- STRICT TEMPORAL ENCODER ---
# No Pooling (stride=1) to preserve 11ms resolution
# Dilations allow seeing context without losing resolution
self.encoder = nn.Sequential(
ResBlock1D(d_model, kernel_size=3, dilation=1),
ResBlock1D(d_model, kernel_size=3, dilation=2),
ResBlock1D(d_model, kernel_size=3, dilation=4),
ResBlock1D(d_model, kernel_size=3, dilation=8),
# Add more layers if you need wider context (e.g. 16, 32)
)
# --- SCORING HEAD ---
# Simple projection to scalar
self.quality_proj = nn.Linear(d_model, 1)
# Learnable Mixing
self.raw_severity = nn.Parameter(torch.tensor(0.0))
def forward(self, music_mels, chart_mels):
"""
music_mels: (Batch, 80, Time)
chart_mels: (Batch, 80, Time)
"""
# 1. Early Fusion: Concatenate along Channel dimension
# Shape becomes (Batch, 160, Time)
x = torch.cat([music_mels, chart_mels], dim=1)
# 2. Extract Features (Strictly Local + Context)
x = F.gelu(self.input_proj(x))
x = self.encoder(x)
# 3. Predict Score per Frame
# (Batch, Dim, Time) -> (Batch, Time, Dim)
x = x.permute(0, 2, 1)
local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1)
# 4. Error-Sensitive Pooling
avg_score = local_scores.mean(dim=1)
k = max(1, int(local_scores.size(1) * 0.1))
min_vals, _ = torch.topk(local_scores, k, dim=1, largest=False)
worst_score = min_vals.mean(dim=1)
alpha = torch.sigmoid(self.raw_severity)
final_score = (alpha * worst_score) + ((1 - alpha) * avg_score)
return final_score.squeeze(1)
def predict_trace(self, music_mels, chart_mels):
"""
Explainability Method: Returns the second-by-second quality curve.
Returns:
local_scores: (Batch, Time) - The quality score at every timestep.
"""
with torch.no_grad():
# 1. Early Fusion: Concatenate along Channel dimension
# Shape becomes (Batch, 160, Time)
x = torch.cat([music_mels, chart_mels], dim=1)
# 2. Extract Features (Strictly Local + Context)
x = F.gelu(self.input_proj(x))
x = self.encoder(x)
# 3. Predict Score per Frame
# (Batch, Dim, Time) -> (Batch, Time, Dim)
x = x.permute(0, 2, 1)
local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1)
return local_scores.squeeze(2)
if __name__ == "__main__":
# Sanity Check
from torchinfo import summary
model = GameChartEvaluator()
print(
f"Model initialized. Learnable Severity: {torch.sigmoid(model.raw_severity).item():.2f}"
)
# Dummy data (Batch=2, Freq=80, Time=1000)
m = torch.randn(2, 80, 1000)
c = torch.randn(2, 80, 1000)
output = model(m, c)
print(f"Output shape: {output.shape}") # Should be torch.Size([2])
print(f"Scores: {output}")
# Trace check
trace = model.predict_trace(m, c)
print(
f"Trace shape: {trace.shape}"
) # Should be torch.Size([2, 500]) (due to MaxPool1d(2))
summary(model, input_data=[m, c])
|