FJFehr commited on
Commit
f7e26a4
Β·
1 Parent(s): 9de8877

First iteration of a midi model intergration.

Browse files
Files changed (7) hide show
  1. .gitignore +6 -0
  2. README.md +12 -0
  3. engines.py +69 -49
  4. keyboard.html +1 -0
  5. midi_model.py +376 -0
  6. requirements.txt +5 -1
  7. static/keyboard.js +25 -0
.gitignore CHANGED
@@ -26,6 +26,12 @@ share/python-wheels/
26
  .installed.cfg
27
  *.egg
28
 
 
 
 
 
 
 
29
  # Virtual environments
30
  .venv/
31
  venv/
 
26
  .installed.cfg
27
  *.egg
28
 
29
+ # External model dependencies
30
+ external/
31
+
32
+ # Output files
33
+ output/
34
+
35
  # Virtual environments
36
  .venv/
37
  venv/
README.md CHANGED
@@ -14,6 +14,9 @@ short_description: Browser-based MIDI keyboard with recording and synthesis
14
 
15
  A minimal, responsive browser-based MIDI keyboard. Play live, record performances, and export as MIDI files. 🎹
16
 
 
 
 
17
 
18
  ## πŸ—‚οΈ Project Structure
19
 
@@ -22,6 +25,7 @@ A minimal, responsive browser-based MIDI keyboard. Play live, record performance
22
  β”œβ”€β”€ app.py # Gradio server & API endpoints
23
  β”œβ”€β”€ config.py # Centralized configuration
24
  β”œβ”€β”€ engines.py # MIDI processing engines
 
25
  β”œβ”€β”€ midi.py # MIDI file utilities
26
  β”œβ”€β”€ keyboard.html # HTML structure
27
  β”œβ”€β”€ static/
@@ -43,6 +47,13 @@ uv run python app.py
43
 
44
  Open **http://127.0.0.1:7861**
45
 
 
 
 
 
 
 
 
46
  ## 🌐 Deploy to Hugging Face Spaces
47
 
48
  ```bash
@@ -55,6 +66,7 @@ git push hf main
55
  - **Frontend**: Tone.js v6+ (Web Audio API)
56
  - **Backend**: Gradio 6.x + Python 3.10+
57
  - **MIDI**: mido library
 
58
 
59
  ## πŸ“ License
60
 
 
14
 
15
  A minimal, responsive browser-based MIDI keyboard. Play live, record performances, and export as MIDI files. 🎹
16
 
17
+ This build includes a **Godzilla** engine that can continue a short phrase using the
18
+ Godzilla Piano Transformer.
19
+
20
 
21
  ## πŸ—‚οΈ Project Structure
22
 
 
25
  β”œβ”€β”€ app.py # Gradio server & API endpoints
26
  β”œβ”€β”€ config.py # Centralized configuration
27
  β”œβ”€β”€ engines.py # MIDI processing engines
28
+ β”œβ”€β”€ midi_model.py # Godzilla model integration
29
  β”œβ”€β”€ midi.py # MIDI file utilities
30
  β”œβ”€β”€ keyboard.html # HTML structure
31
  β”œβ”€β”€ static/
 
47
 
48
  Open **http://127.0.0.1:7861**
49
 
50
+ ## 🎹 Godzilla Engine
51
+
52
+ Select **Godzilla** in the engine dropdown to generate a short continuation from your
53
+ recorded phrase. The model is downloaded on first use and cached locally.
54
+
55
+ Note: the engine filters generated notes to your on-screen keyboard range.
56
+
57
  ## 🌐 Deploy to Hugging Face Spaces
58
 
59
  ```bash
 
66
  - **Frontend**: Tone.js v6+ (Web Audio API)
67
  - **Backend**: Gradio 6.x + Python 3.10+
68
  - **MIDI**: mido library
69
+ - **Model**: Godzilla Piano Transformer (via Hugging Face)
70
 
71
  ## πŸ“ License
72
 
engines.py CHANGED
@@ -4,33 +4,13 @@ Virtual MIDI Keyboard - Engines
4
  MIDI processing engines that transform, analyze, or manipulate MIDI events.
5
  """
6
 
7
- from abc import ABC, abstractmethod
8
  from typing import List, Dict, Any
9
 
10
-
11
- # =============================================================================
12
- # BASE ENGINE CLASS
13
- # =============================================================================
14
-
15
-
16
- class MIDIEngine(ABC):
17
- """Abstract base class for MIDI engines"""
18
-
19
- def __init__(self, name: str):
20
- self.name = name
21
-
22
- @abstractmethod
23
- def process(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
24
- """
25
- Process MIDI events and return transformed events.
26
-
27
- Args:
28
- events: List of MIDI event dictionaries
29
-
30
- Returns:
31
- List of processed MIDI event dictionaries
32
- """
33
- pass
34
 
35
 
36
  # =============================================================================
@@ -38,7 +18,7 @@ class MIDIEngine(ABC):
38
  # =============================================================================
39
 
40
 
41
- class ParrotEngine(MIDIEngine):
42
  """
43
  Parrot Engine - plays back MIDI exactly as recorded.
44
 
@@ -46,7 +26,7 @@ class ParrotEngine(MIDIEngine):
46
  """
47
 
48
  def __init__(self):
49
- super().__init__("Parrot")
50
 
51
  def process(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
52
  """Return events unchanged"""
@@ -69,7 +49,7 @@ class ParrotEngine(MIDIEngine):
69
  # =============================================================================
70
 
71
 
72
- class ReverseParrotEngine(MIDIEngine):
73
  """
74
  Reverse Parrot Engine - plays back MIDI in reverse order.
75
 
@@ -78,7 +58,7 @@ class ReverseParrotEngine(MIDIEngine):
78
  """
79
 
80
  def __init__(self):
81
- super().__init__("Reverse Parrot")
82
 
83
  def process(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
84
  """Reverse the sequence of note numbers while keeping timing and event types"""
@@ -88,43 +68,79 @@ class ReverseParrotEngine(MIDIEngine):
88
  # Separate note_on and note_off events
89
  note_on_events = [e for e in events if e.get("type") == "note_on"]
90
  note_off_events = [e for e in events if e.get("type") == "note_off"]
91
-
92
  # Extract note numbers from note_on events and reverse them
93
  on_notes = [e.get("note") for e in note_on_events]
94
  reversed_on_notes = list(reversed(on_notes))
95
-
96
  # Extract note numbers from note_off events and reverse them
97
  off_notes = [e.get("note") for e in note_off_events]
98
  reversed_off_notes = list(reversed(off_notes))
99
-
100
  # Reconstruct events with reversed notes but original structure
101
  result = []
102
  on_index = 0
103
  off_index = 0
104
-
105
  for event in events:
106
  if event.get("type") == "note_on":
107
- result.append({
108
- "type": "note_on",
109
- "note": reversed_on_notes[on_index],
110
- "velocity": event.get("velocity"),
111
- "time": event.get("time"),
112
- "channel": event.get("channel", 0),
113
- })
 
 
114
  on_index += 1
115
  elif event.get("type") == "note_off":
116
- result.append({
117
- "type": "note_off",
118
- "note": reversed_off_notes[off_index],
119
- "velocity": event.get("velocity"),
120
- "time": event.get("time"),
121
- "channel": event.get("channel", 0),
122
- })
 
 
123
  off_index += 1
124
 
125
  return result
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # =============================================================================
129
  # ENGINE REGISTRY
130
  # =============================================================================
@@ -133,7 +149,11 @@ class ReverseParrotEngine(MIDIEngine):
133
  class EngineRegistry:
134
  """Registry for managing available MIDI engines"""
135
 
136
- _engines = {"parrot": ParrotEngine, "reverse_parrot": ReverseParrotEngine}
 
 
 
 
137
 
138
  @classmethod
139
  def register(cls, engine_id: str, engine_class: type):
@@ -141,7 +161,7 @@ class EngineRegistry:
141
  cls._engines[engine_id] = engine_class
142
 
143
  @classmethod
144
- def get_engine(cls, engine_id: str) -> MIDIEngine:
145
  """Get an engine instance by ID"""
146
  if engine_id not in cls._engines:
147
  raise ValueError(f"Unknown engine: {engine_id}")
 
4
  MIDI processing engines that transform, analyze, or manipulate MIDI events.
5
  """
6
 
 
7
  from typing import List, Dict, Any
8
 
9
+ from midi_model import (
10
+ count_out_of_range_events,
11
+ filter_events_to_keyboard_range,
12
+ get_model,
13
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  # =============================================================================
 
18
  # =============================================================================
19
 
20
 
21
+ class ParrotEngine:
22
  """
23
  Parrot Engine - plays back MIDI exactly as recorded.
24
 
 
26
  """
27
 
28
  def __init__(self):
29
+ self.name = "Parrot"
30
 
31
  def process(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
32
  """Return events unchanged"""
 
49
  # =============================================================================
50
 
51
 
52
+ class ReverseParrotEngine:
53
  """
54
  Reverse Parrot Engine - plays back MIDI in reverse order.
55
 
 
58
  """
59
 
60
  def __init__(self):
61
+ self.name = "Reverse Parrot"
62
 
63
  def process(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
64
  """Reverse the sequence of note numbers while keeping timing and event types"""
 
68
  # Separate note_on and note_off events
69
  note_on_events = [e for e in events if e.get("type") == "note_on"]
70
  note_off_events = [e for e in events if e.get("type") == "note_off"]
71
+
72
  # Extract note numbers from note_on events and reverse them
73
  on_notes = [e.get("note") for e in note_on_events]
74
  reversed_on_notes = list(reversed(on_notes))
75
+
76
  # Extract note numbers from note_off events and reverse them
77
  off_notes = [e.get("note") for e in note_off_events]
78
  reversed_off_notes = list(reversed(off_notes))
79
+
80
  # Reconstruct events with reversed notes but original structure
81
  result = []
82
  on_index = 0
83
  off_index = 0
84
+
85
  for event in events:
86
  if event.get("type") == "note_on":
87
+ result.append(
88
+ {
89
+ "type": "note_on",
90
+ "note": reversed_on_notes[on_index],
91
+ "velocity": event.get("velocity"),
92
+ "time": event.get("time"),
93
+ "channel": event.get("channel", 0),
94
+ }
95
+ )
96
  on_index += 1
97
  elif event.get("type") == "note_off":
98
+ result.append(
99
+ {
100
+ "type": "note_off",
101
+ "note": reversed_off_notes[off_index],
102
+ "velocity": event.get("velocity"),
103
+ "time": event.get("time"),
104
+ "channel": event.get("channel", 0),
105
+ }
106
+ )
107
  off_index += 1
108
 
109
  return result
110
 
111
 
112
+ # =============================================================================
113
+ # GODZILLA CONTINUATION ENGINE
114
+ # =============================================================================
115
+
116
+
117
+ class GodzillaContinuationEngine:
118
+ """
119
+ Continue a short MIDI phrase with the Godzilla Piano Transformer.
120
+
121
+ Generates a small continuation and appends it after the input events.
122
+ """
123
+
124
+ def __init__(self, generate_tokens: int = 32):
125
+ self.name = "Godzilla"
126
+ self.generate_tokens = generate_tokens
127
+
128
+ def process(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
129
+ if not events:
130
+ return []
131
+
132
+ model = get_model("godzilla")
133
+ new_events = model.generate_continuation(
134
+ events,
135
+ tokens=self.generate_tokens,
136
+ seed=None,
137
+ )
138
+ out_of_range = count_out_of_range_events(new_events)
139
+ if out_of_range:
140
+ print(f"Godzilla: dropped {out_of_range} out-of-range events")
141
+ return filter_events_to_keyboard_range(new_events)
142
+
143
+
144
  # =============================================================================
145
  # ENGINE REGISTRY
146
  # =============================================================================
 
149
  class EngineRegistry:
150
  """Registry for managing available MIDI engines"""
151
 
152
+ _engines = {
153
+ "parrot": ParrotEngine,
154
+ "reverse_parrot": ReverseParrotEngine,
155
+ "godzilla_continue": GodzillaContinuationEngine,
156
+ }
157
 
158
  @classmethod
159
  def register(cls, engine_id: str, engine_class: type):
 
161
  cls._engines[engine_id] = engine_class
162
 
163
  @classmethod
164
+ def get_engine(cls, engine_id: str):
165
  """Get an engine instance by ID"""
166
  if engine_id not in cls._engines:
167
  raise ValueError(f"Unknown engine: {engine_id}")
keyboard.html CHANGED
@@ -35,6 +35,7 @@
35
  <select id="engineSelect">
36
  <option value="parrot">Parrot</option>
37
  <option value="reverse_parrot">Reverse Parrot</option>
 
38
  </select>
39
  </label>
40
 
 
35
  <select id="engineSelect">
36
  <option value="parrot">Parrot</option>
37
  <option value="reverse_parrot">Reverse Parrot</option>
38
+ <option value="godzilla_continue">Godzilla</option>
39
  </select>
40
  </label>
41
 
midi_model.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import subprocess
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Iterable, Optional
9
+
10
+ from config import MIDI_DEFAULTS, KEYBOARD_BASE_MIDI, KEYBOARD_OCTAVES
11
+
12
+
13
+ DEFAULT_REPO = "asigalov61/Godzilla-Piano-Transformer"
14
+ DEFAULT_FILENAME = (
15
+ "Godzilla_Piano_Chords_Texturing_Transformer_Trained_Model_22708_steps_"
16
+ "0.7515_loss_0.7853_acc.pth"
17
+ )
18
+
19
+ _MODEL_CACHE: dict[str, object] = {}
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class MidiModel:
24
+ model_id: str
25
+ name: str
26
+
27
+ def generate_continuation(
28
+ self,
29
+ events: list[dict],
30
+ *,
31
+ tokens: int = 32,
32
+ seed: Optional[int] = None,
33
+ ) -> list[dict]:
34
+ raise NotImplementedError
35
+
36
+
37
+ def ensure_tegridy_tools(base_dir: Path) -> tuple[Path, Path]:
38
+ repo_dir = base_dir / "tegridy-tools"
39
+ tools_dir = repo_dir / "tegridy-tools"
40
+ x_transformer_dir = tools_dir / "X-Transformer"
41
+
42
+ if not x_transformer_dir.exists():
43
+ repo_url = "https://github.com/asigalov61/tegridy-tools"
44
+ repo_dir.parent.mkdir(parents=True, exist_ok=True)
45
+ try:
46
+ subprocess.check_call(
47
+ [
48
+ "git",
49
+ "clone",
50
+ "--depth",
51
+ "1",
52
+ repo_url,
53
+ str(repo_dir),
54
+ ]
55
+ )
56
+ except FileNotFoundError as exc:
57
+ raise RuntimeError("git is required to clone tegridy-tools") from exc
58
+
59
+ return tools_dir, x_transformer_dir
60
+
61
+
62
+ def add_sys_path(*paths: Path) -> None:
63
+ for path in paths:
64
+ path_str = str(path.resolve())
65
+ if path_str not in sys.path:
66
+ sys.path.insert(0, path_str)
67
+
68
+
69
+ def build_model(seq_len: int, pad_idx: int):
70
+ from x_transformer_2_3_1 import AutoregressiveWrapper, Decoder, TransformerWrapper
71
+
72
+ model = TransformerWrapper(
73
+ num_tokens=pad_idx + 1,
74
+ max_seq_len=seq_len,
75
+ attn_layers=Decoder(
76
+ dim=2048,
77
+ depth=8,
78
+ heads=32,
79
+ rotary_pos_emb=True,
80
+ attn_flash=True,
81
+ ),
82
+ )
83
+ return AutoregressiveWrapper(model, ignore_index=pad_idx, pad_value=pad_idx)
84
+
85
+
86
+ def resolve_device(requested: str) -> str:
87
+ import torch
88
+
89
+ if requested == "auto":
90
+ return "cuda" if torch.cuda.is_available() else "cpu"
91
+ return requested
92
+
93
+
94
+ def load_checkpoint(model, checkpoint_path: Path, device: str) -> None:
95
+ import torch
96
+
97
+ state = torch.load(checkpoint_path, map_location=device)
98
+ model.load_state_dict(state)
99
+
100
+
101
+ def events_to_score_tokens(events: list[dict]) -> list[int]:
102
+ if not events:
103
+ return []
104
+
105
+ active: dict[int, float] = {}
106
+ notes: list[tuple[float, float, int]] = []
107
+ sorted_events = sorted(events, key=lambda e: e.get("time", 0.0))
108
+
109
+ for event in sorted_events:
110
+ ev_type = event.get("type")
111
+ note = int(event.get("note", 0))
112
+ velocity = int(event.get("velocity", 0))
113
+ time_sec = float(event.get("time", 0.0))
114
+
115
+ if ev_type == "note_on" and velocity > 0:
116
+ active[note] = time_sec
117
+ elif ev_type in {"note_off", "note_on"}:
118
+ if note in active:
119
+ start = active.pop(note)
120
+ duration = max(0.0, time_sec - start)
121
+ notes.append((start, duration, note))
122
+
123
+ if not notes:
124
+ return []
125
+
126
+ notes.sort(key=lambda n: n[0])
127
+ tokens: list[int] = []
128
+ prev_start_ms = 0.0
129
+
130
+ for start, duration, pitch in notes:
131
+ start_ms = round(start * 1000.0)
132
+ delta_ms = max(0.0, start_ms - prev_start_ms)
133
+ prev_start_ms = start_ms
134
+
135
+ time_tok = max(0, min(127, int(round(delta_ms / 32.0))))
136
+ dur_tok = max(1, min(127, int(round((duration * 1000.0) / 32.0))))
137
+ pitch_tok = max(0, min(127, int(pitch)))
138
+
139
+ tokens.extend([time_tok, 128 + dur_tok, 256 + pitch_tok])
140
+
141
+ return tokens
142
+
143
+
144
+ def tokens_to_events(
145
+ tokens: Iterable[int],
146
+ *,
147
+ offset_ms: float = 0.0,
148
+ velocity: int | None = None,
149
+ ) -> list[dict]:
150
+ if velocity is None:
151
+ velocity = MIDI_DEFAULTS["velocity_default"]
152
+
153
+ events: list[dict] = []
154
+ time_ms = offset_ms
155
+ duration_ms = 1
156
+ pitch = 60
157
+
158
+ for tok in tokens:
159
+ if 0 <= tok < 128:
160
+ time_ms += tok * 32
161
+ elif 128 < tok < 256:
162
+ duration_ms = (tok - 128) * 32
163
+ elif 256 < tok < 384:
164
+ pitch = tok - 256
165
+ on_time = time_ms / 1000.0
166
+ off_time = (time_ms + duration_ms) / 1000.0
167
+ events.append(
168
+ {
169
+ "type": "note_on",
170
+ "note": pitch,
171
+ "velocity": velocity,
172
+ "time": on_time,
173
+ "channel": 0,
174
+ }
175
+ )
176
+ events.append(
177
+ {
178
+ "type": "note_off",
179
+ "note": pitch,
180
+ "velocity": 0,
181
+ "time": off_time,
182
+ "channel": 0,
183
+ }
184
+ )
185
+
186
+ return events
187
+
188
+
189
+ def keyboard_note_range() -> tuple[int, int]:
190
+ min_note = KEYBOARD_BASE_MIDI
191
+ max_note = KEYBOARD_BASE_MIDI + (KEYBOARD_OCTAVES * 12) - 1
192
+ return min_note, max_note
193
+
194
+
195
+ def count_out_of_range_events(events: list[dict]) -> int:
196
+ min_note, max_note = keyboard_note_range()
197
+ return sum(
198
+ 1
199
+ for event in events
200
+ if event.get("type") in {"note_on", "note_off"}
201
+ and int(event.get("note", min_note)) not in range(min_note, max_note + 1)
202
+ )
203
+
204
+
205
+ def filter_events_to_keyboard_range(events: list[dict]) -> list[dict]:
206
+ min_note, max_note = keyboard_note_range()
207
+ return [
208
+ event
209
+ for event in events
210
+ if event.get("type") not in {"note_on", "note_off"}
211
+ or min_note <= int(event.get("note", min_note)) <= max_note
212
+ ]
213
+
214
+
215
+ def build_prime_tokens(score_tokens: list[int], seq_len: int) -> list[int]:
216
+ prime = [705, 384, 706]
217
+ if score_tokens:
218
+ max_score = max(0, seq_len - len(prime))
219
+ prime.extend(score_tokens[-max_score:])
220
+ else:
221
+ prime.extend([0, 129, 316])
222
+ return prime
223
+
224
+
225
+ def load_model_cached(
226
+ *,
227
+ repo: str,
228
+ filename: str,
229
+ cache_dir: Path,
230
+ tegridy_dir: Path,
231
+ seq_len: int,
232
+ pad_idx: int,
233
+ device: str,
234
+ ) -> tuple[object, str, Path]:
235
+ from huggingface_hub import hf_hub_download
236
+ import torch
237
+
238
+ cache_dir.mkdir(parents=True, exist_ok=True)
239
+ resolved_device = resolve_device(device)
240
+ cache_key = f"{repo}:{filename}:{seq_len}:{pad_idx}:{resolved_device}"
241
+
242
+ if _MODEL_CACHE.get("key") == cache_key:
243
+ return (
244
+ _MODEL_CACHE["model"],
245
+ _MODEL_CACHE["device"],
246
+ _MODEL_CACHE["tools_dir"],
247
+ )
248
+
249
+ checkpoint_path = Path(
250
+ hf_hub_download(
251
+ repo_id=repo,
252
+ filename=filename,
253
+ local_dir=str(cache_dir),
254
+ repo_type="model",
255
+ )
256
+ )
257
+
258
+ tools_dir, x_transformer_dir = ensure_tegridy_tools(tegridy_dir)
259
+ add_sys_path(x_transformer_dir)
260
+
261
+ if resolved_device == "cuda":
262
+ torch.set_float32_matmul_precision("high")
263
+ torch.backends.cuda.matmul.allow_tf32 = True
264
+ torch.backends.cudnn.allow_tf32 = True
265
+
266
+ model = build_model(seq_len, pad_idx)
267
+ load_checkpoint(model, checkpoint_path, resolved_device)
268
+ model.to(resolved_device)
269
+ model.eval()
270
+
271
+ _MODEL_CACHE["key"] = cache_key
272
+ _MODEL_CACHE["model"] = model
273
+ _MODEL_CACHE["device"] = resolved_device
274
+ _MODEL_CACHE["tools_dir"] = tools_dir
275
+ _MODEL_CACHE["checkpoint_path"] = checkpoint_path
276
+
277
+ return model, resolved_device, tools_dir
278
+
279
+
280
+ def generate_from_events(
281
+ events: list[dict],
282
+ *,
283
+ generate_tokens: int,
284
+ seed: int | None,
285
+ repo: str,
286
+ filename: str,
287
+ cache_dir: Path,
288
+ tegridy_dir: Path,
289
+ seq_len: int,
290
+ pad_idx: int,
291
+ device: str,
292
+ ) -> tuple[list[dict], list[int]]:
293
+ import torch
294
+
295
+ model, resolved_device, _ = load_model_cached(
296
+ repo=repo,
297
+ filename=filename,
298
+ cache_dir=cache_dir,
299
+ tegridy_dir=tegridy_dir,
300
+ seq_len=seq_len,
301
+ pad_idx=pad_idx,
302
+ device=device,
303
+ )
304
+
305
+ if seed is not None:
306
+ torch.manual_seed(seed)
307
+ if resolved_device == "cuda":
308
+ torch.cuda.manual_seed_all(seed)
309
+
310
+ score_tokens = events_to_score_tokens(events)
311
+ prime = build_prime_tokens(score_tokens, seq_len)
312
+ prime_tensor = torch.tensor(prime, dtype=torch.long, device=resolved_device)
313
+
314
+ out = model.generate(
315
+ prime_tensor,
316
+ generate_tokens,
317
+ return_prime=True,
318
+ eos_token=707,
319
+ )
320
+
321
+ tokens = out.detach().cpu().tolist()
322
+ new_tokens = tokens[len(prime) :]
323
+
324
+ last_time_ms = 0.0
325
+ if events:
326
+ last_time_ms = max(float(e.get("time", 0.0)) for e in events) * 1000.0
327
+
328
+ new_events = tokens_to_events(new_tokens, offset_ms=last_time_ms)
329
+ return new_events, new_tokens
330
+
331
+
332
+ def generate_godzilla_continuation(
333
+ events: list[dict],
334
+ *,
335
+ generate_tokens: int = 32,
336
+ seed: int | None = None,
337
+ device: str = "auto",
338
+ ) -> tuple[list[dict], list[int]]:
339
+ return generate_from_events(
340
+ events,
341
+ generate_tokens=generate_tokens,
342
+ seed=seed,
343
+ repo=DEFAULT_REPO,
344
+ filename=DEFAULT_FILENAME,
345
+ cache_dir=Path(".cache/godzilla"),
346
+ tegridy_dir=Path("external"),
347
+ seq_len=1536,
348
+ pad_idx=708,
349
+ device=device,
350
+ )
351
+
352
+
353
+ class GodzillaMidiModel(MidiModel):
354
+ def __init__(self) -> None:
355
+ super().__init__(model_id="godzilla", name="Godzilla")
356
+
357
+ def generate_continuation(
358
+ self,
359
+ events: list[dict],
360
+ *,
361
+ tokens: int = 32,
362
+ seed: Optional[int] = None,
363
+ ) -> list[dict]:
364
+ new_events, _ = generate_godzilla_continuation(
365
+ events,
366
+ generate_tokens=tokens,
367
+ seed=seed,
368
+ device="auto",
369
+ )
370
+ return new_events
371
+
372
+
373
+ def get_model(model_id: str) -> MidiModel:
374
+ if model_id == "godzilla":
375
+ return GodzillaMidiModel()
376
+ raise ValueError(f"Unknown MIDI model: {model_id}")
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
  gradio
2
- mido
 
 
 
 
 
1
  gradio
2
+ mido
3
+ torch
4
+ huggingface_hub
5
+ einops>=0.6
6
+ einx
static/keyboard.js CHANGED
@@ -125,6 +125,23 @@ function buildInstruments(instrumentConfigs) {
125
 
126
  let instruments = {}; // Will be populated after config is fetched
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  // =============================================================================
129
  // INITIALIZATION FROM SERVER CONFIG
130
  // =============================================================================
@@ -148,6 +165,9 @@ async function initializeFromConfig() {
148
  for (const [midiStr, key] of Object.entries(serverConfig.keyboard_shortcuts)) {
149
  window.keyMapFromServer[key.toLowerCase()] = parseInt(midiStr);
150
  }
 
 
 
151
 
152
  // Render keyboard after config is loaded
153
  buildKeyboard();
@@ -166,6 +186,11 @@ async function initializeFromConfig() {
166
  });
167
  window.keyboardShortcutsFromServer = keyShortcuts; // Use hardcoded as fallback
168
  window.keyMapFromServer = keyMap; // Use hardcoded as fallback
 
 
 
 
 
169
  buildKeyboard();
170
  }
171
  }
 
125
 
126
  let instruments = {}; // Will be populated after config is fetched
127
 
128
+ function populateEngineSelect(engines) {
129
+ if (!engineSelect || !Array.isArray(engines)) return;
130
+
131
+ engineSelect.innerHTML = '';
132
+ engines.forEach(engine => {
133
+ const option = document.createElement('option');
134
+ option.value = engine.id;
135
+ option.textContent = engine.name || engine.id;
136
+ engineSelect.appendChild(option);
137
+ });
138
+
139
+ if (engines.length > 0) {
140
+ selectedEngine = engines[0].id;
141
+ engineSelect.value = selectedEngine;
142
+ }
143
+ }
144
+
145
  // =============================================================================
146
  // INITIALIZATION FROM SERVER CONFIG
147
  // =============================================================================
 
165
  for (const [midiStr, key] of Object.entries(serverConfig.keyboard_shortcuts)) {
166
  window.keyMapFromServer[key.toLowerCase()] = parseInt(midiStr);
167
  }
168
+
169
+ // Populate engine dropdown from server config
170
+ populateEngineSelect(serverConfig.engines);
171
 
172
  // Render keyboard after config is loaded
173
  buildKeyboard();
 
186
  });
187
  window.keyboardShortcutsFromServer = keyShortcuts; // Use hardcoded as fallback
188
  window.keyMapFromServer = keyMap; // Use hardcoded as fallback
189
+ populateEngineSelect([
190
+ { id: 'parrot', name: 'Parrot' },
191
+ { id: 'reverse_parrot', name: 'Reverse Parrot' },
192
+ { id: 'godzilla_continue', name: 'Godzilla' }
193
+ ]);
194
  buildKeyboard();
195
  }
196
  }