|
|
--- |
|
|
tags: |
|
|
- model_hub_mixin |
|
|
- pytorch_model_hub_mixin |
|
|
- audio |
|
|
- rhythm-game |
|
|
- music |
|
|
--- |
|
|
|
|
|
# GameChartEvaluator (GCE4) |
|
|
|
|
|
A neural network model for evaluating the quality of rhythm game charts relative to their corresponding music. The model predicts a quality score (0-1) indicating how well a chart synchronizes with the music. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
The model uses an early fusion approach with dilated convolutions for temporal analysis: |
|
|
|
|
|
1. **Early Fusion**: Concatenates music and chart mel spectrograms along the channel dimension (80 + 80 = 160 channels) |
|
|
2. **Dilated Residual Encoder**: 4 residual blocks with increasing dilation rates (1, 2, 4, 8) to capture multi-scale temporal context while preserving 11ms frame resolution. This gives the model a **receptive field of ~0.73s** (63 frames), meaning each time-step's score depends on the local ~0.36s context before and after. |
|
|
3. **Error-Sensitive Scoring Head**: Combines average local scores with the worst 10% of scores using a learnable mixing parameter |
|
|
|
|
|
``` |
|
|
Input: (B, 80, T) music_mels + (B, 80, T) chart_mels |
|
|
↓ Concatenate |
|
|
(B, 160, T) |
|
|
↓ Conv1D Projection |
|
|
(B, 128, T) |
|
|
↓ Dilated ResBlocks × 4 |
|
|
(B, 128, T) |
|
|
↓ Linear → Sigmoid (per-frame scores) |
|
|
(B, T, 1) |
|
|
↓ Error-Sensitive Pooling |
|
|
(B,) final score |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from gce4 import GameChartEvaluator |
|
|
|
|
|
model = GameChartEvaluator.from_pretrained("JacobLinCool/gce4") |
|
|
model.eval() |
|
|
|
|
|
# Input: 80-band mel spectrograms |
|
|
music_mels = torch.randn(1, 80, 1000) # (batch, freq, time) |
|
|
chart_mels = torch.randn(1, 80, 1000) |
|
|
|
|
|
# Get overall quality score (0-1) |
|
|
with torch.no_grad(): |
|
|
score = model(music_mels, chart_mels) |
|
|
print(f"Quality Score: {score.item():.3f}") |
|
|
|
|
|
# Get per-frame quality trace for explainability |
|
|
with torch.no_grad(): |
|
|
trace = model.predict_trace(music_mels, chart_mels) |
|
|
# trace shape: (batch, time) |
|
|
``` |
|
|
|
|
|
## Input Specifications |
|
|
|
|
|
- **music_mels**: `(Batch, 80, Time)` - Mel spectrogram of the music |
|
|
- **chart_mels**: `(Batch, 80, Time)` - Mel spectrogram of synthesized chart audio (click sounds at note positions) |
|
|
|
|
|
Both inputs should be normalized and have the same temporal dimensions. |
|
|
|
|
|
## Output |
|
|
|
|
|
- **forward()**: `(Batch,)` - Single quality score per sample in range [0, 1] |
|
|
- **predict_trace()**: `(Batch, Time)` - Per-frame quality scores for interpretability |
|
|
|
|
|
## Model Configuration |
|
|
|
|
|
| Parameter | Default | Description | |
|
|
|-----------|---------|-------------| |
|
|
| `input_dim` | 80 | Mel spectrogram frequency bins | |
|
|
| `d_model` | 128 | Hidden dimension | |
|
|
| `n_layers` | 4 | Number of residual blocks | |
|
|
|
|
|
## Training |
|
|
|
|
|
The model was trained to detect misaligned or poorly-synchronized rhythm game charts by comparing music-chart pairs with various synthetic corruptions (time shifts, random note placement, etc). |
|
|
|
|
|
## Evaluation Results |
|
|
|
|
|
Evaluation was performed on 2,204 test samples with various segment durations. The model uses a severity parameter of 0.56. |
|
|
|
|
|
### Overall Accuracy by Segment Duration |
|
|
|
|
|
| Duration | Overall | Positive | Shift | Random | Mismatch | |
|
|
|----------|---------|----------|-------|--------|----------| |
|
|
| 5s | 81.85% | 95.69% | 79.04% | 97.41% | 97.41% | |
|
|
| 10s | 83.35% | 96.55% | 80.60% | 97.41% | 100.00% | |
|
|
| 20s | 84.66% | 96.55% | 82.06% | 99.14% | 100.00% | |
|
|
| 30s | 85.30% | 95.69% | 82.81% | 100.00% | 100.00% | |
|
|
| 60s | 85.98% | 95.69% | 83.62% | 100.00% | 100.00% | |
|
|
| 120s | **86.25%** | 94.83% | **84.00%** | 100.00% | 100.00% | |
|
|
| 180s | 85.57% | 94.83% | 83.19% | 100.00% | 100.00% | |
|
|
|
|
|
### Shift Detection by Offset (120s segment) |
|
|
|
|
|
| Offset | Accuracy | Offset | Accuracy | |
|
|
|--------|----------|--------|----------| |
|
|
| -0.50s | 91.38% | +0.50s | 92.24% | |
|
|
| -0.30s | 89.66% | +0.30s | 89.66% | |
|
|
| -0.20s | 91.38% | +0.20s | 94.83% | |
|
|
| -0.10s | 100.00% | +0.10s | 100.00% | |
|
|
| -0.05s | 100.00% | +0.05s | 100.00% | |
|
|
| -0.03s | 91.38% | +0.03s | 95.69% | |
|
|
| -0.02s | 84.48% | +0.02s | 88.79% | |
|
|
| -0.01s | 20.69% | +0.01s | 13.79% | |
|
|
|
|
|
### Analysis |
|
|
|
|
|
The performance characteristics can be directly explained by the model's physical constraints: |
|
|
|
|
|
1. **Resolution Limit (±0.01s)**: Performance drops significantly here because the **10ms shift** is smaller than the model's temporal resolution (**~11.6ms per frame**). Sub-frame timing differences are mathematically difficult for the Convolutional Encoder to resolve. |
|
|
2. **Optimal Zone (±0.05s to ±0.20s)**: The model achieves **100% accuracy** here. These shifts are large enough to be resolved but small enough to fit within the **~0.36s half-receptive field**. The model can simultaneously "see" the music beat and the misaligned note, enabling a direct and precise comparison. |
|
|
3. **Field Boundary (±0.30s to ±0.50s)**: Accuracy dips slightly (to ~90%). A **0.50s shift** often pushes the note outside the receptive field of its corresponding music beat. The model can no longer compare them directly; instead, it must rely on detecting "a note without a corresponding beat" or vice-versa, which is a harder inference task (and prone to errors if the shift lands on a different valid beat). |
|
|
|