JacobLinCool commited on
Commit
dce0030
·
verified ·
1 Parent(s): c8739ce

Upload 17 files

Browse files
TaikoChartEstimator/__init__.py ADDED
File without changes
TaikoChartEstimator/constants.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Centralized Constants for TaikoChartEstimator
3
+
4
+ Consolidates all difficulty mappings, note types, and star ranges
5
+ to avoid duplication across modules.
6
+ """
7
+
8
+ from typing import Dict, Tuple
9
+
10
+ # =============================================================================
11
+ # Note Types
12
+ # =============================================================================
13
+
14
+ NOTE_TYPES = [
15
+ "Don", # 0
16
+ "Ka", # 1
17
+ "DonBig", # 2
18
+ "KaBig", # 3
19
+ "Roll", # 4
20
+ "RollBig", # 5
21
+ "Balloon", # 6
22
+ "BalloonAlt", # 7
23
+ "EndOf", # 8
24
+ ]
25
+
26
+ NOTE_TYPE_TO_ID: Dict[str, int] = {
27
+ note_type: i for i, note_type in enumerate(NOTE_TYPES)
28
+ }
29
+ NUM_NOTE_TYPES = len(NOTE_TYPES)
30
+ PAD_TOKEN_ID = NUM_NOTE_TYPES # 9 for padding
31
+
32
+ # =============================================================================
33
+ # Difficulty Classes
34
+ # =============================================================================
35
+
36
+ # Original 5 classes
37
+ DIFFICULTY_CLASSES = ["easy", "normal", "hard", "oni", "ura"]
38
+
39
+ # Merged classes (ura -> oni)
40
+ DIFFICULTY_CLASSES_MERGED = ["easy", "normal", "hard", "oni_ura"]
41
+ NUM_DIFFICULTY_CLASSES = len(DIFFICULTY_CLASSES)
42
+ NUM_DIFFICULTY_CLASSES_MERGED = len(DIFFICULTY_CLASSES_MERGED)
43
+
44
+ # Difficulty name -> class ID mapping (handles both cases)
45
+ DIFFICULTY_TO_ID: Dict[str, int] = {}
46
+ for i, d in enumerate(DIFFICULTY_CLASSES):
47
+ DIFFICULTY_TO_ID[d] = i
48
+ DIFFICULTY_TO_ID[d.capitalize()] = i
49
+
50
+ # Difficulty ordering for ranking comparisons
51
+ DIFFICULTY_ORDER: Dict[str, int] = {
52
+ "easy": 0,
53
+ "Easy": 0,
54
+ "normal": 1,
55
+ "Normal": 1,
56
+ "hard": 2,
57
+ "Hard": 2,
58
+ "oni": 3,
59
+ "Oni": 3,
60
+ "ura": 4,
61
+ "Ura": 4,
62
+ }
63
+
64
+ # =============================================================================
65
+ # Star Ranges per Difficulty
66
+ # =============================================================================
67
+
68
+ # Star ranges by difficulty index
69
+ STAR_RANGES_BY_ID: Dict[int, Tuple[int, int]] = {
70
+ 0: (1, 5), # easy
71
+ 1: (1, 7), # normal
72
+ 2: (1, 8), # hard
73
+ 3: (1, 10), # oni
74
+ 4: (1, 10), # ura
75
+ }
76
+
77
+ # Star ranges by difficulty name (includes capitalized versions)
78
+ STAR_RANGES_BY_NAME: Dict[str, Tuple[int, int]] = {
79
+ "easy": (1, 5),
80
+ "Easy": (1, 5),
81
+ "normal": (1, 7),
82
+ "Normal": (1, 7),
83
+ "hard": (1, 8),
84
+ "Hard": (1, 8),
85
+ "oni": (1, 10),
86
+ "Oni": (1, 10),
87
+ "ura": (1, 10),
88
+ "Ura": (1, 10),
89
+ }
90
+
91
+ # =============================================================================
92
+ # Helper Functions
93
+ # =============================================================================
94
+
95
+
96
+ def merge_difficulty_class(class_id: int) -> int:
97
+ """Merge ura (4) into oni (3) for classification."""
98
+ return 3 if class_id == 4 else class_id
99
+
100
+
101
+ def get_difficulty_name(class_id: int, merged: bool = False) -> str:
102
+ """Get difficulty name from class ID."""
103
+ if merged:
104
+ return DIFFICULTY_CLASSES_MERGED[min(class_id, 3)]
105
+ return DIFFICULTY_CLASSES[class_id]
TaikoChartEstimator/data/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TaikoChartEstimator Data Pipeline
3
+
4
+ Provides event tokenization, dataset loading, and audio processing for
5
+ MIL-based Taiko chart difficulty estimation.
6
+ """
7
+
8
+ from .audio import AudioProcessor
9
+ from .dataset import (
10
+ ChartBag,
11
+ SongGroup,
12
+ TaikoChartDataset,
13
+ WithinSongPairSampler,
14
+ collate_chart_bags,
15
+ )
16
+ from .tokenizer import NOTE_TYPE_TO_ID, NOTE_TYPES, EventToken, EventTokenizer
17
+
18
+ __all__ = [
19
+ "EventToken",
20
+ "EventTokenizer",
21
+ "NOTE_TYPES",
22
+ "NOTE_TYPE_TO_ID",
23
+ "TaikoChartDataset",
24
+ "ChartBag",
25
+ "SongGroup",
26
+ "WithinSongPairSampler",
27
+ "collate_chart_bags",
28
+ "AudioProcessor",
29
+ ]
TaikoChartEstimator/data/audio.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Processing for Taiko Chart Estimation
3
+
4
+ Handles mel spectrogram extraction and alignment with chart events.
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+ import torchaudio.transforms as T
14
+
15
+
16
+ class AudioProcessor:
17
+ """
18
+ Processes audio waveforms into mel spectrograms for model input.
19
+
20
+ Features:
21
+ - Mel spectrogram extraction with configurable parameters
22
+ - Window extraction aligned with chart timing
23
+ - Optional augmentation (time stretch, pitch shift)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ sample_rate: int = 22050,
29
+ n_mels: int = 128,
30
+ n_fft: int = 2048,
31
+ hop_length: int = 512,
32
+ f_min: float = 20.0,
33
+ f_max: float = 8000.0,
34
+ normalize: bool = True,
35
+ ):
36
+ """
37
+ Initialize audio processor.
38
+
39
+ Args:
40
+ sample_rate: Target sample rate for audio
41
+ n_mels: Number of mel frequency bins
42
+ n_fft: FFT window size
43
+ hop_length: Hop length for STFT
44
+ f_min: Minimum frequency for mel filterbank
45
+ f_max: Maximum frequency for mel filterbank
46
+ normalize: Whether to normalize spectrograms
47
+ """
48
+ self.sample_rate = sample_rate
49
+ self.n_mels = n_mels
50
+ self.n_fft = n_fft
51
+ self.hop_length = hop_length
52
+ self.f_min = f_min
53
+ self.f_max = f_max
54
+ self.normalize = normalize
55
+
56
+ # Mel spectrogram transform
57
+ self.mel_transform = T.MelSpectrogram(
58
+ sample_rate=sample_rate,
59
+ n_mels=n_mels,
60
+ n_fft=n_fft,
61
+ hop_length=hop_length,
62
+ f_min=f_min,
63
+ f_max=f_max,
64
+ power=2.0,
65
+ )
66
+
67
+ # Amplitude to dB
68
+ self.amplitude_to_db = T.AmplitudeToDB(stype="power", top_db=80)
69
+
70
+ # Resampler cache
71
+ self._resamplers: dict[int, T.Resample] = {}
72
+
73
+ def _get_resampler(self, orig_sr: int) -> T.Resample:
74
+ """Get or create a resampler for the given source sample rate."""
75
+ if orig_sr not in self._resamplers:
76
+ self._resamplers[orig_sr] = T.Resample(orig_sr, self.sample_rate)
77
+ return self._resamplers[orig_sr]
78
+
79
+ def process_audio(
80
+ self,
81
+ waveform: np.ndarray | torch.Tensor,
82
+ orig_sample_rate: int,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Process raw audio waveform to mel spectrogram.
86
+
87
+ Args:
88
+ waveform: Audio waveform array [samples] or [channels, samples]
89
+ orig_sample_rate: Original sample rate of the audio
90
+
91
+ Returns:
92
+ Mel spectrogram tensor [n_mels, time_frames]
93
+ """
94
+ # Convert to tensor if needed
95
+ if isinstance(waveform, np.ndarray):
96
+ waveform = torch.from_numpy(waveform).float()
97
+
98
+ # Ensure 2D [channels, samples]
99
+ if waveform.dim() == 1:
100
+ waveform = waveform.unsqueeze(0)
101
+
102
+ # Convert stereo to mono
103
+ if waveform.size(0) > 1:
104
+ waveform = waveform.mean(dim=0, keepdim=True)
105
+
106
+ # Resample if needed
107
+ if orig_sample_rate != self.sample_rate:
108
+ resampler = self._get_resampler(orig_sample_rate)
109
+ waveform = resampler(waveform)
110
+
111
+ # Compute mel spectrogram
112
+ mel_spec = self.mel_transform(waveform)
113
+
114
+ # Convert to dB scale
115
+ mel_spec_db = self.amplitude_to_db(mel_spec)
116
+
117
+ # Remove channel dimension
118
+ mel_spec_db = mel_spec_db.squeeze(0)
119
+
120
+ # Normalize if requested
121
+ if self.normalize:
122
+ mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (
123
+ mel_spec_db.std() + 1e-8
124
+ )
125
+
126
+ return mel_spec_db
127
+
128
+ def time_to_frame(self, time_sec: float) -> int:
129
+ """Convert time in seconds to frame index."""
130
+ return int(time_sec * self.sample_rate / self.hop_length)
131
+
132
+ def frame_to_time(self, frame_idx: int) -> float:
133
+ """Convert frame index to time in seconds."""
134
+ return frame_idx * self.hop_length / self.sample_rate
135
+
136
+ def extract_window(
137
+ self,
138
+ mel_spec: torch.Tensor,
139
+ start_time: float,
140
+ end_time: float,
141
+ pad_value: float = 0.0,
142
+ ) -> torch.Tensor:
143
+ """
144
+ Extract a time window from mel spectrogram.
145
+
146
+ Args:
147
+ mel_spec: Full mel spectrogram [n_mels, time_frames]
148
+ start_time: Window start time in seconds
149
+ end_time: Window end time in seconds
150
+ pad_value: Value for padding if window extends beyond spectrogram
151
+
152
+ Returns:
153
+ Window tensor [n_mels, window_frames]
154
+ """
155
+ start_frame = self.time_to_frame(start_time)
156
+ end_frame = self.time_to_frame(end_time)
157
+
158
+ # Clamp to valid range
159
+ start_frame = max(0, start_frame)
160
+ end_frame = min(mel_spec.size(1), end_frame)
161
+
162
+ window = mel_spec[:, start_frame:end_frame]
163
+
164
+ # Pad if window is shorter than expected
165
+ expected_frames = self.time_to_frame(end_time - start_time)
166
+ if window.size(1) < expected_frames:
167
+ pad_size = expected_frames - window.size(1)
168
+ window = F.pad(window, (0, pad_size), value=pad_value)
169
+
170
+ return window
171
+
172
+ def extract_windows_for_instances(
173
+ self,
174
+ mel_spec: torch.Tensor,
175
+ instance_times: list[tuple[float, float]],
176
+ fixed_frames: Optional[int] = None,
177
+ ) -> list[torch.Tensor]:
178
+ """
179
+ Extract mel spectrogram windows aligned with chart instances.
180
+
181
+ Args:
182
+ mel_spec: Full mel spectrogram [n_mels, time_frames]
183
+ instance_times: List of (start_time, end_time) for each instance
184
+ fixed_frames: If provided, resize all windows to this frame count
185
+
186
+ Returns:
187
+ List of window tensors
188
+ """
189
+ windows = []
190
+
191
+ for start_time, end_time in instance_times:
192
+ window = self.extract_window(mel_spec, start_time, end_time)
193
+
194
+ if fixed_frames is not None and window.size(1) != fixed_frames:
195
+ # Resize to fixed frame count
196
+ window = F.interpolate(
197
+ window.unsqueeze(0),
198
+ size=fixed_frames,
199
+ mode="linear",
200
+ align_corners=False,
201
+ ).squeeze(0)
202
+
203
+ windows.append(window)
204
+
205
+ return windows
206
+
207
+ def compute_onset_strength(self, mel_spec: torch.Tensor) -> torch.Tensor:
208
+ """
209
+ Compute onset strength envelope from mel spectrogram.
210
+
211
+ Useful for beat tracking and rhythm analysis.
212
+
213
+ Args:
214
+ mel_spec: Mel spectrogram [n_mels, time_frames]
215
+
216
+ Returns:
217
+ Onset strength envelope [time_frames]
218
+ """
219
+ # Compute first-order difference
220
+ diff = torch.diff(mel_spec, dim=1)
221
+
222
+ # Half-wave rectify (keep only positive changes)
223
+ diff = F.relu(diff)
224
+
225
+ # Sum across frequency bins
226
+ onset_env = diff.sum(dim=0)
227
+
228
+ # Pad to match original length
229
+ onset_env = F.pad(onset_env, (1, 0))
230
+
231
+ return onset_env
TaikoChartEstimator/data/dataset.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taiko Chart Dataset for MIL-based Difficulty Estimation
3
+
4
+ Loads data from JacobLinCool/taiko-1000-parsed and provides:
5
+ - ChartBag: A single chart with its instances (windows)
6
+ - SongGroup: All difficulty charts for a single song (for ranking loss)
7
+ - Within-song pair sampling for training
8
+ """
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import Iterator, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from datasets import Dataset as HFDataset
16
+ from datasets import load_dataset
17
+ from torch.utils.data import Dataset, Sampler
18
+
19
+ # Import from centralized constants
20
+ from ..constants import (
21
+ DIFFICULTY_CLASSES,
22
+ DIFFICULTY_ORDER,
23
+ NOTE_TYPE_TO_ID,
24
+ )
25
+ from ..constants import (
26
+ DIFFICULTY_TO_ID as DIFFICULTY_TO_CLASS_ID,
27
+ )
28
+ from ..constants import (
29
+ STAR_RANGES_BY_NAME as STAR_RANGES,
30
+ )
31
+ from .audio import AudioProcessor
32
+ from .tokenizer import EventToken, EventTokenizer
33
+
34
+
35
+ @dataclass
36
+ class ChartBag:
37
+ """
38
+ A single chart represented as a bag of instances for MIL.
39
+
40
+ Attributes:
41
+ song_id: Unique identifier for the song
42
+ difficulty: Difficulty level (easy/normal/hard/oni/ura)
43
+ difficulty_class_id: Integer class ID for difficulty
44
+ star: Star rating from label (1-10)
45
+ is_right_censored: True if star == max for difficulty (label is lower bound)
46
+ is_left_censored: True if star == min for difficulty (label is upper bound)
47
+ instances: List of token tensors for each window
48
+ instance_masks: Attention masks for each instance
49
+ instance_times: (start, end) time for each instance
50
+ audio_mel: Optional full mel spectrogram for the song
51
+ """
52
+
53
+ song_id: str
54
+ difficulty: str
55
+ difficulty_class_id: int
56
+ star: int
57
+ is_right_censored: bool
58
+ is_left_censored: bool
59
+ instances: list[torch.Tensor] = field(default_factory=list)
60
+ instance_masks: list[torch.Tensor] = field(default_factory=list)
61
+ instance_times: list[tuple[float, float]] = field(default_factory=list)
62
+ audio_mel: Optional[torch.Tensor] = None
63
+
64
+ def __len__(self) -> int:
65
+ return len(self.instances)
66
+
67
+
68
+ @dataclass
69
+ class SongGroup:
70
+ """
71
+ All charts for a single song, for within-song ranking loss.
72
+
73
+ Charts are ordered by difficulty (easy < normal < hard < oni < ura).
74
+ """
75
+
76
+ song_id: str
77
+ charts: list[ChartBag] = field(default_factory=list)
78
+
79
+ def get_ranking_pairs(self) -> list[tuple[ChartBag, ChartBag]]:
80
+ """
81
+ Get all adjacent difficulty pairs for ranking loss.
82
+
83
+ Returns:
84
+ List of (easier_chart, harder_chart) tuples
85
+ """
86
+ # Sort by difficulty order
87
+ sorted_charts = sorted(
88
+ self.charts, key=lambda c: DIFFICULTY_ORDER.get(c.difficulty, 0)
89
+ )
90
+
91
+ pairs = []
92
+ for i in range(len(sorted_charts) - 1):
93
+ pairs.append((sorted_charts[i], sorted_charts[i + 1]))
94
+
95
+ return pairs
96
+
97
+
98
+ class TaikoChartDataset(Dataset):
99
+ """
100
+ PyTorch Dataset for Taiko chart difficulty estimation.
101
+
102
+ Loads from HuggingFace dataset and provides ChartBag instances.
103
+ Supports multi-scale windowing and optional audio features.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ split: str = "train",
109
+ dataset_name: str = "JacobLinCool/taiko-1000-parsed",
110
+ window_measures: list[int] = [2, 4],
111
+ hop_measures: int = 2,
112
+ max_instances_per_chart: int = 64,
113
+ max_tokens_per_instance: int = 128,
114
+ include_audio: bool = False,
115
+ cache_dir: Optional[str] = None,
116
+ ):
117
+ """
118
+ Initialize dataset.
119
+
120
+ Args:
121
+ split: Dataset split ("train" or "test")
122
+ dataset_name: HuggingFace dataset name
123
+ window_measures: Window sizes in measures for multi-scale
124
+ hop_measures: Hop size in measures
125
+ max_instances_per_chart: Maximum instances to keep per chart
126
+ max_tokens_per_instance: Maximum tokens per instance
127
+ include_audio: Whether to load and process audio
128
+ cache_dir: Cache directory for dataset
129
+ """
130
+ self.split = split
131
+ self.window_measures = window_measures
132
+ self.hop_measures = hop_measures
133
+ self.max_instances_per_chart = max_instances_per_chart
134
+ self.max_tokens_per_instance = max_tokens_per_instance
135
+ self.include_audio = include_audio
136
+
137
+ # Initialize processors
138
+ self.tokenizer = EventTokenizer()
139
+ self.audio_processor = AudioProcessor() if include_audio else None
140
+
141
+ # Load HuggingFace dataset
142
+ self.hf_dataset = load_dataset(
143
+ dataset_name,
144
+ split=split,
145
+ cache_dir=cache_dir,
146
+ )
147
+
148
+ # Build index of all charts (song_idx, difficulty)
149
+ self._build_chart_index()
150
+
151
+ def _build_chart_index(self):
152
+ """Build an index of all available charts across songs."""
153
+ self.chart_index: list[tuple[int, str]] = [] # (song_idx, difficulty)
154
+ self.song_groups: dict[int, SongGroup] = {} # song_idx -> SongGroup
155
+
156
+ difficulties = ["easy", "normal", "hard", "oni", "ura"]
157
+
158
+ for song_idx in range(len(self.hf_dataset)):
159
+ song = self.hf_dataset[song_idx]
160
+ song_id = f"song_{song_idx}"
161
+
162
+ # Check which difficulties are available
163
+ available_diffs = []
164
+ for diff in difficulties:
165
+ if diff in song and song[diff] is not None:
166
+ diff_data = song[diff]
167
+ # Check if it has valid segments
168
+ if diff_data.get("segments") and len(diff_data["segments"]) > 0:
169
+ self.chart_index.append((song_idx, diff))
170
+ available_diffs.append(diff)
171
+
172
+ # Create song group
173
+ if available_diffs:
174
+ self.song_groups[song_idx] = SongGroup(song_id=song_id)
175
+
176
+ def __len__(self) -> int:
177
+ return len(self.chart_index)
178
+
179
+ def _process_chart(
180
+ self,
181
+ song_data: dict,
182
+ song_idx: int,
183
+ difficulty: str,
184
+ ) -> ChartBag:
185
+ """Process a single chart into a ChartBag."""
186
+ song_id = f"song_{song_idx}"
187
+ diff_data = song_data[difficulty]
188
+
189
+ # Get star rating and censoring info
190
+ star = diff_data.get("level", 5) # Default to 5 if missing
191
+ min_star, max_star = STAR_RANGES.get(difficulty, (1, 10))
192
+ is_right_censored = star >= max_star
193
+ is_left_censored = star <= min_star
194
+
195
+ # Get difficulty class ID
196
+ diff_class_id = DIFFICULTY_TO_CLASS_ID.get(difficulty, 0)
197
+
198
+ # Tokenize chart notes
199
+ segments = diff_data.get("segments", [])
200
+ tokens = self.tokenizer.tokenize_chart(segments)
201
+
202
+ # Create multi-scale windows
203
+ all_instances = []
204
+ all_masks = []
205
+ all_times = []
206
+
207
+ for window_size in self.window_measures:
208
+ windows = self.tokenizer.create_windows(
209
+ tokens,
210
+ window_measures=window_size,
211
+ hop_measures=self.hop_measures,
212
+ )
213
+
214
+ for window_tokens in windows:
215
+ if not window_tokens:
216
+ continue
217
+
218
+ # Convert to tensor
219
+ tensor, mask = self.tokenizer.tokens_to_tensor(
220
+ window_tokens,
221
+ max_length=self.max_tokens_per_instance,
222
+ )
223
+
224
+ # Pad to max length
225
+ tensor, mask = self.tokenizer.pad_sequence(
226
+ tensor, mask, self.max_tokens_per_instance
227
+ )
228
+
229
+ # Record time range
230
+ start_time = window_tokens[0].timestamp
231
+ end_time = window_tokens[-1].timestamp
232
+
233
+ all_instances.append(tensor)
234
+ all_masks.append(mask)
235
+ all_times.append((start_time, end_time))
236
+
237
+ # Limit number of instances
238
+ if len(all_instances) > self.max_instances_per_chart:
239
+ # Sample uniformly
240
+ indices = np.linspace(
241
+ 0, len(all_instances) - 1, self.max_instances_per_chart, dtype=int
242
+ )
243
+ all_instances = [all_instances[i] for i in indices]
244
+ all_masks = [all_masks[i] for i in indices]
245
+ all_times = [all_times[i] for i in indices]
246
+
247
+ # Process audio if requested
248
+ audio_mel = None
249
+ if self.include_audio and "audio" in song_data:
250
+ audio_data = song_data["audio"]
251
+ if audio_data is not None:
252
+ waveform = audio_data.get("array")
253
+ sr = audio_data.get("sampling_rate", 22050)
254
+ if waveform is not None:
255
+ audio_mel = self.audio_processor.process_audio(waveform, sr)
256
+
257
+ return ChartBag(
258
+ song_id=song_id,
259
+ difficulty=difficulty,
260
+ difficulty_class_id=diff_class_id,
261
+ star=star,
262
+ is_right_censored=is_right_censored,
263
+ is_left_censored=is_left_censored,
264
+ instances=all_instances,
265
+ instance_masks=all_masks,
266
+ instance_times=all_times,
267
+ audio_mel=audio_mel,
268
+ )
269
+
270
+ def __getitem__(self, idx: int) -> ChartBag:
271
+ song_idx, difficulty = self.chart_index[idx]
272
+ song_data = self.hf_dataset[song_idx]
273
+ return self._process_chart(song_data, song_idx, difficulty)
274
+
275
+ def get_song_group(self, song_idx: int) -> SongGroup:
276
+ """
277
+ Get all charts for a song as a SongGroup.
278
+
279
+ Args:
280
+ song_idx: Index in the HuggingFace dataset
281
+
282
+ Returns:
283
+ SongGroup with all available difficulty charts
284
+ """
285
+ song_data = self.hf_dataset[song_idx]
286
+ song_id = f"song_{song_idx}"
287
+ group = SongGroup(song_id=song_id)
288
+
289
+ for diff in DIFFICULTY_CLASSES:
290
+ if diff in song_data and song_data[diff] is not None:
291
+ diff_data = song_data[diff]
292
+ if diff_data.get("segments") and len(diff_data["segments"]) > 0:
293
+ chart = self._process_chart(song_data, song_idx, diff)
294
+ group.charts.append(chart)
295
+
296
+ return group
297
+
298
+ def get_all_song_indices(self) -> list[int]:
299
+ """Get list of unique song indices in the dataset."""
300
+ return list(self.song_groups.keys())
301
+
302
+
303
+ class WithinSongBatchSampler(Sampler[list[int]]):
304
+ """
305
+ BatchSampler that ensures each batch contains complete song groups.
306
+
307
+ This prevents ranking loss from being broken by batch boundaries that
308
+ split charts from the same song into different batches.
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ dataset: TaikoChartDataset,
314
+ min_batch_size: int = 16,
315
+ shuffle: bool = True,
316
+ seed: int = 2025,
317
+ ):
318
+ """
319
+ Initialize batch sampler.
320
+
321
+ Args:
322
+ dataset: The TaikoChartDataset
323
+ min_batch_size: Minimum number of charts per batch
324
+ shuffle: Whether to shuffle songs each epoch
325
+ seed: Random seed
326
+ """
327
+ self.dataset = dataset
328
+ self.min_batch_size = min_batch_size
329
+ self.shuffle = shuffle
330
+ self.rng = np.random.default_rng(seed)
331
+
332
+ # Build song to chart indices mapping
333
+ self.song_to_charts: dict[int, list[int]] = {}
334
+ for chart_idx, (song_idx, diff) in enumerate(dataset.chart_index):
335
+ if song_idx not in self.song_to_charts:
336
+ self.song_to_charts[song_idx] = []
337
+ self.song_to_charts[song_idx].append(chart_idx)
338
+
339
+ self.song_indices = list(self.song_to_charts.keys())
340
+
341
+ def __iter__(self) -> Iterator[list[int]]:
342
+ """Yield batches of chart indices, with complete song groups."""
343
+ song_order = self.song_indices.copy()
344
+ if self.shuffle:
345
+ self.rng.shuffle(song_order)
346
+
347
+ current_batch: list[int] = []
348
+
349
+ for song_idx in song_order:
350
+ chart_indices = self.song_to_charts[song_idx].copy()
351
+ if self.shuffle:
352
+ self.rng.shuffle(chart_indices)
353
+
354
+ # Add all charts from this song to current batch
355
+ current_batch.extend(chart_indices)
356
+
357
+ # Yield batch when we have enough samples
358
+ if len(current_batch) >= self.min_batch_size:
359
+ yield current_batch
360
+ current_batch = []
361
+
362
+ # Yield remaining samples
363
+ if current_batch:
364
+ yield current_batch
365
+
366
+ def __len__(self) -> int:
367
+ # Approximate number of batches
368
+ total_charts = len(self.dataset)
369
+ return max(1, total_charts // self.min_batch_size)
370
+
371
+
372
+ # Keep old class name as alias for backward compatibility
373
+ WithinSongPairSampler = WithinSongBatchSampler
374
+
375
+
376
+ def collate_chart_bags(bags: list[ChartBag], max_seq_len: int = 128) -> dict:
377
+ """
378
+ Collate function for ChartBag instances.
379
+
380
+ Args:
381
+ bags: List of ChartBag instances to collate
382
+ max_seq_len: Fallback sequence length for padding empty instances
383
+
384
+ Returns a dictionary suitable for model input.
385
+ """
386
+ # Stack instances: need to handle variable numbers
387
+ max_instances = max(len(b.instances) for b in bags)
388
+
389
+ # Infer sequence length from first non-empty bag, or use parameter
390
+ inferred_seq_len = max_seq_len
391
+ for bag in bags:
392
+ if bag.instances:
393
+ inferred_seq_len = bag.instances[0].shape[0]
394
+ break
395
+
396
+ # Pad instances to same count
397
+ batch_instances = []
398
+ batch_masks = []
399
+ instance_counts = []
400
+
401
+ for bag in bags:
402
+ instances = bag.instances
403
+ masks = bag.instance_masks
404
+
405
+ # Pad to max_instances
406
+ n_pad = max_instances - len(instances)
407
+ if n_pad > 0:
408
+ # Infer shape from existing instances or use fallback
409
+ pad_shape = instances[0].shape if instances else (inferred_seq_len, 6)
410
+ instances = instances + [torch.zeros(pad_shape) for _ in range(n_pad)]
411
+ masks = masks + [torch.zeros(pad_shape[0]) for _ in range(n_pad)]
412
+
413
+ batch_instances.append(torch.stack(instances))
414
+ batch_masks.append(torch.stack(masks))
415
+ instance_counts.append(len(bag.instances))
416
+
417
+ return {
418
+ "instances": torch.stack(batch_instances), # [B, N, L, 6]
419
+ "instance_masks": torch.stack(batch_masks), # [B, N, L]
420
+ "instance_counts": torch.tensor(instance_counts), # [B]
421
+ "difficulty_class": torch.tensor([b.difficulty_class_id for b in bags]), # [B]
422
+ "star": torch.tensor([b.star for b in bags], dtype=torch.float32), # [B]
423
+ "is_right_censored": torch.tensor([b.is_right_censored for b in bags]), # [B]
424
+ "is_left_censored": torch.tensor([b.is_left_censored for b in bags]), # [B]
425
+ "song_ids": [b.song_id for b in bags], # List[str]
426
+ "difficulties": [b.difficulty for b in bags], # List[str]
427
+ }
TaikoChartEstimator/data/tokenizer.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event Tokenizer for Taiko Chart Notes
3
+
4
+ Converts raw chart note data into event tokens suitable for sequence modeling.
5
+ Handles 9 note types with continuous features (BPM, scroll, timestamp, duration).
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ # Import from centralized constants
15
+ from ..constants import (
16
+ DIFFICULTY_ORDER,
17
+ NOTE_TYPE_TO_ID,
18
+ NOTE_TYPES,
19
+ PAD_TOKEN_ID,
20
+ )
21
+ from ..constants import (
22
+ STAR_RANGES_BY_NAME as STAR_RANGES,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class EventToken:
28
+ """A single event token representing a note or event in the chart."""
29
+
30
+ timestamp: float # Absolute time in seconds
31
+ beat_position: float # Position within the measure (0-1)
32
+ note_type: int # ID from NOTE_TYPE_TO_ID
33
+ duration: float # Duration for rolls/balloons (0 for regular notes)
34
+ bpm: float # Current BPM at this event
35
+ scroll: float # Scroll speed multiplier
36
+ gogo: bool # Whether in GOGO time (increased scoring)
37
+
38
+ def to_tensor(self) -> torch.Tensor:
39
+ """Convert to tensor representation [type_id, beat_pos, duration, bpm, scroll, gogo]."""
40
+ return torch.tensor(
41
+ [
42
+ self.note_type,
43
+ self.beat_position,
44
+ self.duration,
45
+ self.bpm,
46
+ self.scroll,
47
+ float(self.gogo),
48
+ ],
49
+ dtype=torch.float32,
50
+ )
51
+
52
+
53
+ class EventTokenizer:
54
+ """
55
+ Tokenizes Taiko chart data into event token sequences.
56
+
57
+ Features:
58
+ - Extracts note events from segments
59
+ - Computes beat-relative positions
60
+ - Normalizes continuous features (BPM, scroll)
61
+ - Creates beat-aligned windows for MIL instances
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ bpm_mean: float = 150.0,
67
+ bpm_std: float = 50.0,
68
+ scroll_mean: float = 1.0,
69
+ scroll_std: float = 0.5,
70
+ max_duration: float = 4.0, # Max roll/balloon duration in beats
71
+ ):
72
+ self.bpm_mean = bpm_mean
73
+ self.bpm_std = bpm_std
74
+ self.scroll_mean = scroll_mean
75
+ self.scroll_std = scroll_std
76
+ self.max_duration = max_duration
77
+
78
+ def tokenize_chart(self, segments: list[dict]) -> list[EventToken]:
79
+ """
80
+ Convert chart segments to a list of EventTokens.
81
+
82
+ Args:
83
+ segments: List of segment dicts from the dataset
84
+
85
+ Returns:
86
+ List of EventToken objects, sorted by timestamp
87
+ """
88
+ tokens = []
89
+
90
+ for segment in segments:
91
+ segment_start = segment["timestamp"]
92
+ measure_num = segment.get("measure_num", 4)
93
+ measure_den = segment.get("measure_den", 4)
94
+ notes = segment.get("notes", [])
95
+
96
+ for note in notes:
97
+ note_type_str = note.get("note_type", "Don")
98
+ if note_type_str not in NOTE_TYPE_TO_ID:
99
+ continue # Skip unknown note types
100
+
101
+ # Calculate beat position within measure
102
+ note_time = note.get("timestamp", segment_start)
103
+
104
+ # Estimate beat position (simplified - assuming 4/4)
105
+ beat_in_measure = (
106
+ (note_time - segment_start) * note.get("bpm", 120) / 60
107
+ ) % measure_num
108
+ beat_position = (
109
+ beat_in_measure / measure_num if measure_num > 0 else 0.0
110
+ )
111
+
112
+ # Calculate duration for long notes
113
+ duration = 0.0
114
+ if note_type_str in ["Roll", "RollBig", "Balloon", "BalloonAlt"]:
115
+ # Duration will be until EndOf, but we estimate from context
116
+ duration = note.get("delay", 0.0) # Use delay as duration hint
117
+
118
+ token = EventToken(
119
+ timestamp=note_time,
120
+ beat_position=beat_position,
121
+ note_type=NOTE_TYPE_TO_ID[note_type_str],
122
+ duration=min(duration, self.max_duration),
123
+ bpm=note.get("bpm", 120.0),
124
+ scroll=note.get("scroll", 1.0),
125
+ gogo=note.get("gogo", False),
126
+ )
127
+ tokens.append(token)
128
+
129
+ # Sort by timestamp
130
+ tokens.sort(key=lambda t: t.timestamp)
131
+ return tokens
132
+
133
+ def compute_note_density(
134
+ self, tokens: list[EventToken], window_sec: float = 1.0
135
+ ) -> list[float]:
136
+ """
137
+ Compute local note density for each token (notes per second in window).
138
+
139
+ Args:
140
+ tokens: List of EventTokens
141
+ window_sec: Window size in seconds for density calculation
142
+
143
+ Returns:
144
+ List of density values, one per token
145
+ """
146
+ if not tokens:
147
+ return []
148
+
149
+ timestamps = np.array([t.timestamp for t in tokens])
150
+ densities = []
151
+
152
+ for i, t in enumerate(tokens):
153
+ # Count notes in window centered on this note
154
+ window_start = t.timestamp - window_sec / 2
155
+ window_end = t.timestamp + window_sec / 2
156
+ count = np.sum((timestamps >= window_start) & (timestamps <= window_end))
157
+ density = count / window_sec
158
+ densities.append(density)
159
+
160
+ return densities
161
+
162
+ def create_windows(
163
+ self,
164
+ tokens: list[EventToken],
165
+ window_measures: int = 4,
166
+ hop_measures: int = 2,
167
+ default_bpm: float = 120.0,
168
+ ) -> list[list[EventToken]]:
169
+ """
170
+ Create beat-aligned windows from token sequence, respecting BPM changes.
171
+
172
+ Windows are created within BPM-consistent segments to ensure proper
173
+ beat alignment. This prevents window boundaries from falling on
174
+ off-beats when BPM changes occur.
175
+
176
+ Args:
177
+ tokens: List of EventTokens
178
+ window_measures: Window size in measures
179
+ hop_measures: Hop size in measures
180
+ default_bpm: Default BPM if not available
181
+
182
+ Returns:
183
+ List of token subsequences (windows)
184
+ """
185
+ if not tokens:
186
+ return []
187
+
188
+ # Split tokens by BPM changes
189
+ segments = self._split_by_bpm(tokens, threshold=5.0)
190
+
191
+ all_windows = []
192
+ for segment_tokens in segments:
193
+ if not segment_tokens:
194
+ continue
195
+
196
+ # Use this segment's BPM for window calculation
197
+ segment_bpm = (
198
+ segment_tokens[0].bpm if segment_tokens[0].bpm > 0 else default_bpm
199
+ )
200
+ beats_per_measure = 4 # Assuming 4/4 time
201
+ measure_duration = (beats_per_measure * 60) / segment_bpm
202
+
203
+ window_duration = window_measures * measure_duration
204
+ hop_duration = hop_measures * measure_duration
205
+
206
+ # Create windows within this segment
207
+ start_time = segment_tokens[0].timestamp
208
+ end_time = segment_tokens[-1].timestamp
209
+ current_start = start_time
210
+
211
+ while current_start < end_time:
212
+ window_end = current_start + window_duration
213
+
214
+ # Get tokens in this window
215
+ window_tokens = [
216
+ t
217
+ for t in segment_tokens
218
+ if current_start <= t.timestamp < window_end
219
+ ]
220
+
221
+ if window_tokens: # Only add non-empty windows
222
+ all_windows.append(window_tokens)
223
+
224
+ current_start += hop_duration
225
+
226
+ return all_windows
227
+
228
+ def _split_by_bpm(
229
+ self,
230
+ tokens: list[EventToken],
231
+ threshold: float = 5.0,
232
+ ) -> list[list[EventToken]]:
233
+ """
234
+ Split token list into segments with consistent BPM.
235
+
236
+ Args:
237
+ tokens: List of EventTokens sorted by timestamp
238
+ threshold: BPM difference threshold to trigger a new segment
239
+
240
+ Returns:
241
+ List of token lists, one per BPM segment
242
+ """
243
+ if not tokens:
244
+ return []
245
+
246
+ segments = []
247
+ current_segment = [tokens[0]]
248
+ current_bpm = tokens[0].bpm
249
+
250
+ for token in tokens[1:]:
251
+ if abs(token.bpm - current_bpm) > threshold:
252
+ # BPM changed significantly, start new segment
253
+ if current_segment:
254
+ segments.append(current_segment)
255
+ current_segment = [token]
256
+ current_bpm = token.bpm
257
+ else:
258
+ current_segment.append(token)
259
+
260
+ # Don't forget the last segment
261
+ if current_segment:
262
+ segments.append(current_segment)
263
+
264
+ return segments
265
+
266
+ def tokens_to_tensor(
267
+ self,
268
+ tokens: list[EventToken],
269
+ max_length: Optional[int] = None,
270
+ normalize: bool = True,
271
+ ) -> tuple[torch.Tensor, torch.Tensor]:
272
+ """
273
+ Convert token list to padded tensor batch.
274
+
275
+ Args:
276
+ tokens: List of EventTokens
277
+ max_length: Maximum sequence length (None = no limit)
278
+ normalize: Whether to normalize continuous features
279
+
280
+ Returns:
281
+ Tuple of (token_tensor, attention_mask)
282
+ token_tensor: [seq_len, 6] - [type, beat_pos, duration, bpm, scroll, gogo]
283
+ attention_mask: [seq_len] - 1 for real tokens, 0 for padding
284
+ """
285
+ if not tokens:
286
+ # Return empty tensors
287
+ return torch.zeros(1, 6), torch.zeros(1)
288
+
289
+ # Truncate if needed
290
+ if max_length is not None and len(tokens) > max_length:
291
+ tokens = tokens[:max_length]
292
+
293
+ # Stack token tensors
294
+ tensor = torch.stack([t.to_tensor() for t in tokens])
295
+
296
+ if normalize:
297
+ # Normalize BPM (column 3)
298
+ tensor[:, 3] = (tensor[:, 3] - self.bpm_mean) / self.bpm_std
299
+ # Normalize scroll (column 4)
300
+ tensor[:, 4] = (tensor[:, 4] - self.scroll_mean) / self.scroll_std
301
+
302
+ # Create attention mask (all 1s for real tokens)
303
+ mask = torch.ones(len(tokens))
304
+
305
+ return tensor, mask
306
+
307
+ def pad_sequence(
308
+ self,
309
+ tensor: torch.Tensor,
310
+ mask: torch.Tensor,
311
+ target_length: int,
312
+ ) -> tuple[torch.Tensor, torch.Tensor]:
313
+ """
314
+ Pad tensor and mask to target length.
315
+
316
+ Args:
317
+ tensor: [seq_len, 6] token tensor
318
+ mask: [seq_len] attention mask
319
+ target_length: Target sequence length
320
+
321
+ Returns:
322
+ Padded tensor and mask
323
+ """
324
+ current_length = tensor.size(0)
325
+
326
+ if current_length >= target_length:
327
+ return tensor[:target_length], mask[:target_length]
328
+
329
+ # Pad tensor
330
+ pad_length = target_length - current_length
331
+ pad_tensor = torch.zeros(pad_length, tensor.size(1))
332
+ pad_tensor[:, 0] = PAD_TOKEN_ID # Set type to PAD
333
+
334
+ padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
335
+ padded_mask = torch.cat([mask, torch.zeros(pad_length)], dim=0)
336
+
337
+ return padded_tensor, padded_mask
TaikoChartEstimator/eval/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TaikoChartEstimator Evaluation Package
3
+ """
4
+
5
+ from .evaluator import Evaluator
6
+ from .metrics import (
7
+ DecompressionMetrics,
8
+ DifficultyMetrics,
9
+ MILHealthMetrics,
10
+ MonotonicityMetrics,
11
+ StarMetrics,
12
+ )
13
+
14
+ __all__ = [
15
+ "DifficultyMetrics",
16
+ "StarMetrics",
17
+ "MonotonicityMetrics",
18
+ "DecompressionMetrics",
19
+ "MILHealthMetrics",
20
+ "Evaluator",
21
+ ]
TaikoChartEstimator/eval/evaluator.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluator for TaikoChartEstimator
3
+
4
+ Orchestrates evaluation across all metric types and generates reports.
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+
18
+ from ..data import TaikoChartDataset, collate_chart_bags
19
+ from ..model import ModelConfig, TaikoChartEstimator
20
+ from .metrics import (
21
+ DecompressionMetrics,
22
+ DifficultyMetrics,
23
+ MILHealthMetrics,
24
+ MonotonicityMetrics,
25
+ StarMetrics,
26
+ )
27
+
28
+
29
+ class Evaluator:
30
+ """
31
+ Comprehensive evaluator for TaikoChartEstimator.
32
+
33
+ Runs all metrics and generates detailed reports.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: TaikoChartEstimator,
39
+ device: torch.device = torch.device("cpu"),
40
+ ):
41
+ self.model = model
42
+ self.device = device
43
+
44
+ # Initialize metric calculators
45
+ self.difficulty_metrics = DifficultyMetrics()
46
+ self.star_metrics = StarMetrics()
47
+ self.monotonicity_metrics = MonotonicityMetrics()
48
+ self.decompression_metrics = DecompressionMetrics()
49
+ self.mil_health_metrics = MILHealthMetrics()
50
+
51
+ @torch.no_grad()
52
+ def run_inference(
53
+ self,
54
+ dataloader: DataLoader,
55
+ ) -> dict:
56
+ """
57
+ Run inference on entire dataset and collect predictions.
58
+
59
+ Returns:
60
+ Dict with all predictions and metadata
61
+ """
62
+ self.model.eval()
63
+
64
+ results = {
65
+ "pred_difficulty_class": [],
66
+ "true_difficulty_class": [],
67
+ "pred_star": [],
68
+ "true_star": [],
69
+ "raw_score": [],
70
+ "song_ids": [],
71
+ "difficulties": [],
72
+ "is_right_censored": [],
73
+ "is_left_censored": [],
74
+ "attention_weights": [],
75
+ "instance_counts": [],
76
+ }
77
+
78
+ for batch in tqdm(dataloader, desc="Running inference"):
79
+ instances = batch["instances"].to(self.device)
80
+ instance_masks = batch["instance_masks"].to(self.device)
81
+ instance_counts = batch["instance_counts"].to(self.device)
82
+ difficulty_class = batch["difficulty_class"].to(self.device)
83
+
84
+ output = self.model(
85
+ instances,
86
+ instance_masks,
87
+ instance_counts,
88
+ difficulty_hint=difficulty_class,
89
+ return_attention=True,
90
+ )
91
+
92
+ # Collect predictions
93
+ results["pred_difficulty_class"].extend(
94
+ output.difficulty_logits.argmax(dim=-1).cpu().numpy()
95
+ )
96
+ results["true_difficulty_class"].extend(batch["difficulty_class"].numpy())
97
+ results["pred_star"].extend(output.raw_star.cpu().numpy())
98
+ results["true_star"].extend(batch["star"].numpy())
99
+ results["raw_score"].extend(output.raw_score.cpu().numpy())
100
+ results["song_ids"].extend(batch["song_ids"])
101
+ results["difficulties"].extend(batch["difficulties"])
102
+ results["is_right_censored"].extend(batch["is_right_censored"].numpy())
103
+ results["is_left_censored"].extend(batch["is_left_censored"].numpy())
104
+ results["instance_counts"].extend(instance_counts.cpu().numpy())
105
+
106
+ # Collect attention weights (average across branches if multi-branch)
107
+ if "average_attention" in output.attention_info:
108
+ results["attention_weights"].extend(
109
+ output.attention_info["average_attention"].cpu().numpy()
110
+ )
111
+
112
+ # Convert to numpy arrays
113
+ for key in [
114
+ "pred_difficulty_class",
115
+ "true_difficulty_class",
116
+ "pred_star",
117
+ "true_star",
118
+ "raw_score",
119
+ "is_right_censored",
120
+ "is_left_censored",
121
+ "instance_counts",
122
+ ]:
123
+ results[key] = np.array(results[key])
124
+
125
+ if results["attention_weights"]:
126
+ results["attention_weights"] = np.stack(results["attention_weights"])
127
+
128
+ return results
129
+
130
+ def compute_all_metrics(self, results: dict) -> dict:
131
+ """
132
+ Compute all metrics from inference results.
133
+
134
+ Returns:
135
+ Dict with all metrics organized by category
136
+ """
137
+ all_metrics = {}
138
+
139
+ # Difficulty classification metrics
140
+ all_metrics["difficulty"] = self.difficulty_metrics.compute(
141
+ results["pred_difficulty_class"],
142
+ results["true_difficulty_class"],
143
+ )
144
+
145
+ # Star regression metrics
146
+ all_metrics["star"] = self.star_metrics.compute(
147
+ results["pred_star"],
148
+ results["true_star"],
149
+ results["true_difficulty_class"],
150
+ results["is_right_censored"],
151
+ results["is_left_censored"],
152
+ )
153
+
154
+ # Monotonicity metrics
155
+ all_metrics["monotonicity"] = self.monotonicity_metrics.compute(
156
+ results["raw_score"],
157
+ results["song_ids"],
158
+ results["difficulties"],
159
+ )
160
+
161
+ # Decompression metrics
162
+ all_metrics["decompression"] = self.decompression_metrics.compute(
163
+ results["pred_star"],
164
+ results["true_star"],
165
+ results["true_difficulty_class"],
166
+ )
167
+
168
+ # MIL health metrics
169
+ if len(results.get("attention_weights", [])) > 0:
170
+ all_metrics["mil_health"] = self.mil_health_metrics.compute(
171
+ results["attention_weights"],
172
+ results["instance_counts"],
173
+ )
174
+
175
+ return all_metrics
176
+
177
+ def generate_report(
178
+ self,
179
+ metrics: dict,
180
+ output_path: Optional[Path] = None,
181
+ ) -> str:
182
+ """
183
+ Generate a human-readable report from metrics.
184
+
185
+ Returns:
186
+ Report as markdown string
187
+ """
188
+ lines = []
189
+ lines.append("# TaikoChartEstimator Evaluation Report")
190
+ lines.append(f"\nGenerated: {datetime.now().isoformat()}\n")
191
+
192
+ # Difficulty Classification
193
+ lines.append("## Difficulty Classification")
194
+ lines.append("")
195
+ d_metrics = metrics.get("difficulty", {})
196
+ lines.append(f"- **Accuracy**: {d_metrics.get('accuracy', 0):.4f}")
197
+ lines.append(
198
+ f"- **Balanced Accuracy**: {d_metrics.get('balanced_accuracy', 0):.4f}"
199
+ )
200
+ lines.append(f"- **Macro F1**: {d_metrics.get('macro_f1', 0):.4f}")
201
+ lines.append(
202
+ f"- **±1 Accuracy**: {d_metrics.get('plus_minus_1_accuracy', 0):.4f}"
203
+ )
204
+ lines.append("")
205
+
206
+ # Per-class F1
207
+ lines.append("### Per-Class F1")
208
+ for cls in ["easy", "normal", "hard", "oni", "ura"]:
209
+ f1 = d_metrics.get(f"f1_{cls}", 0)
210
+ lines.append(f"- {cls.capitalize()}: {f1:.4f}")
211
+ lines.append("")
212
+
213
+ # Star Regression
214
+ lines.append("## Star Rating Prediction")
215
+ lines.append("")
216
+ s_metrics = metrics.get("star", {})
217
+ lines.append("### Overall")
218
+ lines.append(f"- **MAE**: {s_metrics.get('mae', 0):.4f}")
219
+ lines.append(f"- **RMSE**: {s_metrics.get('rmse', 0):.4f}")
220
+ lines.append(f"- **Spearman ρ**: {s_metrics.get('spearman_rho', 0):.4f}")
221
+ lines.append("")
222
+
223
+ lines.append("### Uncensored Samples")
224
+ lines.append(f"- **MAE**: {s_metrics.get('mae_uncensored', 0):.4f}")
225
+ lines.append(
226
+ f"- **Spearman ρ**: {s_metrics.get('spearman_rho_uncensored', 0):.4f}"
227
+ )
228
+ lines.append("")
229
+
230
+ lines.append("### Censoring Consistency")
231
+ lines.append(
232
+ f"- **Right Censor Violation Rate**: {s_metrics.get('right_censor_violation_rate', 0):.4f}"
233
+ )
234
+ lines.append(
235
+ f"- **Right Censor Mean Shortfall**: {s_metrics.get('right_censor_mean_shortfall', 0):.4f}"
236
+ )
237
+ lines.append(
238
+ f"- **Left Censor Violation Rate**: {s_metrics.get('left_censor_violation_rate', 0):.4f}"
239
+ )
240
+ lines.append("")
241
+
242
+ # Monotonicity
243
+ lines.append("## Within-Song Monotonicity")
244
+ lines.append("")
245
+ m_metrics = metrics.get("monotonicity", {})
246
+ lines.append(
247
+ f"- **Violation Rate**: {m_metrics.get('violation_rate', 0):.4f} ({m_metrics.get('n_violations', 0)}/{m_metrics.get('n_pairs', 0)} pairs)"
248
+ )
249
+ lines.append(
250
+ f"- **Mean Violation Margin**: {m_metrics.get('mean_violation_margin', 0):.4f}"
251
+ )
252
+ lines.append(
253
+ f"- **Mean Kendall τ (within-song)**: {m_metrics.get('mean_kendall_tau_within_song', 0):.4f}"
254
+ )
255
+ lines.append("")
256
+
257
+ # Decompression
258
+ lines.append("## 10-Star Decompression")
259
+ lines.append("")
260
+ dec_metrics = metrics.get("decompression", {})
261
+ lines.append(
262
+ f"- **Std (10-star predictions)**: {dec_metrics.get('std_10star', 0):.4f}"
263
+ )
264
+ lines.append(
265
+ f"- **Range (10-star predictions)**: {dec_metrics.get('range_10star', 0):.4f}"
266
+ )
267
+ if "p90_p50_10star" in dec_metrics:
268
+ lines.append(f"- **P90 - P50**: {dec_metrics.get('p90_p50_10star', 0):.4f}")
269
+ lines.append(f"- **P99 - P90**: {dec_metrics.get('p99_p90_10star', 0):.4f}")
270
+ lines.append("")
271
+
272
+ # MIL Health
273
+ if "mil_health" in metrics:
274
+ lines.append("## MIL Attention Health")
275
+ lines.append("")
276
+ mil_metrics = metrics["mil_health"]
277
+ lines.append(
278
+ f"- **Mean Attention Entropy**: {mil_metrics.get('mean_attention_entropy', 0):.4f}"
279
+ )
280
+ lines.append(
281
+ f"- **Mean Effective Instances**: {mil_metrics.get('mean_effective_instances', 0):.4f}"
282
+ )
283
+ lines.append(
284
+ f"- **Mean Top-5% Mass**: {mil_metrics.get('mean_top5_mass', 0):.4f}"
285
+ )
286
+
287
+ if mil_metrics.get("attention_collapse_warning", False):
288
+ lines.append("")
289
+ lines.append(
290
+ "> ⚠️ **Warning**: Attention collapse detected! "
291
+ "Model may be relying on too few instances."
292
+ )
293
+ lines.append("")
294
+
295
+ report = "\n".join(lines)
296
+
297
+ if output_path:
298
+ output_path.write_text(report)
299
+
300
+ return report
301
+
302
+ def evaluate(
303
+ self,
304
+ dataloader: DataLoader,
305
+ output_dir: Optional[Path] = None,
306
+ ) -> dict:
307
+ """
308
+ Run full evaluation pipeline.
309
+
310
+ Args:
311
+ dataloader: DataLoader for evaluation data
312
+ output_dir: Optional directory to save results
313
+
314
+ Returns:
315
+ Dict with all metrics
316
+ """
317
+ # Run inference
318
+ results = self.run_inference(dataloader)
319
+
320
+ # Compute metrics
321
+ metrics = self.compute_all_metrics(results)
322
+
323
+ # Generate report
324
+ report = self.generate_report(metrics)
325
+
326
+ if output_dir:
327
+ output_dir.mkdir(parents=True, exist_ok=True)
328
+
329
+ # Save metrics as JSON
330
+ # Convert numpy types for JSON serialization
331
+ def convert_numpy(obj):
332
+ if isinstance(obj, np.ndarray):
333
+ return obj.tolist()
334
+ elif isinstance(obj, np.integer):
335
+ return int(obj)
336
+ elif isinstance(obj, np.floating):
337
+ return float(obj)
338
+ elif isinstance(obj, (np.bool_, bool)):
339
+ return bool(obj)
340
+ elif isinstance(obj, dict):
341
+ return {k: convert_numpy(v) for k, v in obj.items()}
342
+ elif isinstance(obj, list):
343
+ return [convert_numpy(v) for v in obj]
344
+ return obj
345
+
346
+ metrics_serializable = convert_numpy(metrics)
347
+ with open(output_dir / "metrics.json", "w") as f:
348
+ json.dump(metrics_serializable, f, indent=2)
349
+
350
+ # Save report
351
+ (output_dir / "report.md").write_text(report)
352
+
353
+ print(f"Results saved to {output_dir}")
354
+
355
+ return metrics
356
+
357
+
358
+ def load_model_from_checkpoint(
359
+ checkpoint_path: Path,
360
+ device: torch.device,
361
+ ) -> TaikoChartEstimator:
362
+ """
363
+ Load model from checkpoint.
364
+
365
+ Supports two formats:
366
+ 1. Traditional .pt checkpoint file (contains model_state_dict and config)
367
+ 2. HuggingFace save_pretrained directory (saved via model.save_pretrained())
368
+
369
+ Args:
370
+ checkpoint_path: Path to checkpoint file or pretrained directory
371
+ device: Device to load model to
372
+
373
+ Returns:
374
+ Loaded TaikoChartEstimator model
375
+ """
376
+ checkpoint_path = Path(checkpoint_path)
377
+
378
+ if checkpoint_path.is_dir():
379
+ # HuggingFace pretrained directory format
380
+ model = TaikoChartEstimator.from_pretrained(
381
+ checkpoint_path,
382
+ ).to(device)
383
+ else:
384
+ # Traditional .pt checkpoint format
385
+ checkpoint = torch.load(checkpoint_path, map_location=device)
386
+ config = ModelConfig(**checkpoint["config"])
387
+ model = TaikoChartEstimator(config)
388
+ model.load_state_dict(checkpoint["model_state_dict"])
389
+
390
+ model = model.to(device)
391
+ model.eval()
392
+
393
+ return model
394
+
395
+
396
+ def main():
397
+ parser = argparse.ArgumentParser(description="Evaluate TaikoChartEstimator")
398
+ parser.add_argument(
399
+ "--checkpoint", type=str, required=True, help="Path to model checkpoint"
400
+ )
401
+ parser.add_argument(
402
+ "--dataset",
403
+ type=str,
404
+ default="JacobLinCool/taiko-1000-parsed",
405
+ help="HuggingFace dataset name",
406
+ )
407
+ parser.add_argument(
408
+ "--split", type=str, default="test", help="Dataset split to evaluate"
409
+ )
410
+ parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
411
+ parser.add_argument(
412
+ "--output-dir",
413
+ type=str,
414
+ default="eval_results",
415
+ help="Output directory for results",
416
+ )
417
+ parser.add_argument(
418
+ "--device",
419
+ type=str,
420
+ default="cuda" if torch.cuda.is_available() else "cpu",
421
+ help="Device to use",
422
+ )
423
+ parser.add_argument(
424
+ "--num-workers", type=int, default=4, help="Number of data loader workers"
425
+ )
426
+
427
+ args = parser.parse_args()
428
+
429
+ device = torch.device(args.device)
430
+
431
+ # Load model
432
+ print(f"Loading model from {args.checkpoint}")
433
+ model = load_model_from_checkpoint(Path(args.checkpoint), device)
434
+
435
+ # Load dataset
436
+ print(f"Loading {args.split} dataset...")
437
+ dataset = TaikoChartDataset(
438
+ split=args.split,
439
+ dataset_name=args.dataset,
440
+ )
441
+
442
+ dataloader = DataLoader(
443
+ dataset,
444
+ batch_size=args.batch_size,
445
+ shuffle=False,
446
+ collate_fn=collate_chart_bags,
447
+ num_workers=args.num_workers,
448
+ )
449
+
450
+ print(f"Evaluating on {len(dataset)} samples...")
451
+
452
+ # Run evaluation
453
+ evaluator = Evaluator(model, device)
454
+ metrics = evaluator.evaluate(
455
+ dataloader,
456
+ output_dir=Path(args.output_dir),
457
+ )
458
+
459
+ # Print summary
460
+ print("\n" + "=" * 50)
461
+ print("EVALUATION SUMMARY")
462
+ print("=" * 50)
463
+ print(f"Difficulty Macro-F1: {metrics['difficulty']['macro_f1']:.4f}")
464
+ print(f"Star MAE (uncensored): {metrics['star']['mae_uncensored']:.4f}")
465
+ print(f"Star Spearman ρ: {metrics['star']['spearman_rho']:.4f}")
466
+ print(
467
+ f"Monotonicity Violation Rate: {metrics['monotonicity']['violation_rate']:.4f}"
468
+ )
469
+ print(
470
+ f"10-Star Decompression Std: {metrics['decompression'].get('std_10star', 0):.4f}"
471
+ )
472
+ print("=" * 50)
473
+
474
+
475
+ if __name__ == "__main__":
476
+ main()
TaikoChartEstimator/eval/metrics.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Metrics for TaikoChartEstimator
3
+
4
+ Comprehensive metrics covering:
5
+ - Difficulty classification
6
+ - Star rating regression (with censoring awareness)
7
+ - Monotonicity constraints
8
+ - 10-star decompression
9
+ - MIL attention health
10
+ """
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ import numpy as np
16
+ from scipy.stats import kendalltau, spearmanr
17
+ from sklearn.metrics import (
18
+ accuracy_score,
19
+ balanced_accuracy_score,
20
+ confusion_matrix,
21
+ f1_score,
22
+ )
23
+
24
+ from ..constants import STAR_RANGES_BY_ID
25
+
26
+
27
+ @dataclass
28
+ class DifficultyMetrics:
29
+ """
30
+ Metrics for difficulty classification (easy/normal/hard/oni/ura).
31
+
32
+ Includes ordinal-aware metrics since difficulties are ordered.
33
+ Note: ura (4) and oni (3) are treated as the same class for metrics.
34
+ """
35
+
36
+ merge_ura_oni: bool = True # Treat ura and oni as the same class
37
+
38
+ def _merge_classes(self, arr: np.ndarray) -> np.ndarray:
39
+ """Merge ura (4) into oni (3) class."""
40
+ if self.merge_ura_oni:
41
+ arr = arr.copy()
42
+ arr[arr == 4] = 3 # Map ura -> oni
43
+ return arr
44
+
45
+ def compute(
46
+ self,
47
+ predictions: np.ndarray,
48
+ targets: np.ndarray,
49
+ ) -> dict:
50
+ """
51
+ Compute classification metrics.
52
+
53
+ Args:
54
+ predictions: Predicted difficulty class indices [N]
55
+ targets: True difficulty class indices [N]
56
+
57
+ Returns:
58
+ Dict with all metrics
59
+ """
60
+ metrics = {}
61
+
62
+ # Merge ura and oni classes if enabled
63
+ predictions = self._merge_classes(predictions)
64
+ targets = self._merge_classes(targets)
65
+
66
+ # Standard classification metrics
67
+ metrics["accuracy"] = accuracy_score(targets, predictions)
68
+ metrics["balanced_accuracy"] = balanced_accuracy_score(targets, predictions)
69
+ metrics["macro_f1"] = f1_score(targets, predictions, average="macro")
70
+ metrics["weighted_f1"] = f1_score(targets, predictions, average="weighted")
71
+
72
+ # Per-class F1 (4 classes when merged: easy, normal, hard, oni/ura)
73
+ per_class_f1 = f1_score(targets, predictions, average=None)
74
+ if self.merge_ura_oni:
75
+ class_names = ["easy", "normal", "hard", "oni_ura"]
76
+ else:
77
+ class_names = ["easy", "normal", "hard", "oni", "ura"]
78
+ for i, name in enumerate(class_names):
79
+ if i < len(per_class_f1):
80
+ metrics[f"f1_{name}"] = per_class_f1[i]
81
+
82
+ # Ordinal-aware metrics (difficulties are ordered)
83
+ abs_diff = np.abs(predictions - targets)
84
+ metrics["mean_absolute_error_ordinal"] = abs_diff.mean()
85
+ metrics["plus_minus_1_accuracy"] = (abs_diff <= 1).mean()
86
+ metrics["plus_minus_2_accuracy"] = (abs_diff <= 2).mean()
87
+
88
+ # Confusion matrix
89
+ metrics["confusion_matrix"] = confusion_matrix(targets, predictions)
90
+
91
+ return metrics
92
+
93
+
94
+ @dataclass
95
+ class StarMetrics:
96
+ """
97
+ Metrics for star rating prediction with censoring awareness.
98
+
99
+ Separates metrics for:
100
+ - Uncensored samples (true regression quality)
101
+ - Right-censored samples (10-star boundary)
102
+ - Left-censored samples (1-star boundary)
103
+ """
104
+
105
+ star_ranges: dict = field(default_factory=lambda: STAR_RANGES_BY_ID.copy())
106
+
107
+ def compute(
108
+ self,
109
+ predictions: np.ndarray,
110
+ targets: np.ndarray,
111
+ difficulties: np.ndarray,
112
+ is_right_censored: Optional[np.ndarray] = None,
113
+ is_left_censored: Optional[np.ndarray] = None,
114
+ ) -> dict:
115
+ """
116
+ Compute star regression metrics.
117
+
118
+ Args:
119
+ predictions: Predicted star ratings [N]
120
+ targets: Target star labels [N]
121
+ difficulties: Difficulty class indices [N]
122
+ is_right_censored: Boolean mask for right-censored samples
123
+ is_left_censored: Boolean mask for left-censored samples
124
+
125
+ Returns:
126
+ Dict with all metrics
127
+ """
128
+ metrics = {}
129
+
130
+ # Auto-detect censoring if not provided
131
+ if is_right_censored is None or is_left_censored is None:
132
+ is_right_censored = np.zeros(len(predictions), dtype=bool)
133
+ is_left_censored = np.zeros(len(predictions), dtype=bool)
134
+
135
+ for diff_idx, (min_star, max_star) in self.star_ranges.items():
136
+ mask = difficulties == diff_idx
137
+ is_right_censored[mask] = targets[mask] >= max_star
138
+ is_left_censored[mask] = targets[mask] <= min_star
139
+
140
+ # Overall metrics
141
+ metrics["mae"] = np.abs(predictions - targets).mean()
142
+ metrics["rmse"] = np.sqrt(((predictions - targets) ** 2).mean())
143
+
144
+ if len(predictions) > 1:
145
+ rho, p_value = spearmanr(predictions, targets)
146
+ metrics["spearman_rho"] = rho
147
+ metrics["spearman_pvalue"] = p_value
148
+ else:
149
+ metrics["spearman_rho"] = 0.0
150
+ metrics["spearman_pvalue"] = 1.0
151
+
152
+ # Uncensored samples: true regression quality
153
+ uncensored_mask = ~(is_right_censored | is_left_censored)
154
+ if uncensored_mask.sum() > 0:
155
+ uncensored_preds = predictions[uncensored_mask]
156
+ uncensored_targets = targets[uncensored_mask]
157
+
158
+ metrics["mae_uncensored"] = np.abs(
159
+ uncensored_preds - uncensored_targets
160
+ ).mean()
161
+ metrics["rmse_uncensored"] = np.sqrt(
162
+ ((uncensored_preds - uncensored_targets) ** 2).mean()
163
+ )
164
+
165
+ if len(uncensored_preds) > 1:
166
+ rho, _ = spearmanr(uncensored_preds, uncensored_targets)
167
+ metrics["spearman_rho_uncensored"] = rho
168
+ else:
169
+ metrics["spearman_rho_uncensored"] = 0.0
170
+ else:
171
+ metrics["mae_uncensored"] = 0.0
172
+ metrics["rmse_uncensored"] = 0.0
173
+ metrics["spearman_rho_uncensored"] = 0.0
174
+
175
+ # Right-censored (at max star): check violation
176
+ if is_right_censored.sum() > 0:
177
+ right_preds = predictions[is_right_censored]
178
+ right_targets = targets[is_right_censored]
179
+
180
+ # Violation: prediction below the max star bound
181
+ violation_mask = right_preds < right_targets
182
+ metrics["right_censor_violation_rate"] = violation_mask.mean()
183
+
184
+ if violation_mask.sum() > 0:
185
+ metrics["right_censor_mean_shortfall"] = (
186
+ right_targets[violation_mask] - right_preds[violation_mask]
187
+ ).mean()
188
+ else:
189
+ metrics["right_censor_mean_shortfall"] = 0.0
190
+
191
+ metrics["right_censor_count"] = is_right_censored.sum()
192
+ else:
193
+ metrics["right_censor_violation_rate"] = 0.0
194
+ metrics["right_censor_mean_shortfall"] = 0.0
195
+ metrics["right_censor_count"] = 0
196
+
197
+ # Left-censored (at min star): check violation
198
+ if is_left_censored.sum() > 0:
199
+ left_preds = predictions[is_left_censored]
200
+ left_targets = targets[is_left_censored]
201
+
202
+ # Violation: prediction above the min star bound
203
+ violation_mask = left_preds > left_targets
204
+ metrics["left_censor_violation_rate"] = violation_mask.mean()
205
+
206
+ if violation_mask.sum() > 0:
207
+ metrics["left_censor_mean_overshoot"] = (
208
+ left_preds[violation_mask] - left_targets[violation_mask]
209
+ ).mean()
210
+ else:
211
+ metrics["left_censor_mean_overshoot"] = 0.0
212
+
213
+ metrics["left_censor_count"] = is_left_censored.sum()
214
+ else:
215
+ metrics["left_censor_violation_rate"] = 0.0
216
+ metrics["left_censor_mean_overshoot"] = 0.0
217
+ metrics["left_censor_count"] = 0
218
+
219
+ return metrics
220
+
221
+
222
+ @dataclass
223
+ class MonotonicityMetrics:
224
+ """
225
+ Metrics for within-song monotonicity constraint.
226
+
227
+ Checks that harder difficulties have higher scores/stars
228
+ within the same song.
229
+ """
230
+
231
+ difficulty_order: dict = field(
232
+ default_factory=lambda: {
233
+ "easy": 0,
234
+ "Easy": 0,
235
+ "normal": 1,
236
+ "Normal": 1,
237
+ "hard": 2,
238
+ "Hard": 2,
239
+ "oni": 3,
240
+ "Oni": 3,
241
+ "ura": 4,
242
+ "Ura": 4,
243
+ }
244
+ )
245
+
246
+ def compute(
247
+ self,
248
+ raw_scores: np.ndarray,
249
+ song_ids: list[str],
250
+ difficulties: list[str],
251
+ ) -> dict:
252
+ """
253
+ Compute monotonicity metrics.
254
+
255
+ Args:
256
+ raw_scores: Raw difficulty scores [N]
257
+ song_ids: Song identifiers
258
+ difficulties: Difficulty names
259
+
260
+ Returns:
261
+ Dict with metrics
262
+ """
263
+ metrics = {}
264
+
265
+ # Group by song
266
+ song_groups: dict[str, list] = {}
267
+ for i, song_id in enumerate(song_ids):
268
+ if song_id not in song_groups:
269
+ song_groups[song_id] = []
270
+ song_groups[song_id].append(
271
+ {
272
+ "idx": i,
273
+ "difficulty": difficulties[i],
274
+ "score": raw_scores[i],
275
+ }
276
+ )
277
+
278
+ # Count violations
279
+ n_violations = 0
280
+ n_pairs = 0
281
+ violation_margins = []
282
+ per_song_kendall_tau = []
283
+
284
+ for song_id, charts in song_groups.items():
285
+ if len(charts) < 2:
286
+ continue
287
+
288
+ # Sort by difficulty order
289
+ sorted_charts = sorted(
290
+ charts, key=lambda c: self.difficulty_order.get(c["difficulty"], 0)
291
+ )
292
+
293
+ # Check adjacent pairs
294
+ for i in range(len(sorted_charts) - 1):
295
+ n_pairs += 1
296
+ score_easier = sorted_charts[i]["score"]
297
+ score_harder = sorted_charts[i + 1]["score"]
298
+
299
+ if score_easier >= score_harder:
300
+ n_violations += 1
301
+ violation_margins.append(score_easier - score_harder)
302
+
303
+ # Compute Kendall's tau within song
304
+ if len(sorted_charts) >= 2:
305
+ actual_scores = [c["score"] for c in sorted_charts]
306
+ expected_ranks = list(range(len(sorted_charts)))
307
+
308
+ tau, _ = kendalltau(actual_scores, expected_ranks)
309
+ if not np.isnan(tau):
310
+ per_song_kendall_tau.append(tau)
311
+
312
+ # Aggregate metrics
313
+ metrics["n_pairs"] = n_pairs
314
+ metrics["n_violations"] = n_violations
315
+ metrics["violation_rate"] = n_violations / n_pairs if n_pairs > 0 else 0.0
316
+
317
+ if violation_margins:
318
+ metrics["mean_violation_margin"] = np.mean(violation_margins)
319
+ metrics["max_violation_margin"] = np.max(violation_margins)
320
+ else:
321
+ metrics["mean_violation_margin"] = 0.0
322
+ metrics["max_violation_margin"] = 0.0
323
+
324
+ if per_song_kendall_tau:
325
+ metrics["mean_kendall_tau_within_song"] = np.mean(per_song_kendall_tau)
326
+ metrics["min_kendall_tau_within_song"] = np.min(per_song_kendall_tau)
327
+ else:
328
+ metrics["mean_kendall_tau_within_song"] = 0.0
329
+ metrics["min_kendall_tau_within_song"] = 0.0
330
+
331
+ return metrics
332
+
333
+
334
+ @dataclass
335
+ class DecompressionMetrics:
336
+ """
337
+ Metrics for 10-star decompression.
338
+
339
+ Checks if the model learns to distinguish between different
340
+ 10-star charts (which vary widely in actual difficulty).
341
+ """
342
+
343
+ def compute(
344
+ self,
345
+ predictions: np.ndarray,
346
+ targets: np.ndarray,
347
+ difficulties: np.ndarray,
348
+ ) -> dict:
349
+ """
350
+ Compute decompression metrics for max-star samples.
351
+
352
+ Args:
353
+ predictions: Predicted star ratings (can exceed range)
354
+ targets: Target star labels
355
+ difficulties: Difficulty indices
356
+
357
+ Returns:
358
+ Dict with metrics
359
+ """
360
+ metrics = {}
361
+
362
+ # Star ranges per difficulty
363
+ max_stars = {0: 5, 1: 7, 2: 8, 3: 10, 4: 10}
364
+
365
+ for diff_idx, max_star in max_stars.items():
366
+ mask = (difficulties == diff_idx) & (targets >= max_star)
367
+
368
+ if mask.sum() < 2:
369
+ continue
370
+
371
+ preds_at_max = predictions[mask]
372
+ diff_name = ["easy", "normal", "hard", "oni", "ura"][diff_idx]
373
+
374
+ # Spread of predictions
375
+ metrics[f"std_{diff_name}_max"] = preds_at_max.std()
376
+
377
+ # Percentile gaps
378
+ if len(preds_at_max) >= 10:
379
+ p50 = np.percentile(preds_at_max, 50)
380
+ p90 = np.percentile(preds_at_max, 90)
381
+ p99 = np.percentile(preds_at_max, 99)
382
+
383
+ metrics[f"p90_p50_{diff_name}"] = p90 - p50
384
+ metrics[f"p99_p90_{diff_name}"] = p99 - p90
385
+
386
+ # Range
387
+ metrics[f"range_{diff_name}_max"] = preds_at_max.max() - preds_at_max.min()
388
+ metrics[f"n_samples_{diff_name}_max"] = mask.sum()
389
+
390
+ # Overall 10-star decompression (oni + ura combined)
391
+ max_10_mask = (targets >= 10) & ((difficulties == 3) | (difficulties == 4))
392
+ if max_10_mask.sum() >= 2:
393
+ preds_10star = predictions[max_10_mask]
394
+
395
+ metrics["std_10star"] = preds_10star.std()
396
+ metrics["range_10star"] = preds_10star.max() - preds_10star.min()
397
+ metrics["n_samples_10star"] = max_10_mask.sum()
398
+
399
+ if len(preds_10star) >= 10:
400
+ metrics["p90_p50_10star"] = np.percentile(
401
+ preds_10star, 90
402
+ ) - np.percentile(preds_10star, 50)
403
+ metrics["p99_p90_10star"] = np.percentile(
404
+ preds_10star, 99
405
+ ) - np.percentile(preds_10star, 90)
406
+
407
+ return metrics
408
+
409
+
410
+ @dataclass
411
+ class MILHealthMetrics:
412
+ """
413
+ Metrics for MIL attention health.
414
+
415
+ Monitors attention distribution to detect collapse
416
+ (model focusing on too few instances).
417
+ """
418
+
419
+ def compute(
420
+ self,
421
+ attention_weights: np.ndarray,
422
+ instance_counts: Optional[np.ndarray] = None,
423
+ ) -> dict:
424
+ """
425
+ Compute MIL attention health metrics.
426
+
427
+ Args:
428
+ attention_weights: Attention weights [N_samples, N_instances]
429
+ instance_counts: Number of valid instances per sample
430
+
431
+ Returns:
432
+ Dict with metrics
433
+ """
434
+ metrics = {}
435
+ n_samples, n_instances = attention_weights.shape
436
+
437
+ # Mask invalid instances if counts provided
438
+ if instance_counts is not None:
439
+ mask = np.arange(n_instances)[None, :] < instance_counts[:, None]
440
+ else:
441
+ mask = np.ones_like(attention_weights, dtype=bool)
442
+
443
+ # Attention entropy per sample
444
+ # Higher entropy = more distributed attention (good for MIL)
445
+ entropies = []
446
+ effective_ns = []
447
+ top5_masses = []
448
+
449
+ for i in range(n_samples):
450
+ attn = attention_weights[i, mask[i]]
451
+ if len(attn) == 0:
452
+ continue
453
+
454
+ # Normalize to sum to 1
455
+ attn = attn / (attn.sum() + 1e-8)
456
+
457
+ # Entropy
458
+ entropy = -np.sum(attn * np.log(attn + 1e-8))
459
+ entropies.append(entropy)
460
+
461
+ # Effective number of instances (inverse of concentration)
462
+ effective_n = 1.0 / (np.sum(attn**2) + 1e-8)
463
+ effective_ns.append(effective_n)
464
+
465
+ # Top-5% mass
466
+ k = max(1, int(len(attn) * 0.05))
467
+ top5_mass = np.sort(attn)[-k:].sum()
468
+ top5_masses.append(top5_mass)
469
+
470
+ if entropies:
471
+ metrics["mean_attention_entropy"] = np.mean(entropies)
472
+ metrics["min_attention_entropy"] = np.min(entropies)
473
+ metrics["std_attention_entropy"] = np.std(entropies)
474
+
475
+ if effective_ns:
476
+ metrics["mean_effective_instances"] = np.mean(effective_ns)
477
+ metrics["min_effective_instances"] = np.min(effective_ns)
478
+
479
+ if top5_masses:
480
+ metrics["mean_top5_mass"] = np.mean(top5_masses)
481
+ metrics["max_top5_mass"] = np.max(top5_masses)
482
+
483
+ # Health assessment
484
+ # Collapse warning if too few effective instances
485
+ if effective_ns:
486
+ collapse_ratio = np.mean(effective_ns) / np.mean(
487
+ [
488
+ c if instance_counts is not None else n_instances
489
+ for c in (
490
+ instance_counts
491
+ if instance_counts is not None
492
+ else [n_instances]
493
+ )
494
+ ]
495
+ )
496
+ metrics["health_ratio"] = collapse_ratio
497
+ metrics["attention_collapse_warning"] = (
498
+ collapse_ratio < 0.1
499
+ ) # Less than 10% of instances used
500
+
501
+ return metrics
TaikoChartEstimator/model/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TaikoChartEstimator Model Package
3
+
4
+ Provides the MIL-based difficulty estimation model with:
5
+ - Instance encoder (Transformer-based)
6
+ - MIL aggregator with multi-branch attention
7
+ - Multi-head outputs (raw score, difficulty class, star rating)
8
+ """
9
+
10
+ from .aggregator import GatedMILAggregator, MILAggregator
11
+ from .encoder import InstanceEncoder, TCNInstanceEncoder
12
+ from .heads import DifficultyClassifier, MonotonicCalibrator, RawScoreHead
13
+ from .losses import (
14
+ CensoredRegressionLoss,
15
+ CurriculumScheduler,
16
+ TotalLoss,
17
+ WithinSongRankingLoss,
18
+ )
19
+ from .model import ModelConfig, ModelOutput, TaikoChartEstimator
20
+
21
+ __all__ = [
22
+ "InstanceEncoder",
23
+ "TCNInstanceEncoder",
24
+ "MILAggregator",
25
+ "GatedMILAggregator",
26
+ "RawScoreHead",
27
+ "DifficultyClassifier",
28
+ "MonotonicCalibrator",
29
+ "TaikoChartEstimator",
30
+ "ModelConfig",
31
+ "ModelOutput",
32
+ "WithinSongRankingLoss",
33
+ "CensoredRegressionLoss",
34
+ "TotalLoss",
35
+ "CurriculumScheduler",
36
+ ]
TaikoChartEstimator/model/aggregator.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIL Bag Aggregator for Taiko Chart Estimation
3
+
4
+ Implements Multiple Instance Learning aggregation with:
5
+ - Three-way pooling (mean, top-k, attention)
6
+ - Multi-branch attention (ACMIL-inspired)
7
+ - Stochastic top-k masking to prevent attention collapse
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class AttentionBranch(nn.Module):
18
+ """Single attention branch for multi-branch attention."""
19
+
20
+ def __init__(self, d_instance: int, d_hidden: int = 64):
21
+ super().__init__()
22
+ self.attention = nn.Sequential(
23
+ nn.Linear(d_instance, d_hidden),
24
+ nn.Tanh(),
25
+ nn.Linear(d_hidden, 1),
26
+ )
27
+
28
+ def forward(
29
+ self,
30
+ instances: torch.Tensor,
31
+ mask: Optional[torch.Tensor] = None,
32
+ ) -> tuple[torch.Tensor, torch.Tensor]:
33
+ """
34
+ Args:
35
+ instances: [batch, n_instances, d_instance]
36
+ mask: [batch, n_instances], 1 for valid, 0 for padding
37
+
38
+ Returns:
39
+ pooled: [batch, d_instance]
40
+ attention_weights: [batch, n_instances]
41
+ """
42
+ # Compute attention scores
43
+ scores = self.attention(instances).squeeze(-1) # [batch, n_instances]
44
+
45
+ # Apply mask
46
+ if mask is not None:
47
+ scores = scores.masked_fill(mask == 0, float("-inf"))
48
+
49
+ # Softmax
50
+ attn_weights = F.softmax(scores, dim=-1)
51
+
52
+ # Handle all-masked case
53
+ if mask is not None:
54
+ attn_weights = attn_weights.masked_fill(mask == 0, 0.0)
55
+
56
+ # Weighted sum
57
+ pooled = (instances * attn_weights.unsqueeze(-1)).sum(dim=1)
58
+
59
+ return pooled, attn_weights
60
+
61
+
62
+ class MILAggregator(nn.Module):
63
+ """
64
+ Multiple Instance Learning aggregator with ACMIL-inspired design.
65
+
66
+ Combines three complementary pooling strategies:
67
+ 1. Mean pooling: Captures overall difficulty/stamina
68
+ 2. Top-K pooling: Captures peak difficulty segments
69
+ 3. Multi-branch attention: Learns multiple discriminative patterns
70
+
71
+ Features stochastic top-k masking during training to prevent
72
+ the model from relying on only a few "hardest" instances.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ d_instance: int = 256,
78
+ n_branches: int = 3,
79
+ top_k_ratio: float = 0.1,
80
+ stochastic_mask_prob: float = 0.3,
81
+ dropout: float = 0.1,
82
+ ):
83
+ """
84
+ Initialize MIL aggregator.
85
+
86
+ Args:
87
+ d_instance: Dimension of instance embeddings
88
+ n_branches: Number of attention branches
89
+ top_k_ratio: Fraction of instances for top-k pooling
90
+ stochastic_mask_prob: Probability of masking top instances during training
91
+ dropout: Dropout rate
92
+ """
93
+ super().__init__()
94
+
95
+ self.d_instance = d_instance
96
+ self.n_branches = n_branches
97
+ self.top_k_ratio = top_k_ratio
98
+ self.stochastic_mask_prob = stochastic_mask_prob
99
+
100
+ # Top-K scoring network
101
+ self.topk_scorer = nn.Sequential(
102
+ nn.Linear(d_instance, 64),
103
+ nn.ReLU(),
104
+ nn.Linear(64, 1),
105
+ )
106
+
107
+ # Multi-branch attention
108
+ self.attention_branches = nn.ModuleList(
109
+ [AttentionBranch(d_instance, d_hidden=64) for _ in range(n_branches)]
110
+ )
111
+
112
+ # Fusion layer: combines mean (1) + topk (1) + branches (n_branches) = 2 + n_branches
113
+ n_pooled = 2 + n_branches
114
+ self.fusion = nn.Sequential(
115
+ nn.Linear(d_instance * n_pooled, d_instance * 2),
116
+ nn.LayerNorm(d_instance * 2),
117
+ nn.GELU(),
118
+ nn.Dropout(dropout),
119
+ nn.Linear(d_instance * 2, d_instance * 2),
120
+ )
121
+
122
+ self.output_dim = d_instance * 2
123
+
124
+ def _mean_pool(
125
+ self,
126
+ instances: torch.Tensor,
127
+ mask: Optional[torch.Tensor] = None,
128
+ ) -> torch.Tensor:
129
+ """Mean pooling over instances."""
130
+ if mask is not None:
131
+ mask_expanded = mask.unsqueeze(-1)
132
+ pooled = (instances * mask_expanded).sum(dim=1)
133
+ pooled = pooled / mask_expanded.sum(dim=1).clamp(min=1)
134
+ else:
135
+ pooled = instances.mean(dim=1)
136
+ return pooled
137
+
138
+ def _topk_pool(
139
+ self,
140
+ instances: torch.Tensor,
141
+ mask: Optional[torch.Tensor] = None,
142
+ ) -> tuple[torch.Tensor, torch.Tensor]:
143
+ """
144
+ Top-K pooling based on learned scores.
145
+
146
+ Returns:
147
+ pooled: [batch, d_instance]
148
+ topk_mask: [batch, n_instances] binary mask of selected instances
149
+ """
150
+ batch_size, n_instances, _ = instances.shape
151
+
152
+ # Compute scores
153
+ scores = self.topk_scorer(instances).squeeze(-1) # [batch, n_instances]
154
+
155
+ # Apply mask
156
+ if mask is not None:
157
+ scores = scores.masked_fill(mask == 0, float("-inf"))
158
+
159
+ # Determine k
160
+ if mask is not None:
161
+ valid_counts = mask.sum(dim=1) # [batch]
162
+ k = (valid_counts * self.top_k_ratio).clamp(min=1).long()
163
+ max_k = k.max().item()
164
+ else:
165
+ k = max(1, int(n_instances * self.top_k_ratio))
166
+ max_k = k
167
+
168
+ # Get top-k indices
169
+ _, topk_indices = scores.topk(max_k, dim=1) # [batch, max_k]
170
+
171
+ # Create topk mask
172
+ topk_mask = torch.zeros_like(mask if mask is not None else scores)
173
+ topk_mask.scatter_(1, topk_indices, 1.0)
174
+
175
+ # Pool top-k instances
176
+ if mask is not None:
177
+ combined_mask = topk_mask * mask
178
+ else:
179
+ combined_mask = topk_mask
180
+
181
+ mask_expanded = combined_mask.unsqueeze(-1)
182
+ pooled = (instances * mask_expanded).sum(dim=1)
183
+ pooled = pooled / mask_expanded.sum(dim=1).clamp(min=1)
184
+
185
+ return pooled, topk_mask
186
+
187
+ def _stochastic_topk_mask(
188
+ self,
189
+ instances: torch.Tensor,
190
+ mask: Optional[torch.Tensor] = None,
191
+ ) -> torch.Tensor:
192
+ """
193
+ Create stochastic mask that randomly drops top instances.
194
+
195
+ This prevents attention collapse by forcing the model to
196
+ learn from non-peak instances during training.
197
+ """
198
+ if not self.training:
199
+ return mask
200
+
201
+ batch_size, n_instances, _ = instances.shape
202
+
203
+ # Get top-k scores
204
+ with torch.no_grad():
205
+ scores = self.topk_scorer(instances).squeeze(-1)
206
+ if mask is not None:
207
+ scores = scores.masked_fill(mask == 0, float("-inf"))
208
+
209
+ k = max(1, int(n_instances * self.top_k_ratio))
210
+ _, topk_indices = scores.topk(k, dim=1)
211
+
212
+ # Create mask that drops top instances with some probability
213
+ drop_mask = torch.ones_like(mask if mask is not None else scores)
214
+
215
+ # For each batch, randomly decide whether to drop top instances
216
+ drop_decision = (
217
+ torch.rand(batch_size, device=instances.device) < self.stochastic_mask_prob
218
+ )
219
+
220
+ for i in range(batch_size):
221
+ if drop_decision[i]:
222
+ drop_mask[i, topk_indices[i]] = 0.0
223
+
224
+ if mask is not None:
225
+ return mask * drop_mask
226
+ return drop_mask
227
+
228
+ def forward(
229
+ self,
230
+ instances: torch.Tensor,
231
+ mask: Optional[torch.Tensor] = None,
232
+ return_attention: bool = True,
233
+ ) -> tuple[torch.Tensor, dict]:
234
+ """
235
+ Aggregate instance embeddings to bag embedding.
236
+
237
+ Args:
238
+ instances: [batch, n_instances, d_instance]
239
+ mask: [batch, n_instances], 1 for valid, 0 for padding
240
+ return_attention: Whether to return attention weights for analysis
241
+
242
+ Returns:
243
+ bag_embedding: [batch, output_dim]
244
+ attention_info: Dict with attention weights and metrics
245
+ """
246
+ # Apply stochastic top-k masking during training
247
+ if self.training:
248
+ stoch_mask = self._stochastic_topk_mask(instances, mask)
249
+ else:
250
+ stoch_mask = mask
251
+
252
+ # 1. Mean pooling (stamina/overall representation)
253
+ mean_pooled = self._mean_pool(instances, mask)
254
+
255
+ # 2. Top-K pooling (peak difficulty)
256
+ topk_pooled, topk_mask = self._topk_pool(instances, mask)
257
+
258
+ # 3. Multi-branch attention pooling
259
+ branch_outputs = []
260
+ branch_attns = []
261
+
262
+ for branch in self.attention_branches:
263
+ pooled, attn = branch(instances, stoch_mask)
264
+ branch_outputs.append(pooled)
265
+ branch_attns.append(attn)
266
+
267
+ # Concatenate all pooled representations
268
+ all_pooled = [mean_pooled, topk_pooled] + branch_outputs
269
+ concatenated = torch.cat(
270
+ all_pooled, dim=-1
271
+ ) # [batch, d_instance * (2 + n_branches)]
272
+
273
+ # Fuse
274
+ bag_embedding = self.fusion(concatenated)
275
+
276
+ # Compute attention health metrics
277
+ attention_info = {}
278
+ if return_attention:
279
+ # Stack all attention weights
280
+ all_attn = torch.stack(
281
+ branch_attns, dim=1
282
+ ) # [batch, n_branches, n_instances]
283
+
284
+ # Average attention across branches
285
+ avg_attn = all_attn.mean(dim=1) # [batch, n_instances]
286
+
287
+ # Attention entropy (higher = more distributed)
288
+ entropy = -(avg_attn * (avg_attn + 1e-8).log()).sum(dim=-1)
289
+
290
+ # Effective number of instances (inverse of concentration)
291
+ effective_n = 1.0 / (avg_attn**2).sum(dim=-1)
292
+
293
+ # Top-5% mass
294
+ k = max(1, int(instances.size(1) * 0.05))
295
+ top5_mass = avg_attn.topk(k, dim=-1).values.sum(dim=-1)
296
+
297
+ attention_info = {
298
+ "branch_attentions": all_attn, # [batch, n_branches, n_instances]
299
+ "average_attention": avg_attn, # [batch, n_instances]
300
+ "topk_mask": topk_mask, # [batch, n_instances]
301
+ "entropy": entropy, # [batch]
302
+ "effective_n": effective_n, # [batch]
303
+ "top5_mass": top5_mass, # [batch]
304
+ }
305
+
306
+ return bag_embedding, attention_info
307
+
308
+
309
+ class GatedMILAggregator(nn.Module):
310
+ """
311
+ Alternative MIL aggregator using gated attention.
312
+
313
+ Allows instance embeddings to modulate attention via gating,
314
+ which can capture more nuanced importance patterns.
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ d_instance: int = 256,
320
+ d_hidden: int = 128,
321
+ dropout: float = 0.1,
322
+ ):
323
+ super().__init__()
324
+
325
+ self.attention_v = nn.Sequential(
326
+ nn.Linear(d_instance, d_hidden),
327
+ nn.Tanh(),
328
+ )
329
+
330
+ self.attention_u = nn.Sequential(
331
+ nn.Linear(d_instance, d_hidden),
332
+ nn.Sigmoid(),
333
+ )
334
+
335
+ self.attention_w = nn.Linear(d_hidden, 1)
336
+
337
+ self.output_proj = nn.Sequential(
338
+ nn.Linear(d_instance, d_instance * 2),
339
+ nn.LayerNorm(d_instance * 2),
340
+ nn.GELU(),
341
+ nn.Dropout(dropout),
342
+ )
343
+
344
+ self.output_dim = d_instance * 2
345
+
346
+ def forward(
347
+ self,
348
+ instances: torch.Tensor,
349
+ mask: Optional[torch.Tensor] = None,
350
+ return_attention: bool = True,
351
+ ) -> tuple[torch.Tensor, dict]:
352
+ """
353
+ Args:
354
+ instances: [batch, n_instances, d_instance]
355
+ mask: [batch, n_instances]
356
+
357
+ Returns:
358
+ bag_embedding: [batch, output_dim]
359
+ attention_info: Dict with attention weights
360
+ """
361
+ # Gated attention
362
+ v = self.attention_v(instances) # [batch, n_instances, d_hidden]
363
+ u = self.attention_u(instances) # [batch, n_instances, d_hidden]
364
+
365
+ scores = self.attention_w(v * u).squeeze(-1) # [batch, n_instances]
366
+
367
+ if mask is not None:
368
+ scores = scores.masked_fill(mask == 0, float("-inf"))
369
+
370
+ attn_weights = F.softmax(scores, dim=-1)
371
+
372
+ if mask is not None:
373
+ attn_weights = attn_weights.masked_fill(mask == 0, 0.0)
374
+
375
+ # Weighted sum
376
+ pooled = (instances * attn_weights.unsqueeze(-1)).sum(dim=1)
377
+
378
+ # Project to output
379
+ bag_embedding = self.output_proj(pooled)
380
+
381
+ attention_info = {"attention": attn_weights} if return_attention else {}
382
+
383
+ return bag_embedding, attention_info
TaikoChartEstimator/model/encoder.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instance Encoder for Taiko Chart MIL
3
+
4
+ Encodes a sequence of event tokens into a fixed-size vector representation.
5
+ Uses Transformer encoder for capturing rhythm patterns and dependencies.
6
+ """
7
+
8
+ import math
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class PositionalEncoding(nn.Module):
17
+ """Sinusoidal positional encoding for sequences."""
18
+
19
+ def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
20
+ super().__init__()
21
+ self.dropout = nn.Dropout(p=dropout)
22
+
23
+ # Create positional encoding matrix
24
+ position = torch.arange(max_len).unsqueeze(1)
25
+ div_term = torch.exp(
26
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
27
+ )
28
+
29
+ pe = torch.zeros(max_len, d_model)
30
+ pe[:, 0::2] = torch.sin(position * div_term)
31
+ pe[:, 1::2] = torch.cos(position * div_term)
32
+
33
+ self.register_buffer("pe", pe.unsqueeze(0)) # [1, max_len, d_model]
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Args:
38
+ x: Tensor of shape [batch, seq_len, d_model]
39
+ """
40
+ x = x + self.pe[:, : x.size(1)]
41
+ return self.dropout(x)
42
+
43
+
44
+ class ContinuousFeatureEncoder(nn.Module):
45
+ """
46
+ Encodes continuous features (BPM, scroll, beat_pos, duration) to d_model dimension.
47
+ Uses learned linear projections with optional normalization.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ n_continuous: int = 5, # beat_pos, duration, bpm, scroll, gogo
53
+ d_model: int = 256,
54
+ use_layernorm: bool = True,
55
+ ):
56
+ super().__init__()
57
+ self.projection = nn.Linear(n_continuous, d_model)
58
+ self.layernorm = nn.LayerNorm(d_model) if use_layernorm else nn.Identity()
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Args:
63
+ x: Continuous features [batch, seq_len, n_continuous]
64
+ """
65
+ return self.layernorm(self.projection(x))
66
+
67
+
68
+ class InstanceEncoder(nn.Module):
69
+ """
70
+ Encodes a sequence of event tokens to a fixed-size vector.
71
+
72
+ Input: Token sequence [batch, seq_len, 6]
73
+ - Column 0: note_type (discrete, 0-9)
74
+ - Column 1: beat_position (continuous, 0-1)
75
+ - Column 2: duration (continuous, normalized)
76
+ - Column 3: bpm (continuous, normalized)
77
+ - Column 4: scroll (continuous, normalized)
78
+ - Column 5: gogo (binary, 0/1)
79
+
80
+ Output: Instance embedding [batch, d_model]
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ d_model: int = 256,
86
+ n_heads: int = 4,
87
+ n_layers: int = 4,
88
+ d_feedforward: int = 512,
89
+ dropout: float = 0.1,
90
+ n_note_types: int = 10, # 9 types + padding
91
+ max_seq_len: int = 128,
92
+ pooling: str = "cls", # "cls", "mean", or "max"
93
+ ):
94
+ """
95
+ Initialize instance encoder.
96
+
97
+ Args:
98
+ d_model: Model dimension
99
+ n_heads: Number of attention heads
100
+ n_layers: Number of transformer layers
101
+ d_feedforward: Feedforward dimension
102
+ dropout: Dropout rate
103
+ n_note_types: Number of note type categories
104
+ max_seq_len: Maximum sequence length
105
+ pooling: Pooling strategy for sequence to vector
106
+ """
107
+ super().__init__()
108
+
109
+ self.d_model = d_model
110
+ self.pooling = pooling
111
+
112
+ # Discrete feature embedding (note type)
113
+ self.type_embedding = nn.Embedding(n_note_types, d_model, padding_idx=9)
114
+
115
+ # Continuous feature encoder
116
+ self.continuous_encoder = ContinuousFeatureEncoder(
117
+ n_continuous=5, # beat_pos, duration, bpm, scroll, gogo
118
+ d_model=d_model,
119
+ )
120
+
121
+ # Feature fusion
122
+ self.fusion = nn.Linear(d_model * 2, d_model)
123
+ self.fusion_norm = nn.LayerNorm(d_model)
124
+
125
+ # Positional encoding (max_len+1 to accommodate CLS token)
126
+ self.pos_encoder = PositionalEncoding(d_model, max_seq_len + 1, dropout)
127
+
128
+ # CLS token for pooling
129
+ if pooling == "cls":
130
+ self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
131
+
132
+ # Transformer encoder
133
+ encoder_layer = nn.TransformerEncoderLayer(
134
+ d_model=d_model,
135
+ nhead=n_heads,
136
+ dim_feedforward=d_feedforward,
137
+ dropout=dropout,
138
+ activation="gelu",
139
+ batch_first=True,
140
+ norm_first=True, # Pre-LN for stability
141
+ )
142
+ self.transformer = nn.TransformerEncoder(
143
+ encoder_layer,
144
+ num_layers=n_layers,
145
+ )
146
+
147
+ # Output projection
148
+ self.output_norm = nn.LayerNorm(d_model)
149
+
150
+ def forward(
151
+ self,
152
+ tokens: torch.Tensor,
153
+ mask: Optional[torch.Tensor] = None,
154
+ ) -> torch.Tensor:
155
+ """
156
+ Encode token sequence to vector.
157
+
158
+ Args:
159
+ tokens: Token tensor [batch, seq_len, 6]
160
+ mask: Attention mask [batch, seq_len], 1 for valid, 0 for padding
161
+
162
+ Returns:
163
+ Instance embedding [batch, d_model]
164
+ """
165
+ batch_size, seq_len, _ = tokens.shape
166
+
167
+ # Split discrete and continuous features
168
+ note_types = tokens[:, :, 0].long() # [batch, seq_len]
169
+ continuous_feats = tokens[:, :, 1:] # [batch, seq_len, 5]
170
+
171
+ # Embed discrete features
172
+ type_emb = self.type_embedding(note_types) # [batch, seq_len, d_model]
173
+
174
+ # Encode continuous features
175
+ cont_emb = self.continuous_encoder(
176
+ continuous_feats
177
+ ) # [batch, seq_len, d_model]
178
+
179
+ # Fuse embeddings
180
+ fused = self.fusion(torch.cat([type_emb, cont_emb], dim=-1))
181
+ fused = self.fusion_norm(fused) # [batch, seq_len, d_model]
182
+
183
+ # Add CLS token if using CLS pooling
184
+ if self.pooling == "cls":
185
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
186
+ fused = torch.cat([cls_tokens, fused], dim=1) # [batch, 1+seq_len, d_model]
187
+
188
+ # Extend mask for CLS token
189
+ if mask is not None:
190
+ cls_mask = torch.ones(
191
+ batch_size, 1, device=mask.device, dtype=mask.dtype
192
+ )
193
+ mask = torch.cat([cls_mask, mask], dim=1)
194
+
195
+ # Add positional encoding
196
+ fused = self.pos_encoder(fused)
197
+
198
+ # Create attention mask for transformer (True = ignore)
199
+ if mask is not None:
200
+ attn_mask = mask == 0 # Invert: 0 -> True (ignore)
201
+ else:
202
+ attn_mask = None
203
+
204
+ # Apply transformer
205
+ encoded = self.transformer(fused, src_key_padding_mask=attn_mask)
206
+
207
+ # Pool to vector
208
+ if self.pooling == "cls":
209
+ output = encoded[:, 0] # CLS token
210
+ elif self.pooling == "mean":
211
+ if mask is not None:
212
+ # Masked mean (exclude padding)
213
+ mask_expanded = mask.unsqueeze(-1) # [batch, seq_len, 1]
214
+ output = (encoded * mask_expanded).sum(dim=1) / mask_expanded.sum(
215
+ dim=1
216
+ ).clamp(min=1)
217
+ else:
218
+ output = encoded.mean(dim=1)
219
+ elif self.pooling == "max":
220
+ if mask is not None:
221
+ # Masked max (set padding to -inf)
222
+ mask_expanded = mask.unsqueeze(-1)
223
+ encoded = encoded.masked_fill(mask_expanded == 0, float("-inf"))
224
+ output = encoded.max(dim=1).values
225
+ else:
226
+ raise ValueError(f"Unknown pooling method: {self.pooling}")
227
+
228
+ return self.output_norm(output)
229
+
230
+
231
+ class TCNBlock(nn.Module):
232
+ """Temporal Convolutional Network block with residual connection."""
233
+
234
+ def __init__(
235
+ self,
236
+ in_channels: int,
237
+ out_channels: int,
238
+ kernel_size: int = 3,
239
+ dilation: int = 1,
240
+ dropout: float = 0.1,
241
+ ):
242
+ super().__init__()
243
+
244
+ padding = (kernel_size - 1) * dilation // 2
245
+
246
+ self.conv1 = nn.Conv1d(
247
+ in_channels, out_channels, kernel_size, padding=padding, dilation=dilation
248
+ )
249
+ self.conv2 = nn.Conv1d(
250
+ out_channels, out_channels, kernel_size, padding=padding, dilation=dilation
251
+ )
252
+
253
+ self.norm1 = nn.BatchNorm1d(out_channels)
254
+ self.norm2 = nn.BatchNorm1d(out_channels)
255
+
256
+ self.dropout = nn.Dropout(dropout)
257
+
258
+ # Residual connection
259
+ self.residual = (
260
+ nn.Conv1d(in_channels, out_channels, 1)
261
+ if in_channels != out_channels
262
+ else nn.Identity()
263
+ )
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ """
267
+ Args:
268
+ x: [batch, channels, seq_len]
269
+ """
270
+ residual = self.residual(x)
271
+
272
+ out = F.gelu(self.norm1(self.conv1(x)))
273
+ out = self.dropout(out)
274
+ out = F.gelu(self.norm2(self.conv2(out)))
275
+ out = self.dropout(out)
276
+
277
+ return out + residual
278
+
279
+
280
+ class TCNInstanceEncoder(nn.Module):
281
+ """
282
+ Alternative instance encoder using Temporal Convolutional Network.
283
+ Faster than Transformer with stronger local inductive bias.
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ d_model: int = 256,
289
+ n_layers: int = 4,
290
+ kernel_size: int = 3,
291
+ dropout: float = 0.1,
292
+ n_note_types: int = 10,
293
+ ):
294
+ super().__init__()
295
+
296
+ self.d_model = d_model
297
+
298
+ # Input projection
299
+ self.type_embedding = nn.Embedding(n_note_types, d_model // 2, padding_idx=9)
300
+ self.continuous_proj = nn.Linear(5, d_model // 2)
301
+
302
+ # TCN layers with exponentially increasing dilation
303
+ self.tcn_layers = nn.ModuleList(
304
+ [
305
+ TCNBlock(d_model, d_model, kernel_size, dilation=2**i, dropout=dropout)
306
+ for i in range(n_layers)
307
+ ]
308
+ )
309
+
310
+ self.output_norm = nn.LayerNorm(d_model)
311
+
312
+ def forward(
313
+ self,
314
+ tokens: torch.Tensor,
315
+ mask: Optional[torch.Tensor] = None,
316
+ ) -> torch.Tensor:
317
+ """
318
+ Args:
319
+ tokens: [batch, seq_len, 6]
320
+ mask: [batch, seq_len]
321
+
322
+ Returns:
323
+ [batch, d_model]
324
+ """
325
+ # Embed inputs
326
+ note_types = tokens[:, :, 0].long()
327
+ continuous = tokens[:, :, 1:]
328
+
329
+ type_emb = self.type_embedding(note_types)
330
+ cont_emb = self.continuous_proj(continuous)
331
+
332
+ x = torch.cat([type_emb, cont_emb], dim=-1) # [batch, seq_len, d_model]
333
+
334
+ # Convert to channels-first for conv
335
+ x = x.transpose(1, 2) # [batch, d_model, seq_len]
336
+
337
+ # Apply TCN layers
338
+ for layer in self.tcn_layers:
339
+ x = layer(x)
340
+
341
+ # Global average pooling
342
+ if mask is not None:
343
+ mask_expanded = mask.unsqueeze(1) # [batch, 1, seq_len]
344
+ x = (x * mask_expanded).sum(dim=-1) / mask_expanded.sum(dim=-1).clamp(min=1)
345
+ else:
346
+ x = x.mean(dim=-1)
347
+
348
+ return self.output_norm(x)
TaikoChartEstimator/model/heads.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Output Heads for Taiko Chart Estimation
3
+
4
+ Three heads for multi-task learning:
5
+ - Head A: Raw difficulty score (unbounded)
6
+ - Head B: Difficulty classification (4-5 classes)
7
+ - Head C: Monotonic star calibration
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ class RawScoreHead(nn.Module):
19
+ """
20
+ Head A: Unbounded raw difficulty score.
21
+
22
+ Outputs s ∈ ℝ, the "true" continuous difficulty scale
23
+ before mapping to display star ratings.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ d_input: int = 512,
29
+ d_hidden: int = 128,
30
+ dropout: float = 0.1,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.mlp = nn.Sequential(
35
+ nn.Linear(d_input, d_hidden),
36
+ nn.LayerNorm(d_hidden),
37
+ nn.GELU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(d_hidden, d_hidden // 2),
40
+ nn.GELU(),
41
+ nn.Linear(d_hidden // 2, 1),
42
+ )
43
+
44
+ # Initialize to output reasonable range (~1-10)
45
+ self._init_weights()
46
+
47
+ def _init_weights(self):
48
+ """Initialize to output values centered around 5."""
49
+ with torch.no_grad():
50
+ # Bias the final layer to output ~5
51
+ self.mlp[-1].bias.fill_(5.0)
52
+ self.mlp[-1].weight.fill_(0.01)
53
+
54
+ def forward(self, bag_embedding: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Args:
57
+ bag_embedding: [batch, d_input]
58
+
59
+ Returns:
60
+ raw_score: [batch] unbounded difficulty score
61
+ """
62
+ return self.mlp(bag_embedding).squeeze(-1)
63
+
64
+
65
+ class DifficultyClassifier(nn.Module):
66
+ """
67
+ Head B: Difficulty classification.
68
+
69
+ Predicts difficulty class: easy, normal, hard, oni, ura (5 classes)
70
+ or merged oni_ura (4 classes).
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ d_input: int = 512,
76
+ n_classes: int = 5,
77
+ d_hidden: int = 128,
78
+ dropout: float = 0.1,
79
+ ):
80
+ super().__init__()
81
+
82
+ self.n_classes = n_classes
83
+
84
+ self.mlp = nn.Sequential(
85
+ nn.Linear(d_input, d_hidden),
86
+ nn.LayerNorm(d_hidden),
87
+ nn.GELU(),
88
+ nn.Dropout(dropout),
89
+ nn.Linear(d_hidden, n_classes),
90
+ )
91
+
92
+ def forward(self, bag_embedding: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Args:
95
+ bag_embedding: [batch, d_input]
96
+
97
+ Returns:
98
+ logits: [batch, n_classes] classification logits
99
+ """
100
+ return self.mlp(bag_embedding)
101
+
102
+ def predict(self, bag_embedding: torch.Tensor) -> torch.Tensor:
103
+ """Get predicted class indices."""
104
+ logits = self.forward(bag_embedding)
105
+ return logits.argmax(dim=-1)
106
+
107
+
108
+ class MonotonicSpline(nn.Module):
109
+ """
110
+ Monotonic spline for mapping raw score to star rating.
111
+
112
+ Uses I-splines (integrated B-splines) to guarantee monotonicity.
113
+ Learnable coefficients are constrained to be positive.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ n_knots: int = 8,
119
+ input_range: tuple[float, float] = (0, 15),
120
+ output_range: tuple[float, float] = (1, 10),
121
+ ):
122
+ super().__init__()
123
+
124
+ self.n_knots = n_knots
125
+ self.input_range = input_range
126
+ self.output_range = output_range
127
+
128
+ # Knot positions (fixed)
129
+ knots = torch.linspace(input_range[0], input_range[1], n_knots)
130
+ self.register_buffer("knots", knots)
131
+
132
+ # Learnable positive coefficients (using softplus for positivity)
133
+ self.raw_coefficients = nn.Parameter(torch.ones(n_knots))
134
+
135
+ # Learnable offset
136
+ self.offset = nn.Parameter(torch.tensor(float(output_range[0])))
137
+
138
+ def _compute_basis(self, x: torch.Tensor) -> torch.Tensor:
139
+ """Compute I-spline basis functions with clamping for stability."""
140
+ # Clamp input to reasonable range to prevent output explosion
141
+ x_clamped = x.clamp(self.input_range[0], self.input_range[1])
142
+ x_clamped = x_clamped.unsqueeze(-1) # [batch, 1]
143
+ knots = self.knots.unsqueeze(0) # [1, n_knots]
144
+
145
+ # Compute distance to each knot
146
+ diff = x_clamped - knots # [batch, n_knots]
147
+
148
+ # ReLU with cap to prevent unbounded growth
149
+ # Cap at input_range width for reasonable behavior
150
+ max_value = self.input_range[1] - self.input_range[0]
151
+ basis = F.relu(diff).clamp(max=max_value) # [batch, n_knots]
152
+
153
+ return basis
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ """
157
+ Map raw score to star rating (monotonically).
158
+
159
+ Args:
160
+ x: Raw scores [batch]
161
+
162
+ Returns:
163
+ Star ratings [batch]
164
+ """
165
+ # Ensure positive coefficients
166
+ coefficients = F.softplus(self.raw_coefficients)
167
+
168
+ # Normalize coefficients to control output scale
169
+ coefficients = coefficients / coefficients.sum()
170
+ scale = self.output_range[1] - self.output_range[0]
171
+ coefficients = coefficients * scale
172
+
173
+ # Compute basis
174
+ basis = self._compute_basis(x) # [batch, n_knots]
175
+
176
+ # Weighted sum
177
+ output = (basis * coefficients).sum(dim=-1) + self.offset
178
+
179
+ return output
180
+
181
+
182
+ class MonotonicMLP(nn.Module):
183
+ """
184
+ Monotonic MLP using positive weight constraints.
185
+
186
+ Ensures f(x1) >= f(x2) whenever x1 >= x2 by constraining
187
+ all weights to be positive and using monotonic activations.
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ d_hidden: int = 64,
193
+ n_layers: int = 3,
194
+ ):
195
+ super().__init__()
196
+
197
+ layers = []
198
+ in_dim = 1
199
+
200
+ for i in range(n_layers):
201
+ out_dim = d_hidden if i < n_layers - 1 else 1
202
+ layers.append(nn.Linear(in_dim, out_dim))
203
+ if i < n_layers - 1:
204
+ layers.append(nn.Softplus()) # Monotonic activation
205
+ in_dim = out_dim
206
+
207
+ self.layers = nn.ModuleList(
208
+ [layer for layer in layers if isinstance(layer, nn.Linear)]
209
+ )
210
+ self.activations = [nn.Softplus() for _ in range(n_layers - 1)] + [
211
+ nn.Identity()
212
+ ]
213
+
214
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
215
+ """
216
+ Args:
217
+ x: Raw scores [batch]
218
+
219
+ Returns:
220
+ Calibrated scores [batch]
221
+ """
222
+ out = x.unsqueeze(-1) # [batch, 1]
223
+
224
+ for layer, activation in zip(self.layers, self.activations):
225
+ # Apply absolute value to weights for monotonicity
226
+ weight = layer.weight.abs()
227
+ out = F.linear(out, weight, layer.bias)
228
+ out = activation(out)
229
+
230
+ return out.squeeze(-1)
231
+
232
+
233
+ class MonotonicCalibrator(nn.Module):
234
+ """
235
+ Head C: Monotonic calibration from raw score to star rating.
236
+
237
+ Maintains separate calibrators per difficulty level, since
238
+ the star ranges differ (easy: 1-5, normal: 1-7, etc.)
239
+
240
+ Guarantees:
241
+ - Output is monotonically increasing with input
242
+ - Can output values outside the nominal range (for decompression)
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ method: str = "spline", # "spline" or "mlp"
248
+ n_difficulties: int = 5,
249
+ star_ranges: Optional[dict] = None,
250
+ ):
251
+ """
252
+ Args:
253
+ method: Calibration method ("spline" or "mlp")
254
+ n_difficulties: Number of difficulty classes
255
+ star_ranges: Dict mapping difficulty index to (min, max) star range
256
+ """
257
+ super().__init__()
258
+
259
+ self.method = method
260
+ self.n_difficulties = n_difficulties
261
+
262
+ # Default star ranges per difficulty
263
+ if star_ranges is None:
264
+ star_ranges = {
265
+ 0: (1, 5), # easy
266
+ 1: (1, 7), # normal
267
+ 2: (1, 8), # hard
268
+ 3: (1, 10), # oni
269
+ 4: (1, 10), # ura
270
+ }
271
+ self.star_ranges = star_ranges
272
+
273
+ # Create calibrators per difficulty
274
+ if method == "spline":
275
+ self.calibrators = nn.ModuleList(
276
+ [
277
+ MonotonicSpline(
278
+ n_knots=8,
279
+ input_range=(0, 15),
280
+ output_range=star_ranges.get(i, (1, 10)),
281
+ )
282
+ for i in range(n_difficulties)
283
+ ]
284
+ )
285
+ else:
286
+ self.calibrators = nn.ModuleList(
287
+ [MonotonicMLP(d_hidden=32, n_layers=3) for i in range(n_difficulties)]
288
+ )
289
+
290
+ # Add scaling parameters for MLP
291
+ self.scales = nn.ParameterList(
292
+ [
293
+ nn.Parameter(
294
+ torch.tensor(
295
+ float(
296
+ star_ranges.get(i, (1, 10))[1]
297
+ - star_ranges.get(i, (1, 10))[0]
298
+ )
299
+ )
300
+ )
301
+ for i in range(n_difficulties)
302
+ ]
303
+ )
304
+ self.offsets = nn.ParameterList(
305
+ [
306
+ nn.Parameter(torch.tensor(float(star_ranges.get(i, (1, 10))[0])))
307
+ for i in range(n_difficulties)
308
+ ]
309
+ )
310
+
311
+ def forward(
312
+ self,
313
+ raw_score: torch.Tensor,
314
+ difficulty: torch.Tensor,
315
+ ) -> torch.Tensor:
316
+ """
317
+ Map raw scores to star ratings based on difficulty.
318
+
319
+ Args:
320
+ raw_score: [batch] raw difficulty scores
321
+ difficulty: [batch] difficulty class indices
322
+
323
+ Returns:
324
+ star_rating: [batch] calibrated star ratings (can be < min or > max)
325
+ """
326
+ batch_size = raw_score.size(0)
327
+ star_ratings = torch.zeros_like(raw_score)
328
+
329
+ # Process each difficulty class
330
+ for diff_idx in range(self.n_difficulties):
331
+ mask = difficulty == diff_idx
332
+ if mask.any():
333
+ calibrator = self.calibrators[diff_idx]
334
+
335
+ if self.method == "spline":
336
+ star_ratings[mask] = calibrator(raw_score[mask])
337
+ else:
338
+ # MLP with scaling
339
+ normalized = calibrator(raw_score[mask])
340
+ star_ratings[mask] = (
341
+ normalized * self.scales[diff_idx] + self.offsets[diff_idx]
342
+ )
343
+
344
+ return star_ratings
345
+
346
+ def forward_all(
347
+ self,
348
+ raw_score: torch.Tensor,
349
+ ) -> torch.Tensor:
350
+ """
351
+ Compute star ratings for all difficulties at once.
352
+
353
+ Args:
354
+ raw_score: [batch] raw scores
355
+
356
+ Returns:
357
+ star_ratings: [batch, n_difficulties] star per difficulty
358
+ """
359
+ batch_size = raw_score.size(0)
360
+ all_stars = []
361
+
362
+ for diff_idx in range(self.n_difficulties):
363
+ calibrator = self.calibrators[diff_idx]
364
+
365
+ if self.method == "spline":
366
+ stars = calibrator(raw_score)
367
+ else:
368
+ normalized = calibrator(raw_score)
369
+ stars = normalized * self.scales[diff_idx] + self.offsets[diff_idx]
370
+
371
+ all_stars.append(stars)
372
+
373
+ return torch.stack(all_stars, dim=-1)
374
+
375
+ def clip_to_display(
376
+ self,
377
+ star_rating: torch.Tensor,
378
+ difficulty: torch.Tensor,
379
+ ) -> torch.Tensor:
380
+ """
381
+ Clip star ratings to display range for UI.
382
+
383
+ Args:
384
+ star_rating: [batch] raw star ratings (can be outside range)
385
+ difficulty: [batch] difficulty indices
386
+
387
+ Returns:
388
+ display_star: [batch] clipped to valid range per difficulty
389
+ """
390
+ display_star = star_rating.clone()
391
+
392
+ for diff_idx in range(self.n_difficulties):
393
+ mask = difficulty == diff_idx
394
+ if mask.any():
395
+ min_star, max_star = self.star_ranges[diff_idx]
396
+ display_star[mask] = display_star[mask].clamp(min_star, max_star)
397
+
398
+ return display_star
TaikoChartEstimator/model/losses.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss Functions for Taiko Chart Estimation
3
+
4
+ Implements:
5
+ - Within-song ranking loss (monotonicity constraint)
6
+ - Censored regression loss (handles star boundary labels)
7
+ - Multi-task loss combiner with curriculum scheduling
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from ..constants import STAR_RANGES_BY_ID as STAR_RANGES
17
+
18
+
19
+ class WithinSongRankingLoss(nn.Module):
20
+ """
21
+ Ranking loss for enforcing within-song monotonicity.
22
+
23
+ For charts from the same song, harder difficulties must have
24
+ higher raw scores: s_harder > s_easier.
25
+
26
+ Uses hinge loss: L = max(0, margin - (s_harder - s_easier))
27
+ """
28
+
29
+ def __init__(self, margin: float = 0.5):
30
+ """
31
+ Args:
32
+ margin: Minimum required difference between difficulty levels
33
+ """
34
+ super().__init__()
35
+ self.margin = margin
36
+
37
+ def forward(
38
+ self,
39
+ s_easier: torch.Tensor,
40
+ s_harder: torch.Tensor,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Compute ranking loss for pairs.
44
+
45
+ Args:
46
+ s_easier: [n_pairs] scores for easier charts
47
+ s_harder: [n_pairs] scores for harder charts
48
+
49
+ Returns:
50
+ Scalar loss value
51
+ """
52
+ if s_easier.numel() == 0:
53
+ return torch.tensor(0.0, device=s_easier.device)
54
+
55
+ # Hinge loss
56
+ violations = F.relu(self.margin - (s_harder - s_easier))
57
+
58
+ return violations.mean()
59
+
60
+ def compute_violation_rate(
61
+ self,
62
+ s_easier: torch.Tensor,
63
+ s_harder: torch.Tensor,
64
+ ) -> float:
65
+ """Compute fraction of pairs that violate monotonicity."""
66
+ if s_easier.numel() == 0:
67
+ return 0.0
68
+
69
+ violations = (s_easier >= s_harder).float()
70
+ return violations.mean().item()
71
+
72
+
73
+ class CensoredRegressionLoss(nn.Module):
74
+ """
75
+ Censored regression loss for star ratings.
76
+
77
+ Handles the fact that boundary labels (1, 10) are censored:
78
+ - label == max_star: true value is >= max_star (right-censored)
79
+ - label == min_star: true value is <= min_star (left-censored)
80
+
81
+ For censored samples, we only penalize predictions that
82
+ violate the bound, not predictions that exceed it.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ uncensored_loss: str = "huber", # "huber", "mse", "mae"
88
+ huber_delta: float = 0.5,
89
+ star_ranges: Optional[dict] = None,
90
+ ):
91
+ """
92
+ Args:
93
+ uncensored_loss: Loss type for uncensored samples
94
+ huber_delta: Delta for Huber loss
95
+ star_ranges: Dict mapping difficulty index to (min, max) range
96
+ """
97
+ super().__init__()
98
+
99
+ self.uncensored_loss = uncensored_loss
100
+ self.huber_delta = huber_delta
101
+ self.star_ranges = star_ranges if star_ranges is not None else STAR_RANGES
102
+
103
+ def _uncensored_loss(
104
+ self,
105
+ pred: torch.Tensor,
106
+ target: torch.Tensor,
107
+ ) -> torch.Tensor:
108
+ """Compute loss for uncensored samples."""
109
+ if self.uncensored_loss == "huber":
110
+ return F.huber_loss(pred, target, delta=self.huber_delta, reduction="none")
111
+ elif self.uncensored_loss == "mse":
112
+ return F.mse_loss(pred, target, reduction="none")
113
+ elif self.uncensored_loss == "mae":
114
+ return F.l1_loss(pred, target, reduction="none")
115
+ else:
116
+ raise ValueError(f"Unknown loss type: {self.uncensored_loss}")
117
+
118
+ def forward(
119
+ self,
120
+ pred_star: torch.Tensor,
121
+ target_star: torch.Tensor,
122
+ difficulty: torch.Tensor,
123
+ is_right_censored: Optional[torch.Tensor] = None,
124
+ is_left_censored: Optional[torch.Tensor] = None,
125
+ ) -> torch.Tensor:
126
+ """
127
+ Compute censored regression loss.
128
+
129
+ Args:
130
+ pred_star: [batch] predicted star ratings
131
+ target_star: [batch] target star labels
132
+ difficulty: [batch] difficulty class indices
133
+ is_right_censored: [batch] bool, True if label is at max (right-censored)
134
+ is_left_censored: [batch] bool, True if label is at min (left-censored)
135
+
136
+ Returns:
137
+ Scalar loss value
138
+ """
139
+ batch_size = pred_star.size(0)
140
+
141
+ # Auto-detect censoring if not provided
142
+ if is_right_censored is None or is_left_censored is None:
143
+ is_right_censored = torch.zeros(
144
+ batch_size, dtype=torch.bool, device=pred_star.device
145
+ )
146
+ is_left_censored = torch.zeros(
147
+ batch_size, dtype=torch.bool, device=pred_star.device
148
+ )
149
+
150
+ for diff_idx, (min_star, max_star) in self.star_ranges.items():
151
+ mask = difficulty == diff_idx
152
+ is_right_censored[mask] = target_star[mask] >= max_star
153
+ is_left_censored[mask] = target_star[mask] <= min_star
154
+
155
+ # Compute losses per sample
156
+ losses = torch.zeros_like(pred_star)
157
+
158
+ # Right-censored: only penalize if pred < target
159
+ right_mask = is_right_censored
160
+ if right_mask.any():
161
+ shortfall = F.relu(target_star[right_mask] - pred_star[right_mask])
162
+ losses[right_mask] = shortfall
163
+
164
+ # Left-censored: only penalize if pred > target
165
+ left_mask = is_left_censored
166
+ if left_mask.any():
167
+ overshoot = F.relu(pred_star[left_mask] - target_star[left_mask])
168
+ losses[left_mask] = overshoot
169
+
170
+ # Uncensored: standard loss
171
+ uncensored_mask = ~(is_right_censored | is_left_censored)
172
+ if uncensored_mask.any():
173
+ losses[uncensored_mask] = self._uncensored_loss(
174
+ pred_star[uncensored_mask],
175
+ target_star[uncensored_mask],
176
+ )
177
+
178
+ return losses.mean()
179
+
180
+ def compute_censoring_metrics(
181
+ self,
182
+ pred_star: torch.Tensor,
183
+ target_star: torch.Tensor,
184
+ difficulty: torch.Tensor,
185
+ ) -> dict:
186
+ """
187
+ Compute censoring-related metrics.
188
+
189
+ Returns:
190
+ Dict with violation rates and shortfall/overshoot stats
191
+ """
192
+ metrics = {}
193
+
194
+ for diff_idx, (min_star, max_star) in self.star_ranges.items():
195
+ mask = difficulty == diff_idx
196
+ if not mask.any():
197
+ continue
198
+
199
+ preds = pred_star[mask]
200
+ targets = target_star[mask]
201
+
202
+ # Right-censored samples (at max)
203
+ right_mask = targets >= max_star
204
+ if right_mask.any():
205
+ right_preds = preds[right_mask]
206
+ violation_rate = (right_preds < max_star).float().mean().item()
207
+ mean_shortfall = F.relu(max_star - right_preds).mean().item()
208
+
209
+ metrics[f"right_violation_rate_{diff_idx}"] = violation_rate
210
+ metrics[f"mean_shortfall_{diff_idx}"] = mean_shortfall
211
+
212
+ # Left-censored samples (at min)
213
+ left_mask = targets <= min_star
214
+ if left_mask.any():
215
+ left_preds = preds[left_mask]
216
+ violation_rate = (left_preds > min_star).float().mean().item()
217
+ mean_overshoot = F.relu(left_preds - min_star).mean().item()
218
+
219
+ metrics[f"left_violation_rate_{diff_idx}"] = violation_rate
220
+ metrics[f"mean_overshoot_{diff_idx}"] = mean_overshoot
221
+
222
+ return metrics
223
+
224
+
225
+ class TotalLoss(nn.Module):
226
+ """
227
+ Multi-task loss combiner for difficulty estimation.
228
+
229
+ Combines:
230
+ - Classification loss (difficulty prediction)
231
+ - Censored star regression loss
232
+ - Within-song ranking loss (monotonicity)
233
+
234
+ Supports curriculum learning with schedulable weights.
235
+ Note: When merge_ura_oni=True, ura (4) and oni (3) are treated as the same class.
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ lambda_cls: float = 1.0,
241
+ lambda_star: float = 1.0,
242
+ lambda_rank: float = 1.0,
243
+ class_weights: Optional[torch.Tensor] = None,
244
+ ranking_margin: float = 0.5,
245
+ star_loss_type: str = "huber",
246
+ merge_ura_oni: bool = True,
247
+ ):
248
+ """
249
+ Args:
250
+ lambda_cls: Weight for classification loss
251
+ lambda_star: Weight for star regression loss
252
+ lambda_rank: Weight for ranking loss
253
+ class_weights: Optional class weights for classification
254
+ ranking_margin: Margin for ranking hinge loss
255
+ star_loss_type: Loss type for star regression
256
+ merge_ura_oni: If True, treat ura (4) as oni (3) for classification
257
+ """
258
+ super().__init__()
259
+
260
+ self.lambda_cls = lambda_cls
261
+ self.lambda_star = lambda_star
262
+ self.lambda_rank = lambda_rank
263
+ self.merge_ura_oni = merge_ura_oni
264
+
265
+ # Classification loss
266
+ self.cls_loss = nn.CrossEntropyLoss(weight=class_weights)
267
+
268
+ # Star regression loss
269
+ self.star_loss = CensoredRegressionLoss(uncensored_loss=star_loss_type)
270
+
271
+ # Ranking loss
272
+ self.rank_loss = WithinSongRankingLoss(margin=ranking_margin)
273
+
274
+ def set_weights(
275
+ self,
276
+ lambda_cls: Optional[float] = None,
277
+ lambda_star: Optional[float] = None,
278
+ lambda_rank: Optional[float] = None,
279
+ ):
280
+ """Update loss weights (for curriculum learning)."""
281
+ if lambda_cls is not None:
282
+ self.lambda_cls = lambda_cls
283
+ if lambda_star is not None:
284
+ self.lambda_star = lambda_star
285
+ if lambda_rank is not None:
286
+ self.lambda_rank = lambda_rank
287
+
288
+ def forward(
289
+ self,
290
+ difficulty_logits: torch.Tensor,
291
+ pred_star: torch.Tensor,
292
+ target_difficulty: torch.Tensor,
293
+ target_star: torch.Tensor,
294
+ is_right_censored: Optional[torch.Tensor] = None,
295
+ is_left_censored: Optional[torch.Tensor] = None,
296
+ ranking_pairs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
297
+ ) -> dict[str, torch.Tensor]:
298
+ """
299
+ Compute total loss with breakdown.
300
+
301
+ Args:
302
+ difficulty_logits: [batch, n_classes] difficulty predictions
303
+ pred_star: [batch] predicted star ratings
304
+ target_difficulty: [batch] target difficulty classes
305
+ target_star: [batch] target star labels
306
+ is_right_censored: [batch] right-censoring flags
307
+ is_left_censored: [batch] left-censoring flags
308
+ ranking_pairs: Optional (s_easier, s_harder) for ranking loss
309
+
310
+ Returns:
311
+ Dict with total loss and breakdown:
312
+ - "total": Combined weighted loss
313
+ - "cls": Classification loss
314
+ - "star": Star regression loss
315
+ - "rank": Ranking loss (if pairs provided)
316
+ """
317
+ losses = {}
318
+
319
+ # Classification loss
320
+ # Merge ura (4) and oni (3) if enabled
321
+ if self.merge_ura_oni:
322
+ # Merge target: map ura (class 4) to oni (class 3)
323
+ target_difficulty_merged = target_difficulty.clone()
324
+ target_difficulty_merged[target_difficulty_merged == 4] = 3
325
+
326
+ # Correct merging: use logsumexp in log-probability space
327
+ # This correctly computes P(oni OR ura) = P(oni) + P(ura)
328
+ log_probs = F.log_softmax(difficulty_logits, dim=-1) # [batch, 5]
329
+ log_probs_merged = log_probs[:, :4].clone() # [batch, 4]
330
+ # logsumexp(log P(oni), log P(ura)) = log(P(oni) + P(ura))
331
+ log_probs_merged[:, 3] = torch.logsumexp(log_probs[:, 3:5], dim=-1)
332
+
333
+ cls_loss = F.nll_loss(
334
+ log_probs_merged,
335
+ target_difficulty_merged,
336
+ weight=self.cls_loss.weight,
337
+ )
338
+ else:
339
+ cls_loss = self.cls_loss(difficulty_logits, target_difficulty)
340
+ losses["cls"] = cls_loss
341
+
342
+ # Star regression loss
343
+ star_loss = self.star_loss(
344
+ pred_star,
345
+ target_star,
346
+ target_difficulty,
347
+ is_right_censored,
348
+ is_left_censored,
349
+ )
350
+ losses["star"] = star_loss
351
+
352
+ # Ranking loss (if pairs provided)
353
+ if ranking_pairs is not None:
354
+ s_easier, s_harder = ranking_pairs
355
+ rank_loss = self.rank_loss(s_easier, s_harder)
356
+ losses["rank"] = rank_loss
357
+ else:
358
+ rank_loss = torch.tensor(0.0, device=pred_star.device)
359
+ losses["rank"] = rank_loss
360
+
361
+ # Combine with weights
362
+ total = (
363
+ self.lambda_cls * cls_loss
364
+ + self.lambda_star * star_loss
365
+ + self.lambda_rank * rank_loss
366
+ )
367
+ losses["total"] = total
368
+
369
+ return losses
370
+
371
+
372
+ class CurriculumScheduler:
373
+ """
374
+ Scheduler for curriculum learning of loss weights.
375
+
376
+ Early training: focus on classification (coarse alignment)
377
+ Later training: increase ranking + star loss (fine-grained)
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ total_steps: int,
383
+ warmup_fraction: float = 0.2,
384
+ cls_start: float = 2.0,
385
+ cls_end: float = 0.5,
386
+ rank_start: float = 0.1,
387
+ rank_end: float = 1.5,
388
+ star_start: float = 0.5,
389
+ star_end: float = 1.5,
390
+ ):
391
+ """
392
+ Args:
393
+ total_steps: Total training steps
394
+ warmup_fraction: Fraction of training for warmup
395
+ *_start/*_end: Start and end values for each loss weight
396
+ """
397
+ self.total_steps = total_steps
398
+ self.warmup_steps = int(total_steps * warmup_fraction)
399
+
400
+ self.cls_start = cls_start
401
+ self.cls_end = cls_end
402
+ self.rank_start = rank_start
403
+ self.rank_end = rank_end
404
+ self.star_start = star_start
405
+ self.star_end = star_end
406
+
407
+ def get_weights(self, step: int) -> dict[str, float]:
408
+ """
409
+ Get loss weights for current step.
410
+
411
+ Returns:
412
+ Dict with lambda_cls, lambda_star, lambda_rank
413
+ """
414
+ if step < self.warmup_steps:
415
+ # During warmup: interpolate from start to mid
416
+ t = step / self.warmup_steps
417
+ else:
418
+ # After warmup: continue to end
419
+ t = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
420
+ t = min(1.0, t) # Clamp at 1
421
+
422
+ # Linear interpolation
423
+ lambda_cls = self.cls_start + t * (self.cls_end - self.cls_start)
424
+ lambda_rank = self.rank_start + t * (self.rank_end - self.rank_start)
425
+ lambda_star = self.star_start + t * (self.star_end - self.star_start)
426
+
427
+ return {
428
+ "lambda_cls": lambda_cls,
429
+ "lambda_star": lambda_star,
430
+ "lambda_rank": lambda_rank,
431
+ }
TaikoChartEstimator/model/model.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main TaikoChartEstimator Model
3
+
4
+ Combines instance encoder, MIL aggregator, and output heads
5
+ into a unified model for difficulty estimation.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from huggingface_hub import PyTorchModelHubMixin
14
+
15
+ from ..data.tokenizer import DIFFICULTY_ORDER
16
+ from .aggregator import GatedMILAggregator, MILAggregator
17
+ from .encoder import InstanceEncoder, TCNInstanceEncoder
18
+ from .heads import DifficultyClassifier, MonotonicCalibrator, RawScoreHead
19
+
20
+
21
+ @dataclass
22
+ class ModelConfig:
23
+ """Configuration for TaikoChartEstimator."""
24
+
25
+ # Instance encoder config
26
+ encoder_type: str = "transformer" # "transformer" or "tcn"
27
+ d_model: int = 256
28
+ n_encoder_layers: int = 4
29
+ n_heads: int = 4
30
+ d_feedforward: int = 512
31
+ encoder_dropout: float = 0.1
32
+ max_seq_len: int = 128
33
+ encoder_pooling: str = "cls"
34
+
35
+ # MIL aggregator config
36
+ aggregator_type: str = "multibranch" # "multibranch" or "gated"
37
+ n_attention_branches: int = 3
38
+ top_k_ratio: float = 0.1
39
+ stochastic_mask_prob: float = 0.3
40
+ aggregator_dropout: float = 0.1
41
+
42
+ # Head config
43
+ n_difficulty_classes: int = 5 # easy, normal, hard, oni, ura
44
+ head_hidden_dim: int = 128
45
+ head_dropout: float = 0.1
46
+ calibrator_method: str = "spline" # "spline" or "mlp"
47
+
48
+ # Star ranges per difficulty
49
+ star_ranges: dict = None
50
+
51
+ def __post_init__(self):
52
+ if self.star_ranges is None:
53
+ self.star_ranges = {
54
+ 0: (1, 5), # easy
55
+ 1: (1, 7), # normal
56
+ 2: (1, 8), # hard
57
+ 3: (1, 10), # oni
58
+ 4: (1, 10), # ura
59
+ }
60
+ else:
61
+ # Fix JSON serialization issue: keys become strings, values become lists
62
+ # Convert back to int keys and tuple values
63
+ self.star_ranges = {
64
+ int(k): tuple(v) if isinstance(v, list) else v
65
+ for k, v in self.star_ranges.items()
66
+ }
67
+
68
+
69
+ @dataclass
70
+ class ModelOutput:
71
+ """Output from TaikoChartEstimator forward pass."""
72
+
73
+ raw_score: torch.Tensor # [batch] unbounded difficulty score
74
+ difficulty_logits: torch.Tensor # [batch, n_classes] difficulty logits
75
+ raw_star: torch.Tensor # [batch] star rating (can be < 1 or > 10)
76
+ display_star: torch.Tensor # [batch] star rating clipped to range
77
+ attention_info: dict # MIL attention weights and metrics
78
+ instance_embeddings: torch.Tensor # [batch, n_instances, d_model] for analysis
79
+
80
+
81
+ class TaikoChartEstimator(nn.Module, PyTorchModelHubMixin):
82
+ """
83
+ MIL-based Taiko chart difficulty estimation model.
84
+
85
+ Takes a bag of chart instances (beat-aligned windows) and predicts:
86
+ 1. Raw difficulty score (unbounded, ℝ)
87
+ 2. Difficulty class (easy/normal/hard/oni/ura)
88
+ 3. Star rating (per difficulty, can exceed nominal range)
89
+
90
+ Architecture:
91
+ - Instance Encoder: Transformer or TCN to encode each window
92
+ - MIL Aggregator: Multi-branch attention pooling
93
+ - Output Heads: Raw score, classifier, monotonic calibrator
94
+ """
95
+
96
+ def __init__(self, config: Optional[ModelConfig] = None):
97
+ """
98
+ Initialize model.
99
+
100
+ Args:
101
+ config: Model configuration (uses defaults if None)
102
+ """
103
+ super().__init__()
104
+
105
+ if config is None:
106
+ config = ModelConfig()
107
+ self.config = config
108
+
109
+ # Build instance encoder
110
+ if config.encoder_type == "transformer":
111
+ self.instance_encoder = InstanceEncoder(
112
+ d_model=config.d_model,
113
+ n_heads=config.n_heads,
114
+ n_layers=config.n_encoder_layers,
115
+ d_feedforward=config.d_feedforward,
116
+ dropout=config.encoder_dropout,
117
+ max_seq_len=config.max_seq_len,
118
+ pooling=config.encoder_pooling,
119
+ )
120
+ else:
121
+ self.instance_encoder = TCNInstanceEncoder(
122
+ d_model=config.d_model,
123
+ n_layers=config.n_encoder_layers,
124
+ dropout=config.encoder_dropout,
125
+ )
126
+
127
+ # Build MIL aggregator
128
+ if config.aggregator_type == "multibranch":
129
+ self.aggregator = MILAggregator(
130
+ d_instance=config.d_model,
131
+ n_branches=config.n_attention_branches,
132
+ top_k_ratio=config.top_k_ratio,
133
+ stochastic_mask_prob=config.stochastic_mask_prob,
134
+ dropout=config.aggregator_dropout,
135
+ )
136
+ else:
137
+ self.aggregator = GatedMILAggregator(
138
+ d_instance=config.d_model,
139
+ dropout=config.aggregator_dropout,
140
+ )
141
+
142
+ # Output heads
143
+ bag_dim = self.aggregator.output_dim
144
+
145
+ self.raw_score_head = RawScoreHead(
146
+ d_input=bag_dim,
147
+ d_hidden=config.head_hidden_dim,
148
+ dropout=config.head_dropout,
149
+ )
150
+
151
+ self.difficulty_classifier = DifficultyClassifier(
152
+ d_input=bag_dim,
153
+ n_classes=config.n_difficulty_classes,
154
+ d_hidden=config.head_hidden_dim,
155
+ dropout=config.head_dropout,
156
+ )
157
+
158
+ self.calibrator = MonotonicCalibrator(
159
+ method=config.calibrator_method,
160
+ n_difficulties=config.n_difficulty_classes,
161
+ star_ranges=config.star_ranges,
162
+ )
163
+
164
+ def encode_instances(
165
+ self,
166
+ instances: torch.Tensor,
167
+ instance_masks: torch.Tensor,
168
+ ) -> torch.Tensor:
169
+ """
170
+ Encode all instances in a batch.
171
+
172
+ Args:
173
+ instances: [batch, n_instances, seq_len, 6] token sequences
174
+ instance_masks: [batch, n_instances, seq_len] attention masks
175
+
176
+ Returns:
177
+ instance_embeddings: [batch, n_instances, d_model]
178
+ """
179
+ batch_size, n_instances, seq_len, n_features = instances.shape
180
+
181
+ # Flatten batch and instances
182
+ flat_instances = instances.view(batch_size * n_instances, seq_len, n_features)
183
+ flat_masks = instance_masks.view(batch_size * n_instances, seq_len)
184
+
185
+ # Encode
186
+ flat_embeddings = self.instance_encoder(flat_instances, flat_masks)
187
+
188
+ # Reshape back
189
+ instance_embeddings = flat_embeddings.view(batch_size, n_instances, -1)
190
+
191
+ return instance_embeddings
192
+
193
+ def forward(
194
+ self,
195
+ instances: torch.Tensor,
196
+ instance_masks: torch.Tensor,
197
+ instance_counts: Optional[torch.Tensor] = None,
198
+ difficulty_hint: Optional[torch.Tensor] = None,
199
+ return_attention: bool = True,
200
+ ) -> ModelOutput:
201
+ """
202
+ Forward pass through the model.
203
+
204
+ Args:
205
+ instances: [batch, n_instances, seq_len, 6] token sequences
206
+ instance_masks: [batch, n_instances, seq_len] token masks
207
+ instance_counts: [batch] number of valid instances per sample
208
+ difficulty_hint: [batch] difficulty class for calibration (uses predicted if None)
209
+ return_attention: Whether to return attention weights
210
+
211
+ Returns:
212
+ ModelOutput with all predictions
213
+ """
214
+ batch_size, n_instances, seq_len, _ = instances.shape
215
+
216
+ # Create instance-level mask from counts
217
+ if instance_counts is not None:
218
+ bag_mask = torch.arange(n_instances, device=instances.device).unsqueeze(0)
219
+ bag_mask = (bag_mask < instance_counts.unsqueeze(1)).float()
220
+ else:
221
+ # Infer from instance masks (if any token is valid, instance is valid)
222
+ bag_mask = (instance_masks.sum(dim=-1) > 0).float()
223
+
224
+ # Encode instances
225
+ instance_embeddings = self.encode_instances(instances, instance_masks)
226
+
227
+ # Aggregate to bag embedding
228
+ bag_embedding, attention_info = self.aggregator(
229
+ instance_embeddings,
230
+ bag_mask,
231
+ return_attention=return_attention,
232
+ )
233
+
234
+ # Raw score prediction (unbounded)
235
+ raw_score = self.raw_score_head(bag_embedding)
236
+
237
+ # Difficulty classification
238
+ difficulty_logits = self.difficulty_classifier(bag_embedding)
239
+
240
+ # Determine difficulty for calibration
241
+ if difficulty_hint is not None:
242
+ calibration_diff = difficulty_hint
243
+ else:
244
+ calibration_diff = difficulty_logits.argmax(dim=-1)
245
+
246
+ # Calibrate to star rating
247
+ raw_star = self.calibrator(raw_score, calibration_diff)
248
+ display_star = self.calibrator.clip_to_display(raw_star, calibration_diff)
249
+
250
+ return ModelOutput(
251
+ raw_score=raw_score,
252
+ difficulty_logits=difficulty_logits,
253
+ raw_star=raw_star,
254
+ display_star=display_star,
255
+ attention_info=attention_info,
256
+ instance_embeddings=instance_embeddings,
257
+ )
258
+
259
+ def predict(
260
+ self,
261
+ instances: torch.Tensor,
262
+ instance_masks: torch.Tensor,
263
+ instance_counts: Optional[torch.Tensor] = None,
264
+ ) -> dict:
265
+ """
266
+ Convenience method for inference.
267
+
268
+ Returns dict with human-readable outputs:
269
+ - difficulty_class: Predicted difficulty name
270
+ - raw_score: Unbounded difficulty score
271
+ - raw_star: Star rating (may exceed range)
272
+ - display_star: Star rating for display (clipped)
273
+ """
274
+ output = self.forward(
275
+ instances,
276
+ instance_masks,
277
+ instance_counts,
278
+ difficulty_hint=None,
279
+ return_attention=False,
280
+ )
281
+
282
+ difficulty_names = ["easy", "normal", "hard", "oni", "ura"]
283
+ predicted_class = output.difficulty_logits.argmax(dim=-1)
284
+
285
+ return {
286
+ "difficulty_class": [difficulty_names[c] for c in predicted_class.tolist()],
287
+ "difficulty_class_id": predicted_class,
288
+ "raw_score": output.raw_score,
289
+ "raw_star": output.raw_star,
290
+ "display_star": output.display_star,
291
+ }
292
+
293
+ def get_ranking_pairs_from_batch(
294
+ self,
295
+ raw_scores: torch.Tensor,
296
+ song_ids: list[str],
297
+ difficulties: list[str],
298
+ ) -> tuple[torch.Tensor, torch.Tensor]:
299
+ """
300
+ Extract within-song ranking pairs from a batch.
301
+
302
+ Args:
303
+ raw_scores: [batch] raw difficulty scores
304
+ song_ids: List of song IDs
305
+ difficulties: List of difficulty names
306
+
307
+ Returns:
308
+ (s_easier, s_harder) tensors for ranking loss
309
+ """
310
+
311
+ # Group by song
312
+ song_to_indices: dict[str, list[int]] = {}
313
+ for i, song_id in enumerate(song_ids):
314
+ if song_id not in song_to_indices:
315
+ song_to_indices[song_id] = []
316
+ song_to_indices[song_id].append(i)
317
+
318
+ easier_scores = []
319
+ harder_scores = []
320
+
321
+ for song_id, indices in song_to_indices.items():
322
+ if len(indices) < 2:
323
+ continue
324
+
325
+ # Sort by difficulty
326
+ sorted_indices = sorted(
327
+ indices, key=lambda i: DIFFICULTY_ORDER.get(difficulties[i], 0)
328
+ )
329
+
330
+ # Create pairs
331
+ for i in range(len(sorted_indices) - 1):
332
+ easier_idx = sorted_indices[i]
333
+ harder_idx = sorted_indices[i + 1]
334
+
335
+ easier_scores.append(raw_scores[easier_idx])
336
+ harder_scores.append(raw_scores[harder_idx])
337
+
338
+ if not easier_scores:
339
+ return (
340
+ torch.tensor([], device=raw_scores.device),
341
+ torch.tensor([], device=raw_scores.device),
342
+ )
343
+
344
+ return (
345
+ torch.stack(easier_scores),
346
+ torch.stack(harder_scores),
347
+ )
348
+
349
+
350
+ def create_model(
351
+ d_model: int = 256,
352
+ n_layers: int = 4,
353
+ encoder_type: str = "transformer",
354
+ **kwargs,
355
+ ) -> TaikoChartEstimator:
356
+ """
357
+ Factory function to create model with common configurations.
358
+
359
+ Args:
360
+ d_model: Model dimension
361
+ n_layers: Number of encoder layers
362
+ encoder_type: "transformer" or "tcn"
363
+ **kwargs: Additional config overrides
364
+
365
+ Returns:
366
+ Configured TaikoChartEstimator
367
+ """
368
+ config = ModelConfig(
369
+ encoder_type=encoder_type,
370
+ d_model=d_model,
371
+ n_encoder_layers=n_layers,
372
+ **kwargs,
373
+ )
374
+ return TaikoChartEstimator(config)
TaikoChartEstimator/train/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ TaikoChartEstimator Training Package
3
+ """
4
+
5
+ from . import __main__
6
+
7
+ __all__ = ["__main__"]
TaikoChartEstimator/train/__main__.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Script for TaikoChartEstimator
3
+
4
+ Main entry point for training the MIL-based difficulty estimation model.
5
+ Supports:
6
+ - Multi-task learning (classification + regression + ranking)
7
+ - Curriculum learning for loss weights
8
+ - TensorBoard logging
9
+ - Multi-objective checkpoint selection
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ from datetime import datetime
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.optim as optim
22
+ from scipy.stats import spearmanr
23
+ from sklearn.metrics import (
24
+ balanced_accuracy_score,
25
+ f1_score,
26
+ precision_score,
27
+ recall_score,
28
+ )
29
+ from torch.utils.data import DataLoader, Subset
30
+ from torch.utils.tensorboard import SummaryWriter
31
+ from tqdm import tqdm
32
+
33
+ from ..data import TaikoChartDataset, WithinSongPairSampler, collate_chart_bags
34
+ from ..data.tokenizer import DIFFICULTY_ORDER
35
+ from ..model import CurriculumScheduler, ModelConfig, TaikoChartEstimator, TotalLoss
36
+
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(description="Train TaikoChartEstimator")
40
+
41
+ # Data arguments
42
+ parser.add_argument(
43
+ "--dataset",
44
+ type=str,
45
+ default="JacobLinCool/taiko-1000-parsed",
46
+ help="HuggingFace dataset name",
47
+ )
48
+ parser.add_argument(
49
+ "--cache-dir", type=str, default=None, help="Cache directory for dataset"
50
+ )
51
+ parser.add_argument(
52
+ "--include-audio", action="store_true", help="Include audio features (slower)"
53
+ )
54
+
55
+ # Model arguments
56
+ parser.add_argument("--d-model", type=int, default=256, help="Model dimension")
57
+ parser.add_argument(
58
+ "--n-layers", type=int, default=4, help="Number of encoder layers"
59
+ )
60
+ parser.add_argument(
61
+ "--encoder-type",
62
+ type=str,
63
+ default="transformer",
64
+ choices=["transformer", "tcn"],
65
+ help="Instance encoder type",
66
+ )
67
+ parser.add_argument(
68
+ "--n-branches", type=int, default=3, help="Number of attention branches in MIL"
69
+ )
70
+
71
+ # Training arguments
72
+ parser.add_argument(
73
+ "--epochs", type=int, default=100, help="Number of training epochs"
74
+ )
75
+ parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
76
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
77
+ parser.add_argument("--weight-decay", type=float, default=0.01, help="Weight decay")
78
+ parser.add_argument(
79
+ "--grad-clip", type=float, default=1.0, help="Gradient clipping norm"
80
+ )
81
+
82
+ # Loss weights
83
+ parser.add_argument(
84
+ "--lambda-cls", type=float, default=1.0, help="Classification loss weight"
85
+ )
86
+ parser.add_argument(
87
+ "--lambda-star", type=float, default=1.0, help="Star regression loss weight"
88
+ )
89
+ parser.add_argument(
90
+ "--lambda-rank", type=float, default=1.0, help="Ranking loss weight"
91
+ )
92
+ parser.add_argument(
93
+ "--use-curriculum",
94
+ action="store_true",
95
+ help="Use curriculum learning for loss weights",
96
+ )
97
+
98
+ # Checkpointing and logging
99
+ parser.add_argument(
100
+ "--output-dir", type=str, default="outputs", help="Output directory"
101
+ )
102
+ parser.add_argument(
103
+ "--tensorboard-dir", type=str, default="runs", help="TensorBoard log directory"
104
+ )
105
+ parser.add_argument(
106
+ "--save-every", type=int, default=5, help="Save checkpoint every N epochs"
107
+ )
108
+ parser.add_argument(
109
+ "--eval-every", type=int, default=1, help="Evaluate every N epochs"
110
+ )
111
+
112
+ # Misc
113
+ parser.add_argument("--seed", type=int, default=2025, help="Random seed")
114
+ parser.add_argument(
115
+ "--device",
116
+ type=str,
117
+ default="cuda" if torch.cuda.is_available() else "cpu",
118
+ help="Device to use",
119
+ )
120
+ parser.add_argument(
121
+ "--overfit-batch",
122
+ action="store_true",
123
+ help="Overfit on a single batch (for debugging)",
124
+ )
125
+ parser.add_argument(
126
+ "--num-workers", type=int, default=16, help="Number of data loader workers"
127
+ )
128
+
129
+ return parser.parse_args()
130
+
131
+
132
+ def set_seed(seed: int):
133
+ """Set random seeds for reproducibility."""
134
+ torch.manual_seed(seed)
135
+ torch.cuda.manual_seed_all(seed)
136
+ np.random.seed(seed)
137
+
138
+
139
+ def compute_class_weights(
140
+ dataset: TaikoChartDataset, merge_ura_oni: bool = True
141
+ ) -> torch.Tensor:
142
+ """Compute class weights based on class frequencies.
143
+
144
+ Args:
145
+ dataset: The training dataset
146
+ merge_ura_oni: If True, treat ura and oni as the same class (4 classes total)
147
+
148
+ Returns:
149
+ Class weights tensor (4 or 5 weights depending on merge_ura_oni)
150
+ """
151
+ n_classes = 4 if merge_ura_oni else 5
152
+ class_counts = [0] * n_classes
153
+
154
+ for song_idx, diff in dataset.chart_index:
155
+ diff_id = {"easy": 0, "normal": 1, "hard": 2, "oni": 3, "ura": 4}.get(diff, 0)
156
+ # Merge ura into oni if enabled
157
+ if merge_ura_oni and diff_id == 4:
158
+ diff_id = 3
159
+ class_counts[diff_id] += 1
160
+
161
+ total = sum(class_counts)
162
+ weights = [
163
+ total / (n_classes * count) if count > 0 else 1.0 for count in class_counts
164
+ ]
165
+
166
+ return torch.tensor(weights, dtype=torch.float32)
167
+
168
+
169
+ def extract_ranking_pairs(
170
+ batch: dict, raw_scores: torch.Tensor
171
+ ) -> tuple[torch.Tensor, torch.Tensor]:
172
+ """Extract within-song ranking pairs from batch."""
173
+ song_ids = batch["song_ids"]
174
+ difficulties = batch["difficulties"]
175
+
176
+ # Group by song
177
+ song_to_indices: dict[str, list[int]] = {}
178
+ for i, song_id in enumerate(song_ids):
179
+ if song_id not in song_to_indices:
180
+ song_to_indices[song_id] = []
181
+ song_to_indices[song_id].append(i)
182
+
183
+ easier_scores = []
184
+ harder_scores = []
185
+
186
+ for song_id, indices in song_to_indices.items():
187
+ if len(indices) < 2:
188
+ continue
189
+
190
+ # Sort by difficulty
191
+ sorted_indices = sorted(
192
+ indices, key=lambda i: DIFFICULTY_ORDER.get(difficulties[i], 0)
193
+ )
194
+
195
+ # Create adjacent pairs
196
+ for i in range(len(sorted_indices) - 1):
197
+ easier_idx = sorted_indices[i]
198
+ harder_idx = sorted_indices[i + 1]
199
+
200
+ easier_scores.append(raw_scores[easier_idx])
201
+ harder_scores.append(raw_scores[harder_idx])
202
+
203
+ if not easier_scores:
204
+ return (
205
+ torch.tensor([], device=raw_scores.device),
206
+ torch.tensor([], device=raw_scores.device),
207
+ )
208
+
209
+ return torch.stack(easier_scores), torch.stack(harder_scores)
210
+
211
+
212
+ def train_epoch(
213
+ model: TaikoChartEstimator,
214
+ dataloader: DataLoader,
215
+ criterion: TotalLoss,
216
+ optimizer: optim.Optimizer,
217
+ scheduler: Optional[optim.lr_scheduler._LRScheduler],
218
+ device: torch.device,
219
+ epoch: int,
220
+ writer: Optional[SummaryWriter] = None,
221
+ curriculum: Optional[CurriculumScheduler] = None,
222
+ grad_clip: float = 1.0,
223
+ ) -> dict:
224
+ """Train for one epoch."""
225
+ model.train()
226
+
227
+ total_loss = 0.0
228
+ total_cls_loss = 0.0
229
+ total_star_loss = 0.0
230
+ total_rank_loss = 0.0
231
+ n_batches = 0
232
+ n_ranking_pairs = 0
233
+
234
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
235
+
236
+ for batch_idx, batch in enumerate(pbar):
237
+ global_step = epoch * len(dataloader) + batch_idx
238
+
239
+ # Update curriculum weights
240
+ if curriculum is not None:
241
+ weights = curriculum.get_weights(global_step)
242
+ criterion.set_weights(**weights)
243
+
244
+ # Move batch to device
245
+ instances = batch["instances"].to(device)
246
+ instance_masks = batch["instance_masks"].to(device)
247
+ instance_counts = batch["instance_counts"].to(device)
248
+ difficulty_class = batch["difficulty_class"].to(device)
249
+ star = batch["star"].to(device)
250
+ is_right_censored = batch["is_right_censored"].to(device)
251
+ is_left_censored = batch["is_left_censored"].to(device)
252
+
253
+ # Forward pass
254
+ output = model(
255
+ instances,
256
+ instance_masks,
257
+ instance_counts,
258
+ difficulty_hint=difficulty_class, # Use ground truth for training
259
+ )
260
+
261
+ # Extract ranking pairs
262
+ s_easier, s_harder = extract_ranking_pairs(batch, output.raw_score)
263
+ ranking_pairs = (s_easier, s_harder) if s_easier.numel() > 0 else None
264
+ n_ranking_pairs += s_easier.numel()
265
+
266
+ # Compute losses
267
+ losses = criterion(
268
+ difficulty_logits=output.difficulty_logits,
269
+ pred_star=output.raw_star,
270
+ target_difficulty=difficulty_class,
271
+ target_star=star,
272
+ is_right_censored=is_right_censored,
273
+ is_left_censored=is_left_censored,
274
+ ranking_pairs=ranking_pairs,
275
+ )
276
+
277
+ # Backward pass
278
+ optimizer.zero_grad()
279
+ losses["total"].backward()
280
+
281
+ # Gradient clipping
282
+ if grad_clip > 0:
283
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
284
+
285
+ optimizer.step()
286
+
287
+ # Track losses
288
+ total_loss += losses["total"].item()
289
+ total_cls_loss += losses["cls"].item()
290
+ total_star_loss += losses["star"].item()
291
+ total_rank_loss += losses["rank"].item()
292
+ n_batches += 1
293
+
294
+ # Update progress bar
295
+ pbar.set_postfix(
296
+ {
297
+ "loss": f"{losses['total'].item():.4f}",
298
+ "cls": f"{losses['cls'].item():.4f}",
299
+ "star": f"{losses['star'].item():.4f}",
300
+ "rank": f"{losses['rank'].item():.4f}",
301
+ }
302
+ )
303
+
304
+ # Log to TensorBoard
305
+ if writer is not None and batch_idx % 10 == 0:
306
+ writer.add_scalar("train/loss_total", losses["total"].item(), global_step)
307
+ writer.add_scalar("train/loss_cls", losses["cls"].item(), global_step)
308
+ writer.add_scalar("train/loss_star", losses["star"].item(), global_step)
309
+ writer.add_scalar("train/loss_rank", losses["rank"].item(), global_step)
310
+
311
+ # Log attention health metrics
312
+ if "entropy" in output.attention_info:
313
+ writer.add_scalar(
314
+ "train/attention_entropy",
315
+ output.attention_info["entropy"].mean().item(),
316
+ global_step,
317
+ )
318
+ if "effective_n" in output.attention_info:
319
+ writer.add_scalar(
320
+ "train/effective_instances",
321
+ output.attention_info["effective_n"].mean().item(),
322
+ global_step,
323
+ )
324
+ if "top5_mass" in output.attention_info:
325
+ writer.add_scalar(
326
+ "train/top5_attention_mass",
327
+ output.attention_info["top5_mass"].mean().item(),
328
+ global_step,
329
+ )
330
+
331
+ if scheduler is not None:
332
+ scheduler.step()
333
+
334
+ return {
335
+ "loss": total_loss / n_batches,
336
+ "cls_loss": total_cls_loss / n_batches,
337
+ "star_loss": total_star_loss / n_batches,
338
+ "rank_loss": total_rank_loss / n_batches,
339
+ "n_ranking_pairs": n_ranking_pairs,
340
+ }
341
+
342
+
343
+ @torch.no_grad()
344
+ def evaluate(
345
+ model: TaikoChartEstimator,
346
+ dataloader: DataLoader,
347
+ criterion: TotalLoss,
348
+ device: torch.device,
349
+ ) -> dict:
350
+ """Evaluate model on validation set."""
351
+ model.eval()
352
+
353
+ all_pred_class = []
354
+ all_true_class = []
355
+ all_pred_star = []
356
+ all_true_star = []
357
+ all_raw_scores = []
358
+ all_difficulties = []
359
+ all_song_ids = []
360
+ all_is_right_censored = []
361
+
362
+ total_loss = 0.0
363
+ n_batches = 0
364
+
365
+ for batch in tqdm(dataloader, desc="Evaluating"):
366
+ instances = batch["instances"].to(device)
367
+ instance_masks = batch["instance_masks"].to(device)
368
+ instance_counts = batch["instance_counts"].to(device)
369
+ difficulty_class = batch["difficulty_class"].to(device)
370
+ star = batch["star"].to(device)
371
+ is_right_censored = batch["is_right_censored"].to(device)
372
+ is_left_censored = batch["is_left_censored"].to(device)
373
+
374
+ output = model(
375
+ instances,
376
+ instance_masks,
377
+ instance_counts,
378
+ difficulty_hint=difficulty_class,
379
+ )
380
+
381
+ # Compute loss
382
+ losses = criterion(
383
+ difficulty_logits=output.difficulty_logits,
384
+ pred_star=output.raw_star,
385
+ target_difficulty=difficulty_class,
386
+ target_star=star,
387
+ is_right_censored=is_right_censored,
388
+ is_left_censored=is_left_censored,
389
+ )
390
+
391
+ total_loss += losses["total"].item()
392
+ n_batches += 1
393
+
394
+ # Collect predictions
395
+ all_pred_class.extend(output.difficulty_logits.argmax(dim=-1).cpu().tolist())
396
+ all_true_class.extend(difficulty_class.cpu().tolist())
397
+ all_pred_star.extend(output.raw_star.cpu().tolist())
398
+ all_true_star.extend(star.cpu().tolist())
399
+ all_raw_scores.extend(output.raw_score.cpu().tolist())
400
+ all_difficulties.extend(batch["difficulties"])
401
+ all_song_ids.extend(batch["song_ids"])
402
+ all_is_right_censored.extend(is_right_censored.cpu().tolist())
403
+
404
+ # Compute metrics
405
+ all_pred_class = np.array(all_pred_class)
406
+ all_true_class = np.array(all_true_class)
407
+ all_pred_star = np.array(all_pred_star)
408
+ all_true_star = np.array(all_true_star)
409
+ all_raw_scores = np.array(all_raw_scores)
410
+ all_is_right_censored = np.array(all_is_right_censored)
411
+
412
+ # Merge ura (4) and oni (3) for classification metrics
413
+ # They are essentially the same difficulty level
414
+ all_pred_class_merged = all_pred_class.copy()
415
+ all_true_class_merged = all_true_class.copy()
416
+ all_pred_class_merged[all_pred_class_merged == 4] = 3 # Map ura -> oni
417
+ all_true_class_merged[all_true_class_merged == 4] = 3 # Map ura -> oni
418
+
419
+ # Classification metrics (using merged classes)
420
+ macro_f1 = f1_score(all_true_class_merged, all_pred_class_merged, average="macro")
421
+ balanced_acc = balanced_accuracy_score(all_true_class_merged, all_pred_class_merged)
422
+ plus_minus_1_acc = (
423
+ np.abs(all_pred_class_merged - all_true_class_merged) <= 1
424
+ ).mean()
425
+
426
+ # Per-difficulty classification metrics (precision, recall, F1)
427
+ diff_names_cls = ["easy", "normal", "hard", "oni_ura"]
428
+ per_diff_cls_metrics = {}
429
+
430
+ per_class_f1 = f1_score(
431
+ all_true_class_merged, all_pred_class_merged, average=None, labels=[0, 1, 2, 3]
432
+ )
433
+ per_class_precision = precision_score(
434
+ all_true_class_merged,
435
+ all_pred_class_merged,
436
+ average=None,
437
+ labels=[0, 1, 2, 3],
438
+ zero_division=0,
439
+ )
440
+ per_class_recall = recall_score(
441
+ all_true_class_merged,
442
+ all_pred_class_merged,
443
+ average=None,
444
+ labels=[0, 1, 2, 3],
445
+ zero_division=0,
446
+ )
447
+
448
+ for i, name in enumerate(diff_names_cls):
449
+ if i < len(per_class_f1):
450
+ per_diff_cls_metrics[f"f1_{name}"] = per_class_f1[i]
451
+ per_diff_cls_metrics[f"precision_{name}"] = per_class_precision[i]
452
+ per_diff_cls_metrics[f"recall_{name}"] = per_class_recall[i]
453
+
454
+ # Star regression metrics (on uncensored samples)
455
+ uncensored_mask = ~all_is_right_censored
456
+ if uncensored_mask.sum() > 0:
457
+ mae_uncensored = np.abs(
458
+ all_pred_star[uncensored_mask] - all_true_star[uncensored_mask]
459
+ ).mean()
460
+ spearman_rho, _ = spearmanr(all_pred_star, all_true_star)
461
+ else:
462
+ mae_uncensored = 0.0
463
+ spearman_rho = 0.0
464
+
465
+ # Per-difficulty Star MAE & RMSE (using merged oni/ura as same class)
466
+ diff_names_merged = ["easy", "normal", "hard", "oni_ura"]
467
+ per_diff_star_metrics = {}
468
+
469
+ for diff_idx, diff_name in enumerate(diff_names_merged):
470
+ if diff_idx == 3:
471
+ # oni_ura: merge classes 3 and 4
472
+ mask = (all_true_class == 3) | (all_true_class == 4)
473
+ else:
474
+ mask = all_true_class == diff_idx
475
+
476
+ if mask.sum() > 0:
477
+ diff_pred = all_pred_star[mask]
478
+ diff_true = all_true_star[mask]
479
+ diff_errors = diff_pred - diff_true
480
+
481
+ per_diff_star_metrics[f"mae_star_{diff_name}"] = np.abs(diff_errors).mean()
482
+ per_diff_star_metrics[f"rmse_star_{diff_name}"] = np.sqrt(
483
+ (diff_errors**2).mean()
484
+ )
485
+ else:
486
+ per_diff_star_metrics[f"mae_star_{diff_name}"] = 0.0
487
+ per_diff_star_metrics[f"rmse_star_{diff_name}"] = 0.0
488
+
489
+ # Monotonicity metrics
490
+ song_groups: dict[str, list] = {}
491
+ for i, song_id in enumerate(all_song_ids):
492
+ if song_id not in song_groups:
493
+ song_groups[song_id] = []
494
+ song_groups[song_id].append(
495
+ {
496
+ "difficulty": all_difficulties[i],
497
+ "raw_score": all_raw_scores[i],
498
+ }
499
+ )
500
+
501
+ n_violations = 0
502
+ n_pairs = 0
503
+
504
+ for song_id, charts in song_groups.items():
505
+ if len(charts) < 2:
506
+ continue
507
+
508
+ sorted_charts = sorted(
509
+ charts, key=lambda c: DIFFICULTY_ORDER.get(c["difficulty"], 0)
510
+ )
511
+
512
+ for i in range(len(sorted_charts) - 1):
513
+ n_pairs += 1
514
+ if sorted_charts[i]["raw_score"] >= sorted_charts[i + 1]["raw_score"]:
515
+ n_violations += 1
516
+
517
+ violation_rate = n_violations / n_pairs if n_pairs > 0 else 0.0
518
+
519
+ # Decompression metrics (for 10-star samples)
520
+ max_star_mask = all_true_star >= 10.0
521
+ if max_star_mask.sum() > 1:
522
+ pred_10star = all_pred_star[max_star_mask]
523
+ decompression_std = pred_10star.std()
524
+ p90_p50 = np.percentile(pred_10star, 90) - np.percentile(pred_10star, 50)
525
+ else:
526
+ decompression_std = 0.0
527
+ p90_p50 = 0.0
528
+
529
+ result = {
530
+ "loss": total_loss / n_batches,
531
+ "macro_f1": macro_f1,
532
+ "balanced_accuracy": balanced_acc,
533
+ "plus_minus_1_accuracy": plus_minus_1_acc,
534
+ "mae_uncensored": mae_uncensored,
535
+ "spearman_rho": spearman_rho,
536
+ "monotonicity_violation_rate": violation_rate,
537
+ "decompression_std": decompression_std,
538
+ "decompression_p90_p50": p90_p50,
539
+ }
540
+ # Add per-difficulty classification metrics
541
+ result.update(per_diff_cls_metrics)
542
+ # Add per-difficulty star metrics
543
+ result.update(per_diff_star_metrics)
544
+
545
+ return result
546
+
547
+
548
+ def save_checkpoint(
549
+ model: TaikoChartEstimator,
550
+ optimizer: optim.Optimizer,
551
+ epoch: int,
552
+ metrics: dict,
553
+ output_dir: Path,
554
+ name: str = "checkpoint",
555
+ ):
556
+ """Save model checkpoint."""
557
+ checkpoint = {
558
+ "epoch": epoch,
559
+ "model_state_dict": model.state_dict(),
560
+ "optimizer_state_dict": optimizer.state_dict(),
561
+ "metrics": metrics,
562
+ "config": model.config.__dict__,
563
+ }
564
+
565
+ pretrained_path = output_dir / "pretrained" / name
566
+ model.save_pretrained(pretrained_path)
567
+
568
+ path = output_dir / f"{name}_epoch{epoch}.pt"
569
+ torch.save(checkpoint, path)
570
+
571
+ # Also save as latest
572
+ latest_path = output_dir / f"{name}_latest.pt"
573
+ torch.save(checkpoint, latest_path)
574
+
575
+ return path
576
+
577
+
578
+ def main():
579
+ args = parse_args()
580
+ set_seed(args.seed)
581
+
582
+ # Create output directories
583
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
584
+ output_dir = Path(args.output_dir) / timestamp
585
+ output_dir.mkdir(parents=True, exist_ok=True)
586
+
587
+ tensorboard_dir = Path(args.tensorboard_dir) / timestamp
588
+ writer = SummaryWriter(tensorboard_dir)
589
+
590
+ # Save args
591
+ with open(output_dir / "args.json", "w") as f:
592
+ json.dump(vars(args), f, indent=2)
593
+
594
+ print(f"Output directory: {output_dir}")
595
+ print(f"TensorBoard directory: {tensorboard_dir}")
596
+
597
+ # Load datasets
598
+ print("Loading datasets...")
599
+ train_dataset = TaikoChartDataset(
600
+ split="train",
601
+ dataset_name=args.dataset,
602
+ include_audio=args.include_audio,
603
+ cache_dir=args.cache_dir,
604
+ )
605
+
606
+ val_dataset = TaikoChartDataset(
607
+ split="test",
608
+ dataset_name=args.dataset,
609
+ include_audio=args.include_audio,
610
+ cache_dir=args.cache_dir,
611
+ )
612
+
613
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
614
+
615
+ # Create data loaders
616
+ if args.overfit_batch:
617
+ # Take a small subset for debugging
618
+ train_dataset = Subset(train_dataset, list(range(min(32, len(train_dataset)))))
619
+ val_dataset = Subset(val_dataset, list(range(min(8, len(val_dataset)))))
620
+
621
+ train_sampler = WithinSongPairSampler(
622
+ train_dataset
623
+ if not isinstance(train_dataset, torch.utils.data.Subset)
624
+ else train_dataset.dataset,
625
+ min_batch_size=args.batch_size,
626
+ shuffle=True,
627
+ seed=args.seed,
628
+ )
629
+
630
+ train_loader = DataLoader(
631
+ train_dataset,
632
+ batch_sampler=train_sampler if not args.overfit_batch else None,
633
+ batch_size=args.batch_size if args.overfit_batch else 1,
634
+ shuffle=args.overfit_batch,
635
+ collate_fn=collate_chart_bags,
636
+ num_workers=args.num_workers,
637
+ pin_memory=True,
638
+ )
639
+
640
+ val_loader = DataLoader(
641
+ val_dataset,
642
+ batch_size=args.batch_size,
643
+ shuffle=False,
644
+ collate_fn=collate_chart_bags,
645
+ num_workers=args.num_workers,
646
+ pin_memory=True,
647
+ )
648
+
649
+ # Create model
650
+ print("Creating model...")
651
+ config = ModelConfig(
652
+ encoder_type=args.encoder_type,
653
+ d_model=args.d_model,
654
+ n_encoder_layers=args.n_layers,
655
+ n_attention_branches=args.n_branches,
656
+ )
657
+ model = TaikoChartEstimator(config)
658
+ model = model.to(args.device)
659
+
660
+ # Count parameters
661
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
662
+ print(f"Model parameters: {n_params:,}")
663
+
664
+ # Create optimizer and scheduler
665
+ optimizer = optim.AdamW(
666
+ model.parameters(),
667
+ lr=args.lr,
668
+ weight_decay=args.weight_decay,
669
+ )
670
+
671
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
672
+ optimizer,
673
+ T_max=args.epochs,
674
+ eta_min=args.lr * 0.01,
675
+ )
676
+
677
+ # Create loss function
678
+ class_weights = compute_class_weights(
679
+ train_dataset
680
+ if not isinstance(train_dataset, torch.utils.data.Subset)
681
+ else train_dataset.dataset
682
+ ).to(args.device)
683
+
684
+ criterion = TotalLoss(
685
+ lambda_cls=args.lambda_cls,
686
+ lambda_star=args.lambda_star,
687
+ lambda_rank=args.lambda_rank,
688
+ class_weights=class_weights,
689
+ )
690
+
691
+ # Curriculum scheduler
692
+ curriculum = None
693
+ if args.use_curriculum:
694
+ total_steps = args.epochs * len(train_loader)
695
+ curriculum = CurriculumScheduler(total_steps)
696
+
697
+ # Composite score function for model selection
698
+ def compute_composite_score(metrics: dict) -> float:
699
+ """
700
+ Compute weighted composite score for model selection.
701
+
702
+ Weights prioritize Spearman (star ranking) as the core objective.
703
+ - Spearman ρ: 55% (star prediction ranking accuracy)
704
+ - Macro-F1: 25% (difficulty classification)
705
+ - Violation Rate: 20% (monotonicity constraint)
706
+ """
707
+ # Clamp to reasonable ranges observed in training
708
+ f1 = max(0.70, min(0.90, metrics["macro_f1"]))
709
+ spearman = max(0.80, min(0.98, metrics["spearman_rho"]))
710
+ violation = max(0.0, min(0.10, metrics["monotonicity_violation_rate"]))
711
+
712
+ # Normalize to 0-1
713
+ f1_norm = (f1 - 0.70) / 0.20
714
+ spearman_norm = (spearman - 0.80) / 0.18
715
+ violation_norm = 1.0 - violation / 0.10 # Lower is better
716
+
717
+ return 0.6 * spearman_norm + 0.25 * f1_norm + 0.15 * violation_norm
718
+
719
+ # Training loop
720
+ print("Starting training...")
721
+ best_metrics = {
722
+ "macro_f1": 0.0,
723
+ "spearman_rho": 0.0,
724
+ "monotonicity_violation_rate": 1.0,
725
+ }
726
+ best_composite_score = 0.0
727
+
728
+ for epoch in range(1, args.epochs + 1):
729
+ # Train
730
+ train_metrics = train_epoch(
731
+ model=model,
732
+ dataloader=train_loader,
733
+ criterion=criterion,
734
+ optimizer=optimizer,
735
+ scheduler=scheduler,
736
+ device=torch.device(args.device),
737
+ epoch=epoch,
738
+ writer=writer,
739
+ curriculum=curriculum,
740
+ grad_clip=args.grad_clip,
741
+ )
742
+
743
+ print(
744
+ f"Epoch {epoch} - Train Loss: {train_metrics['loss']:.4f}, "
745
+ f"Cls: {train_metrics['cls_loss']:.4f}, Star: {train_metrics['star_loss']:.4f}, "
746
+ f"Rank: {train_metrics['rank_loss']:.4f} ({train_metrics['n_ranking_pairs']} pairs)"
747
+ )
748
+
749
+ # Log training metrics
750
+ writer.add_scalar("epoch/train_loss", train_metrics["loss"], epoch)
751
+ writer.add_scalar("epoch/learning_rate", scheduler.get_last_lr()[0], epoch)
752
+
753
+ # Evaluate
754
+ if epoch % args.eval_every == 0:
755
+ val_metrics = evaluate(
756
+ model=model,
757
+ dataloader=val_loader,
758
+ criterion=criterion,
759
+ device=torch.device(args.device),
760
+ )
761
+
762
+ # Compute composite score
763
+ composite_score = compute_composite_score(val_metrics)
764
+
765
+ print(
766
+ f"Epoch {epoch} - Val Loss: {val_metrics['loss']:.4f}, "
767
+ f"Macro-F1: {val_metrics['macro_f1']:.4f}, "
768
+ f"Spearman: {val_metrics['spearman_rho']:.4f}, "
769
+ f"Violation Rate: {val_metrics['monotonicity_violation_rate']:.4f}, "
770
+ f"Decomp Std: {val_metrics['decompression_std']:.4f}, "
771
+ f"Composite: {composite_score:.4f}"
772
+ )
773
+
774
+ # Log validation metrics
775
+ for key, value in val_metrics.items():
776
+ writer.add_scalar(f"val/{key}", value, epoch)
777
+ writer.add_scalar("val/composite_score", composite_score, epoch)
778
+
779
+ # Save best model based on composite score
780
+ if composite_score > best_composite_score:
781
+ best_composite_score = composite_score
782
+ best_metrics = val_metrics
783
+ save_checkpoint(
784
+ model, optimizer, epoch, val_metrics, output_dir, "best"
785
+ )
786
+ print(f" -> New best model saved! (Composite: {composite_score:.4f})")
787
+
788
+ # Periodic checkpoint
789
+ if epoch % args.save_every == 0:
790
+ save_checkpoint(
791
+ model, optimizer, epoch, train_metrics, output_dir, "checkpoint"
792
+ )
793
+
794
+ # Save final model
795
+ save_checkpoint(model, optimizer, args.epochs, best_metrics, output_dir, "final")
796
+
797
+ print(f"\nTraining complete!")
798
+ print(f"Best Composite Score: {best_composite_score:.4f}")
799
+ print(f" - Macro-F1: {best_metrics['macro_f1']:.4f}")
800
+ print(f" - Spearman: {best_metrics['spearman_rho']:.4f}")
801
+ print(f" - Violation Rate: {best_metrics['monotonicity_violation_rate']:.4f}")
802
+ print(f"Checkpoints saved to: {output_dir}")
803
+
804
+ writer.close()
805
+
806
+
807
+ if __name__ == "__main__":
808
+ main()