| """ |
| Dataset loader for maimai training data. |
| |
| Provides: |
| - SongDataset: scans datasets/ folder, parses all songs |
| - ChartDataset: PyTorch Dataset for chart note sequences |
| - Utility functions for note-to-tensor conversion |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import pickle |
| from pathlib import Path |
| from typing import Iterator, Optional, Union |
|
|
| from .models import Cabinet, Chart, Difficulty, Song, TouchNote |
| from .parser import parse_maidata_file |
|
|
| |
| try: |
| import torch |
| from torch.utils.data import Dataset as TorchDataset |
| HAS_TORCH = True |
| except (ImportError, OSError): |
| HAS_TORCH = False |
|
|
|
|
| class SongDataset: |
| """ |
| A collection of parsed Song objects from the datasets/ directory. |
| |
| Usage: |
| ds = SongDataset("datasets") |
| ds.scan() # discover all songs |
| ds.load(use_cache=True) # parse all songs (with caching) |
| print(len(ds)) # total songs |
| song = ds[0] # first song |
| """ |
|
|
| def __init__(self, root_dir: str | Path): |
| self.root_dir = Path(root_dir) |
| self.songs: list[Song] = [] |
| self._by_id: dict[str, Song] = {} |
| self._cache_path = self.root_dir / ".mai_parser_cache.pkl" |
|
|
| def scan(self) -> list[Path]: |
| """ |
| Scan the dataset directory for all song folders (containing maidata.txt). |
| |
| Returns list of found song folder paths. |
| """ |
| folders = [] |
| for item in sorted(self.root_dir.iterdir()): |
| if item.is_dir() and (item / "maidata.txt").exists(): |
| folders.append(item) |
| return folders |
|
|
| def load(self, use_cache: bool = True, max_workers: int = 8, |
| progress: bool = True) -> "SongDataset": |
| """ |
| Parse all songs from discovered folders. |
| |
| Args: |
| use_cache: If True, try loading from cache first; save after parsing. |
| max_workers: Number of threads for parallel parsing. |
| progress: Print progress information. |
| """ |
| if use_cache and self._cache_path.exists(): |
| try: |
| with open(self._cache_path, "rb") as f: |
| self.songs = pickle.load(f) |
| self._build_index() |
| if progress: |
| print(f"Loaded {len(self.songs)} songs from cache.") |
| return self |
| except (pickle.UnpicklingError, EOFError, KeyError) as e: |
| if progress: |
| print(f"Cache invalid ({e}), re-parsing...") |
|
|
| folders = self.scan() |
| if progress: |
| print(f"Found {len(folders)} song folders. Parsing...") |
|
|
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
| songs: list[Song] = [] |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| futures = { |
| executor.submit(parse_maidata_file, f / "maidata.txt"): f |
| for f in folders |
| } |
| for i, future in enumerate(as_completed(futures)): |
| folder = futures[future] |
| try: |
| song = future.result() |
| song.maidata_path = str(folder / "maidata.txt") |
| song.audio_path = str(folder / "track.mp3") |
| songs.append(song) |
| except Exception as e: |
| if progress: |
| print(f" [WARN] Failed to parse {folder.name}: {e}") |
|
|
| if progress and (i + 1) % 200 == 0: |
| print(f" ... parsed {i + 1}/{len(folders)}") |
|
|
| |
| songs.sort(key=lambda s: _sort_key(s.song_id)) |
|
|
| self.songs = songs |
| self._build_index() |
|
|
| if use_cache: |
| try: |
| with open(self._cache_path, "wb") as f: |
| pickle.dump(self.songs, f, protocol=pickle.HIGHEST_PROTOCOL) |
| if progress: |
| print(f"Saved cache to {self._cache_path}") |
| except OSError as e: |
| if progress: |
| print(f" [WARN] Could not save cache: {e}") |
|
|
| if progress: |
| print(f"Loaded {len(self.songs)} songs.") |
| return self |
|
|
| def _build_index(self) -> None: |
| self._by_id = {s.song_id: s for s in self.songs} |
|
|
| def __len__(self) -> int: |
| return len(self.songs) |
|
|
| def __getitem__(self, idx: int) -> Song: |
| return self.songs[idx] |
|
|
| def __iter__(self) -> Iterator[Song]: |
| return iter(self.songs) |
|
|
| def get_by_id(self, song_id: str) -> Song | None: |
| """Look up a song by its folder name / ID.""" |
| return self._by_id.get(song_id) |
|
|
| def filter( |
| self, |
| cabinet: Cabinet | None = None, |
| min_level: float | None = None, |
| max_level: float | None = None, |
| difficulty: Difficulty | None = None, |
| genre: str | None = None, |
| has_audio: bool = False, |
| ) -> list[Song]: |
| """ |
| Filter songs by various criteria. |
| |
| Args: |
| cabinet: Filter by cabinet type (SD/DX). |
| min_level: Minimum chart level. |
| max_level: Maximum chart level. |
| difficulty: Only songs with this difficulty. |
| genre: Filter by genre name (partial match). |
| has_audio: Only songs where track.mp3 exists. |
| """ |
| result = [] |
| for song in self.songs: |
| if cabinet is not None and song.cabinet != cabinet: |
| continue |
| if genre is not None and genre.lower() not in song.genre.lower(): |
| continue |
| if has_audio: |
| audio_path = self.root_dir / song.song_id / "track.mp3" |
| if not audio_path.exists(): |
| continue |
| if difficulty is not None: |
| chart = song.get_chart(difficulty) |
| if chart is None: |
| continue |
| if min_level is not None and chart.level_value < min_level: |
| continue |
| if max_level is not None and chart.level_value > max_level: |
| continue |
| result.append(song) |
| return result |
|
|
| def stats(self) -> dict: |
| """Return summary statistics of the dataset.""" |
| total = len(self.songs) |
| sd_count = sum(1 for s in self.songs if s.cabinet == Cabinet.SD) |
| dx_count = sum(1 for s in self.songs if s.cabinet == Cabinet.DX) |
| unknown_cab = total - sd_count - dx_count |
| total_charts = sum(len(s.charts) for s in self.songs) |
| genres = list(set(s.genre for s in self.songs if s.genre)) |
|
|
| |
| level_dist: dict[str, list[float]] = {} |
| for song in self.songs: |
| for idx, chart in song.charts.items(): |
| key = f"lv_{idx}" |
| if key not in level_dist: |
| level_dist[key] = [] |
| level_dist[key].append(chart.level_value) |
|
|
| return { |
| "total_songs": total, |
| "sd_songs": sd_count, |
| "dx_songs": dx_count, |
| "unknown_cabinet": unknown_cab, |
| "total_charts": total_charts, |
| "genres": sorted(genres), |
| "level_distribution": { |
| k: { |
| "min": min(v) if v else 0, |
| "max": max(v) if v else 0, |
| "avg": sum(v) / len(v) if v else 0, |
| } |
| for k, v in level_dist.items() |
| }, |
| } |
|
|
|
|
| def _sort_key(song_id: str) -> tuple: |
| """Sort: numeric IDs first, then strings.""" |
| try: |
| return (0, int(song_id)) |
| except ValueError: |
| return (1, song_id) |
|
|
|
|
| |
|
|
| |
| NOTE_TYPE = { |
| "rest": 0, |
| "tap": 1, |
| "hold": 2, |
| "slide": 3, |
| "break": 4, |
| "touch": 5, |
| "end": 6, |
| } |
|
|
| |
| NOTE_TYPE_INV = {v: k for k, v in NOTE_TYPE.items()} |
|
|
| |
| NUM_POSITIONS = 8 |
|
|
|
|
| def note_to_vector(note: TouchNote, num_positions: int = NUM_POSITIONS) -> list[float]: |
| """ |
| Convert a single TouchNote to a feature vector. |
| |
| Format (fixed-length): |
| [type_rest, type_tap, type_hold, type_slide, type_break, type_touch, type_end, |
| pos_1, pos_2, ..., pos_N, |
| is_simultaneous, has_slide_path_len_2, has_slide_path_len_3, has_slide_path_len_4, |
| hold_beat, hold_subdiv, |
| beat_div_normalized] |
| |
| Returns a flat list of floats. |
| """ |
| vec = [0.0] * (7 + num_positions + 4 + 2 + 1) |
|
|
| |
| if note.is_end: |
| vec[6] = 1.0 |
| elif note.is_rest: |
| vec[0] = 1.0 |
| elif note.is_touch: |
| vec[5] = 1.0 |
| elif note.is_break: |
| vec[4] = 1.0 |
| elif note.is_hold: |
| vec[2] = 1.0 |
| elif note.is_slide: |
| vec[3] = 1.0 |
| else: |
| vec[1] = 1.0 |
|
|
| |
| offset = 7 |
| for pos in note.positions: |
| if 1 <= pos <= num_positions: |
| vec[offset + pos - 1] = 1.0 |
|
|
| |
| offset += num_positions |
| sl = len(note.slide_path) |
| if sl >= 2: |
| vec[offset] = 1.0 |
| if sl >= 3: |
| vec[offset + 1] = 1.0 |
| if sl >= 4: |
| vec[offset + 2] = 1.0 |
| vec[offset + 3] = 1.0 if note.is_simultaneous else 0.0 |
|
|
| |
| offset += 4 |
| if note.hold_duration: |
| vec[offset] = float(note.hold_duration[0]) |
| vec[offset + 1] = float(note.hold_duration[1]) |
|
|
| |
| vec[-1] = note.beat_div |
|
|
| return vec |
|
|
|
|
| def chart_to_tensor(chart: Chart) -> "torch.Tensor": |
| """ |
| Convert an entire chart's notes to a PyTorch tensor. |
| |
| Returns tensor of shape (num_notes, feature_dim). |
| """ |
| if not HAS_TORCH: |
| raise ImportError("PyTorch is required for tensor conversion. " |
| "Install with: pip install torch") |
| vectors = [note_to_vector(n) for n in chart.notes] |
| return torch.tensor(vectors, dtype=torch.float32) |
|
|
|
|
| def chart_to_sequence(chart: Chart) -> list[list[float]]: |
| """Convert a chart to a list of feature vectors (no PyTorch dependency).""" |
| return [note_to_vector(n) for n in chart.notes] |
|
|
|
|
| |
|
|
| if HAS_TORCH: |
|
|
| class ChartDataset(TorchDataset): |
| """ |
| PyTorch Dataset for maimai chart training. |
| |
| Yields (note_sequence, metadata) pairs where: |
| - note_sequence is a tensor of shape (seq_len, feature_dim) |
| - metadata is a dict with song info |
| |
| Usage: |
| from torch.utils.data import DataLoader |
| ds = ChartDataset(song_dataset, difficulty=Difficulty.MASTER) |
| loader = DataLoader(ds, batch_size=32, collate_fn=ds.collate_fn) |
| """ |
|
|
| def __init__( |
| self, |
| song_dataset: SongDataset, |
| difficulty: Difficulty | None = None, |
| min_level: float | None = None, |
| max_level: float | None = None, |
| pad_to: int | None = None, |
| include_metadata: bool = True, |
| ): |
| self.song_dataset = song_dataset |
| self.difficulty = difficulty |
| self.min_level = min_level |
| self.max_level = max_level |
| self.pad_to = pad_to |
| self.include_metadata = include_metadata |
|
|
| |
| self._items: list[tuple[int, int]] = [] |
| for si, song in enumerate(song_dataset.songs): |
| for idx, chart in song.charts.items(): |
| if difficulty is not None and idx != difficulty.value: |
| continue |
| if min_level is not None and chart.level_value < min_level: |
| continue |
| if max_level is not None and chart.level_value > max_level: |
| continue |
| self._items.append((si, idx)) |
|
|
| def __len__(self) -> int: |
| return len(self._items) |
|
|
| def __getitem__(self, idx: int): |
| si, ci = self._items[idx] |
| song = self.song_dataset[si] |
| chart = song.charts[ci] |
|
|
| tensor = chart_to_tensor(chart) |
|
|
| if self.pad_to is not None: |
| if tensor.shape[0] < self.pad_to: |
| pad = torch.zeros( |
| self.pad_to - tensor.shape[0], tensor.shape[1] |
| ) |
| tensor = torch.cat([tensor, pad], dim=0) |
|
|
| if self.include_metadata: |
| return tensor, { |
| "song_id": song.song_id, |
| "title": song.title_clean, |
| "artist": song.artist, |
| "bpm": song.bpm, |
| "level": chart.level_value, |
| "difficulty": chart.difficulty_index, |
| "cabinet": song.cabinet.value, |
| "note_count": chart.note_count, |
| } |
| return tensor |
|
|
| @staticmethod |
| def collate_fn(batch): |
| """Custom collate for variable-length sequences.""" |
| if isinstance(batch[0], tuple): |
| sequences = [item[0] for item in batch] |
| metadata = [item[1] for item in batch] |
| else: |
| sequences = batch |
| metadata = None |
|
|
| |
| lengths = torch.tensor([s.shape[0] for s in sequences]) |
| max_len = lengths.max().item() |
| feat_dim = sequences[0].shape[1] |
|
|
| padded = torch.zeros(len(sequences), max_len, feat_dim) |
| for i, seq in enumerate(sequences): |
| padded[i, :seq.shape[0]] = seq |
|
|
| if metadata is not None: |
| return padded, lengths, metadata |
| return padded, lengths |
|
|
| def stats(self) -> dict: |
| """Return dataset statistics.""" |
| levels = [] |
| note_counts = [] |
| for _, ci in self._items: |
| song_idx, chart_idx = self._items[0] |
| for si, ci in self._items: |
| chart = self.song_dataset[si].charts[ci] |
| levels.append(chart.level_value) |
| note_counts.append(chart.note_count) |
|
|
| return { |
| "total_charts": len(self._items), |
| "level_min": min(levels) if levels else 0, |
| "level_max": max(levels) if levels else 0, |
| "level_avg": sum(levels) / len(levels) if levels else 0, |
| "note_count_min": min(note_counts) if note_counts else 0, |
| "note_count_max": max(note_counts) if note_counts else 0, |
| "note_count_avg": sum(note_counts) / len(note_counts) |
| if note_counts else 0, |
| } |
|
|
| else: |
|
|
| class ChartDataset: |
| """Stub when PyTorch is not installed.""" |
| def __init__(self, *args, **kwargs): |
| raise ImportError( |
| "PyTorch is required for ChartDataset. " |
| "Install with: pip install torch" |
| ) |
|
|