"""Difficulty constraints — enforce per-difficulty fret and chord rules.""" from midmid.datatypes import NoteEvent ALLOWED_FRETS = { "easy": {0, 1, 2}, "medium": {0, 1, 2, 3}, "hard": {0, 1, 2, 3, 4}, "expert": {0, 1, 2, 3, 4}, } MAX_CHORD_SIZE = { "easy": 1, "medium": 2, "hard": 3, "expert": 5, } MIN_NOTE_SPACING = { "easy": 192, "medium": 96, "hard": 48, "expert": 0, } def enforce_constraints( notes: list[NoteEvent], difficulty: str, resolution: int = 192, ) -> list[NoteEvent]: allowed = ALLOWED_FRETS.get(difficulty, {0, 1, 2, 3, 4}) max_chord = MAX_CHORD_SIZE.get(difficulty, 5) min_spacing = MIN_NOTE_SPACING.get(difficulty, 0) result = [] for note in notes: filtered = note.fret_set & allowed if not filtered: for fret in sorted(note.fret_set): closest = min(allowed, key=lambda a: abs(a - fret)) filtered.add(closest) break if not filtered: continue if len(filtered) > max_chord: filtered = set(sorted(filtered)[:max_chord]) if min_spacing > 0 and result: if note.tick - result[-1].tick < min_spacing: continue if result and result[-1].sustain_ticks > 0: prev_end = result[-1].tick + result[-1].sustain_ticks if note.tick < prev_end: continue result.append(NoteEvent( tick=note.tick, fret_set=filtered, sustain_ticks=note.sustain_ticks, is_hopo=note.is_hopo, )) sixteenth = resolution // 4 if len(result) >= 2 and result[-2].sustain_ticks > 0: prev = result[-2] max_sustain = note.tick - prev.tick - sixteenth max_sustain = (max_sustain // sixteenth) * sixteenth if max_sustain < sixteenth: max_sustain = 0 if prev.sustain_ticks > max_sustain: result[-2] = NoteEvent( tick=prev.tick, fret_set=prev.fret_set, sustain_ticks=max_sustain, is_hopo=prev.is_hopo, ) return result