maiChartGen / mai_parser /dataset.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
15.4 kB
"""
Dataset loader for maimai training data.
Provides:
- SongDataset: scans datasets/ folder, parses all songs
- ChartDataset: PyTorch Dataset for chart note sequences
- Utility functions for note-to-tensor conversion
"""
from __future__ import annotations
import json
import pickle
from pathlib import Path
from typing import Iterator, Optional, Union
from .models import Cabinet, Chart, Difficulty, Song, TouchNote
from .parser import parse_maidata_file
# Try importing PyTorch (optional)
try:
import torch
from torch.utils.data import Dataset as TorchDataset
HAS_TORCH = True
except (ImportError, OSError):
HAS_TORCH = False
class SongDataset:
"""
A collection of parsed Song objects from the datasets/ directory.
Usage:
ds = SongDataset("datasets")
ds.scan() # discover all songs
ds.load(use_cache=True) # parse all songs (with caching)
print(len(ds)) # total songs
song = ds[0] # first song
"""
def __init__(self, root_dir: str | Path):
self.root_dir = Path(root_dir)
self.songs: list[Song] = []
self._by_id: dict[str, Song] = {}
self._cache_path = self.root_dir / ".mai_parser_cache.pkl"
def scan(self) -> list[Path]:
"""
Scan the dataset directory for all song folders (containing maidata.txt).
Returns list of found song folder paths.
"""
folders = []
for item in sorted(self.root_dir.iterdir()):
if item.is_dir() and (item / "maidata.txt").exists():
folders.append(item)
return folders
def load(self, use_cache: bool = True, max_workers: int = 8,
progress: bool = True) -> "SongDataset":
"""
Parse all songs from discovered folders.
Args:
use_cache: If True, try loading from cache first; save after parsing.
max_workers: Number of threads for parallel parsing.
progress: Print progress information.
"""
if use_cache and self._cache_path.exists():
try:
with open(self._cache_path, "rb") as f:
self.songs = pickle.load(f)
self._build_index()
if progress:
print(f"Loaded {len(self.songs)} songs from cache.")
return self
except (pickle.UnpicklingError, EOFError, KeyError) as e:
if progress:
print(f"Cache invalid ({e}), re-parsing...")
folders = self.scan()
if progress:
print(f"Found {len(folders)} song folders. Parsing...")
from concurrent.futures import ThreadPoolExecutor, as_completed
songs: list[Song] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(parse_maidata_file, f / "maidata.txt"): f
for f in folders
}
for i, future in enumerate(as_completed(futures)):
folder = futures[future]
try:
song = future.result()
song.maidata_path = str(folder / "maidata.txt")
song.audio_path = str(folder / "track.mp3")
songs.append(song)
except Exception as e:
if progress:
print(f" [WARN] Failed to parse {folder.name}: {e}")
if progress and (i + 1) % 200 == 0:
print(f" ... parsed {i + 1}/{len(folders)}")
# Sort by song_id numerically if possible
songs.sort(key=lambda s: _sort_key(s.song_id))
self.songs = songs
self._build_index()
if use_cache:
try:
with open(self._cache_path, "wb") as f:
pickle.dump(self.songs, f, protocol=pickle.HIGHEST_PROTOCOL)
if progress:
print(f"Saved cache to {self._cache_path}")
except OSError as e:
if progress:
print(f" [WARN] Could not save cache: {e}")
if progress:
print(f"Loaded {len(self.songs)} songs.")
return self
def _build_index(self) -> None:
self._by_id = {s.song_id: s for s in self.songs}
def __len__(self) -> int:
return len(self.songs)
def __getitem__(self, idx: int) -> Song:
return self.songs[idx]
def __iter__(self) -> Iterator[Song]:
return iter(self.songs)
def get_by_id(self, song_id: str) -> Song | None:
"""Look up a song by its folder name / ID."""
return self._by_id.get(song_id)
def filter(
self,
cabinet: Cabinet | None = None,
min_level: float | None = None,
max_level: float | None = None,
difficulty: Difficulty | None = None,
genre: str | None = None,
has_audio: bool = False,
) -> list[Song]:
"""
Filter songs by various criteria.
Args:
cabinet: Filter by cabinet type (SD/DX).
min_level: Minimum chart level.
max_level: Maximum chart level.
difficulty: Only songs with this difficulty.
genre: Filter by genre name (partial match).
has_audio: Only songs where track.mp3 exists.
"""
result = []
for song in self.songs:
if cabinet is not None and song.cabinet != cabinet:
continue
if genre is not None and genre.lower() not in song.genre.lower():
continue
if has_audio:
audio_path = self.root_dir / song.song_id / "track.mp3"
if not audio_path.exists():
continue
if difficulty is not None:
chart = song.get_chart(difficulty)
if chart is None:
continue
if min_level is not None and chart.level_value < min_level:
continue
if max_level is not None and chart.level_value > max_level:
continue
result.append(song)
return result
def stats(self) -> dict:
"""Return summary statistics of the dataset."""
total = len(self.songs)
sd_count = sum(1 for s in self.songs if s.cabinet == Cabinet.SD)
dx_count = sum(1 for s in self.songs if s.cabinet == Cabinet.DX)
unknown_cab = total - sd_count - dx_count
total_charts = sum(len(s.charts) for s in self.songs)
genres = list(set(s.genre for s in self.songs if s.genre))
# Level distribution per difficulty
level_dist: dict[str, list[float]] = {}
for song in self.songs:
for idx, chart in song.charts.items():
key = f"lv_{idx}"
if key not in level_dist:
level_dist[key] = []
level_dist[key].append(chart.level_value)
return {
"total_songs": total,
"sd_songs": sd_count,
"dx_songs": dx_count,
"unknown_cabinet": unknown_cab,
"total_charts": total_charts,
"genres": sorted(genres),
"level_distribution": {
k: {
"min": min(v) if v else 0,
"max": max(v) if v else 0,
"avg": sum(v) / len(v) if v else 0,
}
for k, v in level_dist.items()
},
}
def _sort_key(song_id: str) -> tuple:
"""Sort: numeric IDs first, then strings."""
try:
return (0, int(song_id))
except ValueError:
return (1, song_id)
# โ”€โ”€โ”€ Note to Tensor Conversion โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Note type encoding for training
NOTE_TYPE = {
"rest": 0,
"tap": 1,
"hold": 2,
"slide": 3,
"break": 4,
"touch": 5,
"end": 6,
}
# Reverse mapping
NOTE_TYPE_INV = {v: k for k, v in NOTE_TYPE.items()}
# Number of button positions (1-8)
NUM_POSITIONS = 8
def note_to_vector(note: TouchNote, num_positions: int = NUM_POSITIONS) -> list[float]:
"""
Convert a single TouchNote to a feature vector.
Format (fixed-length):
[type_rest, type_tap, type_hold, type_slide, type_break, type_touch, type_end,
pos_1, pos_2, ..., pos_N,
is_simultaneous, has_slide_path_len_2, has_slide_path_len_3, has_slide_path_len_4,
hold_beat, hold_subdiv,
beat_div_normalized]
Returns a flat list of floats.
"""
vec = [0.0] * (7 + num_positions + 4 + 2 + 1)
# Note type (one-hot)
if note.is_end:
vec[6] = 1.0
elif note.is_rest:
vec[0] = 1.0
elif note.is_touch:
vec[5] = 1.0
elif note.is_break:
vec[4] = 1.0
elif note.is_hold:
vec[2] = 1.0
elif note.is_slide:
vec[3] = 1.0
else:
vec[1] = 1.0 # tap
# Button positions (multi-hot)
offset = 7
for pos in note.positions:
if 1 <= pos <= num_positions:
vec[offset + pos - 1] = 1.0
# Slide path length flag
offset += num_positions
sl = len(note.slide_path)
if sl >= 2:
vec[offset] = 1.0
if sl >= 3:
vec[offset + 1] = 1.0
if sl >= 4:
vec[offset + 2] = 1.0
vec[offset + 3] = 1.0 if note.is_simultaneous else 0.0
# Hold duration
offset += 4
if note.hold_duration:
vec[offset] = float(note.hold_duration[0])
vec[offset + 1] = float(note.hold_duration[1])
# Beat division (normalized)
vec[-1] = note.beat_div
return vec
def chart_to_tensor(chart: Chart) -> "torch.Tensor":
"""
Convert an entire chart's notes to a PyTorch tensor.
Returns tensor of shape (num_notes, feature_dim).
"""
if not HAS_TORCH:
raise ImportError("PyTorch is required for tensor conversion. "
"Install with: pip install torch")
vectors = [note_to_vector(n) for n in chart.notes]
return torch.tensor(vectors, dtype=torch.float32)
def chart_to_sequence(chart: Chart) -> list[list[float]]:
"""Convert a chart to a list of feature vectors (no PyTorch dependency)."""
return [note_to_vector(n) for n in chart.notes]
# โ”€โ”€โ”€ PyTorch Dataset (if torch available) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
if HAS_TORCH:
class ChartDataset(TorchDataset):
"""
PyTorch Dataset for maimai chart training.
Yields (note_sequence, metadata) pairs where:
- note_sequence is a tensor of shape (seq_len, feature_dim)
- metadata is a dict with song info
Usage:
from torch.utils.data import DataLoader
ds = ChartDataset(song_dataset, difficulty=Difficulty.MASTER)
loader = DataLoader(ds, batch_size=32, collate_fn=ds.collate_fn)
"""
def __init__(
self,
song_dataset: SongDataset,
difficulty: Difficulty | None = None,
min_level: float | None = None,
max_level: float | None = None,
pad_to: int | None = None,
include_metadata: bool = True,
):
self.song_dataset = song_dataset
self.difficulty = difficulty
self.min_level = min_level
self.max_level = max_level
self.pad_to = pad_to
self.include_metadata = include_metadata
# Build filtered list of (song_index, chart_index)
self._items: list[tuple[int, int]] = []
for si, song in enumerate(song_dataset.songs):
for idx, chart in song.charts.items():
if difficulty is not None and idx != difficulty.value:
continue
if min_level is not None and chart.level_value < min_level:
continue
if max_level is not None and chart.level_value > max_level:
continue
self._items.append((si, idx))
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, idx: int):
si, ci = self._items[idx]
song = self.song_dataset[si]
chart = song.charts[ci]
tensor = chart_to_tensor(chart)
if self.pad_to is not None:
if tensor.shape[0] < self.pad_to:
pad = torch.zeros(
self.pad_to - tensor.shape[0], tensor.shape[1]
)
tensor = torch.cat([tensor, pad], dim=0)
if self.include_metadata:
return tensor, {
"song_id": song.song_id,
"title": song.title_clean,
"artist": song.artist,
"bpm": song.bpm,
"level": chart.level_value,
"difficulty": chart.difficulty_index,
"cabinet": song.cabinet.value,
"note_count": chart.note_count,
}
return tensor
@staticmethod
def collate_fn(batch):
"""Custom collate for variable-length sequences."""
if isinstance(batch[0], tuple):
sequences = [item[0] for item in batch]
metadata = [item[1] for item in batch]
else:
sequences = batch
metadata = None
# Pad sequences to max length in batch
lengths = torch.tensor([s.shape[0] for s in sequences])
max_len = lengths.max().item()
feat_dim = sequences[0].shape[1]
padded = torch.zeros(len(sequences), max_len, feat_dim)
for i, seq in enumerate(sequences):
padded[i, :seq.shape[0]] = seq
if metadata is not None:
return padded, lengths, metadata
return padded, lengths
def stats(self) -> dict:
"""Return dataset statistics."""
levels = []
note_counts = []
for _, ci in self._items:
song_idx, chart_idx = self._items[0]
for si, ci in self._items:
chart = self.song_dataset[si].charts[ci]
levels.append(chart.level_value)
note_counts.append(chart.note_count)
return {
"total_charts": len(self._items),
"level_min": min(levels) if levels else 0,
"level_max": max(levels) if levels else 0,
"level_avg": sum(levels) / len(levels) if levels else 0,
"note_count_min": min(note_counts) if note_counts else 0,
"note_count_max": max(note_counts) if note_counts else 0,
"note_count_avg": sum(note_counts) / len(note_counts)
if note_counts else 0,
}
else:
class ChartDataset:
"""Stub when PyTorch is not installed."""
def __init__(self, *args, **kwargs):
raise ImportError(
"PyTorch is required for ChartDataset. "
"Install with: pip install torch"
)