File size: 2,256 Bytes
d171350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Difficulty constraints — enforce per-difficulty fret and chord rules."""

from midmid.datatypes import NoteEvent

ALLOWED_FRETS = {
    "easy":   {0, 1, 2},
    "medium": {0, 1, 2, 3},
    "hard":   {0, 1, 2, 3, 4},
    "expert": {0, 1, 2, 3, 4},
}

MAX_CHORD_SIZE = {
    "easy":   1,
    "medium": 2,
    "hard":   3,
    "expert": 5,
}

MIN_NOTE_SPACING = {
    "easy":   192,
    "medium": 96,
    "hard":   48,
    "expert": 0,
}


def enforce_constraints(
    notes: list[NoteEvent], difficulty: str, resolution: int = 192,
) -> list[NoteEvent]:
    allowed = ALLOWED_FRETS.get(difficulty, {0, 1, 2, 3, 4})
    max_chord = MAX_CHORD_SIZE.get(difficulty, 5)
    min_spacing = MIN_NOTE_SPACING.get(difficulty, 0)

    result = []
    for note in notes:
        filtered = note.fret_set & allowed
        if not filtered:
            for fret in sorted(note.fret_set):
                closest = min(allowed, key=lambda a: abs(a - fret))
                filtered.add(closest)
                break
        if not filtered:
            continue

        if len(filtered) > max_chord:
            filtered = set(sorted(filtered)[:max_chord])

        if min_spacing > 0 and result:
            if note.tick - result[-1].tick < min_spacing:
                continue

        if result and result[-1].sustain_ticks > 0:
            prev_end = result[-1].tick + result[-1].sustain_ticks
            if note.tick < prev_end:
                continue

        result.append(NoteEvent(
            tick=note.tick,
            fret_set=filtered,
            sustain_ticks=note.sustain_ticks,
            is_hopo=note.is_hopo,
        ))

        sixteenth = resolution // 4
        if len(result) >= 2 and result[-2].sustain_ticks > 0:
            prev = result[-2]
            max_sustain = note.tick - prev.tick - sixteenth
            max_sustain = (max_sustain // sixteenth) * sixteenth
            if max_sustain < sixteenth:
                max_sustain = 0
            if prev.sustain_ticks > max_sustain:
                result[-2] = NoteEvent(
                    tick=prev.tick,
                    fret_set=prev.fret_set,
                    sustain_ticks=max_sustain,
                    is_hopo=prev.is_hopo,
                )

    return result