import torch import torch.nn as nn import torch.nn.functional as F import math from huggingface_hub import PyTorchModelHubMixin class ResBlock1D(nn.Module): """ Residual Block for extracting rhythmic features from audio spectrograms. Maintains temporal resolution while increasing receptive field. """ def __init__(self, channels, kernel_size=3, dilation=1): super().__init__() padding = (kernel_size - 1) * dilation // 2 self.conv1 = nn.Conv1d( channels, channels, kernel_size, padding=padding, dilation=dilation ) self.bn1 = nn.BatchNorm1d(channels) self.conv2 = nn.Conv1d( channels, channels, kernel_size, padding=padding, dilation=dilation ) self.bn2 = nn.BatchNorm1d(channels) def forward(self, x): res = x x = F.gelu(self.bn1(self.conv1(x))) x = self.bn2(self.conv2(x)) return F.gelu(x + res) class GameChartEvaluator(nn.Module, PyTorchModelHubMixin): def __init__(self, input_dim=80, d_model=128, n_layers=4): super().__init__() # --- Early Fusion --- # Input is (Batch, 80 * 2, Time) # We stack Music (80) + Chart (80) = 160 channels self.input_proj = nn.Conv1d( input_dim * 2, d_model, kernel_size=3, stride=1, padding=1 ) # --- STRICT TEMPORAL ENCODER --- # No Pooling (stride=1) to preserve 11ms resolution # Dilations allow seeing context without losing resolution self.encoder = nn.Sequential( ResBlock1D(d_model, kernel_size=3, dilation=1), ResBlock1D(d_model, kernel_size=3, dilation=2), ResBlock1D(d_model, kernel_size=3, dilation=4), ResBlock1D(d_model, kernel_size=3, dilation=8), # Add more layers if you need wider context (e.g. 16, 32) ) # --- SCORING HEAD --- # Simple projection to scalar self.quality_proj = nn.Linear(d_model, 1) # Learnable Mixing self.raw_severity = nn.Parameter(torch.tensor(0.0)) def forward(self, music_mels, chart_mels): """ music_mels: (Batch, 80, Time) chart_mels: (Batch, 80, Time) """ # 1. Early Fusion: Concatenate along Channel dimension # Shape becomes (Batch, 160, Time) x = torch.cat([music_mels, chart_mels], dim=1) # 2. Extract Features (Strictly Local + Context) x = F.gelu(self.input_proj(x)) x = self.encoder(x) # 3. Predict Score per Frame # (Batch, Dim, Time) -> (Batch, Time, Dim) x = x.permute(0, 2, 1) local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1) # 4. Error-Sensitive Pooling avg_score = local_scores.mean(dim=1) k = max(1, int(local_scores.size(1) * 0.1)) min_vals, _ = torch.topk(local_scores, k, dim=1, largest=False) worst_score = min_vals.mean(dim=1) alpha = torch.sigmoid(self.raw_severity) final_score = (alpha * worst_score) + ((1 - alpha) * avg_score) return final_score.squeeze(1) def predict_trace(self, music_mels, chart_mels): """ Explainability Method: Returns the second-by-second quality curve. Returns: local_scores: (Batch, Time) - The quality score at every timestep. """ with torch.no_grad(): # 1. Early Fusion: Concatenate along Channel dimension # Shape becomes (Batch, 160, Time) x = torch.cat([music_mels, chart_mels], dim=1) # 2. Extract Features (Strictly Local + Context) x = F.gelu(self.input_proj(x)) x = self.encoder(x) # 3. Predict Score per Frame # (Batch, Dim, Time) -> (Batch, Time, Dim) x = x.permute(0, 2, 1) local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1) return local_scores.squeeze(2) if __name__ == "__main__": # Sanity Check from torchinfo import summary model = GameChartEvaluator() print( f"Model initialized. Learnable Severity: {torch.sigmoid(model.raw_severity).item():.2f}" ) # Dummy data (Batch=2, Freq=80, Time=1000) m = torch.randn(2, 80, 1000) c = torch.randn(2, 80, 1000) output = model(m, c) print(f"Output shape: {output.shape}") # Should be torch.Size([2]) print(f"Scores: {output}") # Trace check trace = model.predict_trace(m, c) print( f"Trace shape: {trace.shape}" ) # Should be torch.Size([2, 500]) (due to MaxPool1d(2)) summary(model, input_data=[m, c])