asigalov61 commited on
Commit
e3ffbcb
·
verified ·
1 Parent(s): cba55e9

Upload Orpheus_Drums_Transformer.ipynb

Browse files
inference_code/Orpheus_Drums_Transformer.ipynb ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "VGrGd6__l5ch"
7
+ },
8
+ "source": [
9
+ "# Orpheus Drums Transformer (ver. 1.0)\n",
10
+ "\n",
11
+ "***\n",
12
+ "\n",
13
+ "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n",
14
+ "\n",
15
+ "***\n",
16
+ "\n",
17
+ "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n",
18
+ "\n",
19
+ "***\n",
20
+ "\n",
21
+ "#### Project Los Angeles\n",
22
+ "\n",
23
+ "#### Tegridy Code 2025\n",
24
+ "\n",
25
+ "***"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {
31
+ "id": "shLrgoXdl5cj"
32
+ },
33
+ "source": [
34
+ "# GPU check"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "X3rABEpKCO02"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "!nvidia-smi"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {
51
+ "id": "0RcVC4btl5ck"
52
+ },
53
+ "source": [
54
+ "# Setup environment"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {
61
+ "id": "viHgEaNACPTs"
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {
72
+ "id": "vK40g6V_BTNj"
73
+ },
74
+ "outputs": [],
75
+ "source": [
76
+ "!pip install huggingface_hub\n",
77
+ "!pip install hf-transfer\n",
78
+ "!pip install ipywidgets\n",
79
+ "!pip install tqdm\n",
80
+ "\n",
81
+ "!pip install einx\n",
82
+ "!pip install einops\n",
83
+ "!pip install torch-summary"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {
90
+ "id": "DzCOZU_gBiQV"
91
+ },
92
+ "outputs": [],
93
+ "source": [
94
+ "# Load modules and make data dir\n",
95
+ "\n",
96
+ "print('Loading modules...')\n",
97
+ "\n",
98
+ "import os\n",
99
+ "\n",
100
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n",
101
+ "\n",
102
+ "import pickle\n",
103
+ "import random\n",
104
+ "import secrets\n",
105
+ "import tqdm\n",
106
+ "import math\n",
107
+ "\n",
108
+ "import gc\n",
109
+ "\n",
110
+ "!set USE_FLASH_ATTENTION=1\n",
111
+ "os.environ['USE_FLASH_ATTENTION'] = '1'\n",
112
+ "\n",
113
+ "import torch\n",
114
+ "\n",
115
+ "import matplotlib.pyplot as plt\n",
116
+ "\n",
117
+ "from torchsummary import summary\n",
118
+ "from sklearn import metrics\n",
119
+ "\n",
120
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/\n",
121
+ "\n",
122
+ "import TMIDIX\n",
123
+ "\n",
124
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer\n",
125
+ "\n",
126
+ "from x_transformer_2_3_1 import *\n",
127
+ "\n",
128
+ "torch.set_float32_matmul_precision('high')\n",
129
+ "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
130
+ "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
131
+ "torch.backends.cuda.enable_flash_sdp(True)\n",
132
+ "torch.backends.cuda.enable_cudnn_sdp(False)\n",
133
+ "\n",
134
+ "!set USE_FLASH_ATTENTION=1\n",
135
+ "\n",
136
+ "%cd /home/ubuntu/\n",
137
+ "\n",
138
+ "import random\n",
139
+ "\n",
140
+ "from huggingface_hub import hf_hub_download\n",
141
+ "\n",
142
+ "print('Done')\n",
143
+ "\n",
144
+ "print('Torch version:', torch.__version__)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "metadata": {
150
+ "id": "feXay_Ed7mG5"
151
+ },
152
+ "source": [
153
+ "# Download model"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "metadata": {
160
+ "id": "SA8qQSzbWslM"
161
+ },
162
+ "outputs": [],
163
+ "source": [
164
+ "hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n",
165
+ " filename='Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth',\n",
166
+ " local_dir='/home/ubuntu/Models/',\n",
167
+ " repo_type='model'\n",
168
+ " )"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "metadata": {},
174
+ "source": [
175
+ "# Load model"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "metadata": {
182
+ "id": "gSvqSRLaWslM"
183
+ },
184
+ "outputs": [],
185
+ "source": [
186
+ "SEQ_LEN = 8192\n",
187
+ "PAD_IDX = 18819\n",
188
+ "\n",
189
+ "model = TransformerWrapper(num_tokens = PAD_IDX+1,\n",
190
+ " max_seq_len = SEQ_LEN,\n",
191
+ " attn_layers = Decoder(dim = 2048,\n",
192
+ " depth = 8,\n",
193
+ " heads = 32,\n",
194
+ " rotary_pos_emb = True,\n",
195
+ " attn_flash = True\n",
196
+ " )\n",
197
+ " )\n",
198
+ "\n",
199
+ "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n",
200
+ "\n",
201
+ "print('=' * 70)\n",
202
+ "print('Loading model checkpoint...')\n",
203
+ "\n",
204
+ "model_path = 'Models/Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth'\n",
205
+ "\n",
206
+ "model.load_state_dict(torch.load(model_path))\n",
207
+ "\n",
208
+ "print('=' * 70)\n",
209
+ "\n",
210
+ "model.cuda()\n",
211
+ "model.eval()\n",
212
+ "\n",
213
+ "print('Done!')\n",
214
+ "\n",
215
+ "summary(model)\n",
216
+ "\n",
217
+ "dtype = torch.bfloat16\n",
218
+ "\n",
219
+ "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {},
225
+ "source": [
226
+ "# Load MIDI"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {
233
+ "id": "enHpaHxaWslM"
234
+ },
235
+ "outputs": [],
236
+ "source": [
237
+ "midi_file = 'tegridy-tools/tegridy-tools/seed2.mid'\n",
238
+ "\n",
239
+ "print('=' * 70)\n",
240
+ "print('Loading MIDI...')\n",
241
+ "\n",
242
+ "raw_score = TMIDIX.midi2single_track_ms_score(midi_file)\n",
243
+ "\n",
244
+ "escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)\n",
245
+ "\n",
246
+ "if escore_notes:\n",
247
+ "\n",
248
+ " escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True)\n",
249
+ "\n",
250
+ " escore_notes = TMIDIX.recalculate_score_timings([e for e in escore_notes if e[3] != 9])\n",
251
+ " \n",
252
+ " dscore = TMIDIX.delta_score_notes(escore_notes)\n",
253
+ " \n",
254
+ " dcscore = TMIDIX.chordify_score([d[1:] for d in dscore])\n",
255
+ " \n",
256
+ " melody_chords = [18816]\n",
257
+ "\n",
258
+ " chords = []\n",
259
+ " \n",
260
+ " #=======================================================\n",
261
+ " # MAIN PROCESSING CYCLE\n",
262
+ " #=======================================================\n",
263
+ " \n",
264
+ " for i, c in enumerate(dcscore):\n",
265
+ " \n",
266
+ " delta_time = c[0][0]\n",
267
+ " \n",
268
+ " melody_chords.append(delta_time)\n",
269
+ "\n",
270
+ " cho = []\n",
271
+ " \n",
272
+ " cho.append(delta_time)\n",
273
+ " \n",
274
+ " for e in c:\n",
275
+ " \n",
276
+ " #=======================================================\n",
277
+ " \n",
278
+ " # Durations\n",
279
+ " dur = max(1, min(255, e[1]))\n",
280
+ " \n",
281
+ " # Patches\n",
282
+ " pat = max(0, min(128, e[5]))\n",
283
+ " \n",
284
+ " # Pitches\n",
285
+ " ptc = max(1, min(127, e[3]))\n",
286
+ " \n",
287
+ " # Velocities\n",
288
+ " # Calculating octo-velocity\n",
289
+ " \n",
290
+ " vel = max(8, min(127, e[4]))\n",
291
+ " velocity = round(vel / 15)-1\n",
292
+ " \n",
293
+ " #=======================================================\n",
294
+ " # FINAL NOTE SEQ\n",
295
+ " #=======================================================\n",
296
+ " \n",
297
+ " # Writing final note\n",
298
+ " pat_ptc = (128 * pat) + ptc \n",
299
+ " dur_vel = (8 * dur) + velocity\n",
300
+ " \n",
301
+ " melody_chords.extend([pat_ptc+256, dur_vel+16768]) # 18816\n",
302
+ " cho.extend([pat_ptc+256, dur_vel+16768])\n",
303
+ "\n",
304
+ " chords.append(cho)\n",
305
+ " \n",
306
+ " print('Done!')\n",
307
+ " print('=' * 70)\n",
308
+ " print('Score has', len(melody_chords), 'tokens')\n",
309
+ " print('Score has', len(chords), 'chords')\n",
310
+ " print('=' * 70)\n",
311
+ "\n",
312
+ "else:\n",
313
+ " print('Error! Check MIDI file!')"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "markdown",
318
+ "metadata": {},
319
+ "source": [
320
+ "# Texture chords"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {
327
+ "id": "w6Z3HJ313EL_"
328
+ },
329
+ "outputs": [],
330
+ "source": [
331
+ "model_temperature = 1.0\n",
332
+ "model_sampling_top_p = 0.96\n",
333
+ "\n",
334
+ "#==================================================================\n",
335
+ "\n",
336
+ "print('=' * 70)\n",
337
+ "print('Sample score tokens', melody_chords[:10])\n",
338
+ "\n",
339
+ "#==================================================================\n",
340
+ "\n",
341
+ "def gen_drums(seq):\n",
342
+ "\n",
343
+ " y = 16641\n",
344
+ " num_gen_drums = 0\n",
345
+ "\n",
346
+ " while y > 16640:\n",
347
+ " \n",
348
+ " x = torch.LongTensor(seq).cuda()\n",
349
+ " \n",
350
+ " with ctx:\n",
351
+ " out = model.generate(x,\n",
352
+ " 1,\n",
353
+ " temperature=model_temperature,\n",
354
+ " filter_logits_fn=top_p,\n",
355
+ " filter_kwargs={'thres': model_sampling_top_p},\n",
356
+ " return_prime=False,\n",
357
+ " eos_token=18818,\n",
358
+ " verbose=False)\n",
359
+ "\n",
360
+ " y = out.tolist()[0]\n",
361
+ "\n",
362
+ " if y > 16640:\n",
363
+ " seq.append(y)\n",
364
+ " num_gen_drums += 1\n",
365
+ "\n",
366
+ " if num_gen_drums == 10:\n",
367
+ " break\n",
368
+ "\n",
369
+ " return seq\n",
370
+ "\n",
371
+ "#==================================================================\n",
372
+ "\n",
373
+ "print('=' * 70)\n",
374
+ "print('Generating...')\n",
375
+ "print('=' * 70)\n",
376
+ "\n",
377
+ "final_song = [18816]\n",
378
+ "\n",
379
+ "for i in tqdm.tqdm(range(len(chords))):\n",
380
+ "\n",
381
+ " final_song.extend(chords[i])\n",
382
+ "\n",
383
+ " if i == 0:\n",
384
+ " final_song.append((128*128)+38+256) # Drum pitch/patch\n",
385
+ " final_song.append((8*8)+5+16768) # Drum dur/vel\n",
386
+ " \n",
387
+ " if (final_song[-2] < 16640 and i % 8 == 0):\n",
388
+ " final_song.append((128*128)+38+256) # Drum pitch/patch\n",
389
+ "\n",
390
+ " final_song = gen_drums(final_song)\n",
391
+ "\n",
392
+ "#==================================================================\n",
393
+ "\n",
394
+ "print('=' * 70)\n",
395
+ "print('Done!')\n",
396
+ "print('=' * 70)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "markdown",
401
+ "metadata": {},
402
+ "source": [
403
+ "# Save to MIDI"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {
410
+ "id": "tlBzqWpAnZna"
411
+ },
412
+ "outputs": [],
413
+ "source": [
414
+ "print('Sample INTs', final_song[:15])\n",
415
+ "\n",
416
+ "if len(final_song) != 0:\n",
417
+ "\n",
418
+ " song_f = []\n",
419
+ "\n",
420
+ " time = 0\n",
421
+ " dur = 1\n",
422
+ " vel = 90\n",
423
+ " pitch = 60\n",
424
+ " channel = 0\n",
425
+ " patch = 0\n",
426
+ "\n",
427
+ " patches = [-1] * 16\n",
428
+ "\n",
429
+ " channels = [0] * 16\n",
430
+ " channels[9] = 1\n",
431
+ "\n",
432
+ " for ss in final_song:\n",
433
+ "\n",
434
+ " if 0 <= ss < 256:\n",
435
+ "\n",
436
+ " time += ss * 16\n",
437
+ "\n",
438
+ " if 256 <= ss < 16768:\n",
439
+ "\n",
440
+ " patch = (ss-256) // 128\n",
441
+ "\n",
442
+ " if patch < 128:\n",
443
+ "\n",
444
+ " if patch not in patches:\n",
445
+ " if 0 in channels:\n",
446
+ " cha = channels.index(0)\n",
447
+ " channels[cha] = 1\n",
448
+ " else:\n",
449
+ " cha = 15\n",
450
+ "\n",
451
+ " patches[cha] = patch\n",
452
+ " channel = patches.index(patch)\n",
453
+ " else:\n",
454
+ " channel = patches.index(patch)\n",
455
+ "\n",
456
+ " if patch == 128:\n",
457
+ " channel = 9\n",
458
+ "\n",
459
+ " pitch = (ss-256) % 128\n",
460
+ "\n",
461
+ "\n",
462
+ " if 16768 <= ss < 18816:\n",
463
+ "\n",
464
+ " dur = ((ss-16768) // 8) * 16\n",
465
+ " vel = (((ss-16768) % 8)+1) * 15\n",
466
+ "\n",
467
+ " song_f.append(['note', time, dur, channel, pitch, vel, patch])\n",
468
+ "\n",
469
+ " patches = [0 if x==-1 else x for x in patches]\n",
470
+ "\n",
471
+ "output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)\n",
472
+ "\n",
473
+ "fn1 = \"Orpheus-Drums-Transformer-Composition\"\n",
474
+ "\n",
475
+ "detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,\n",
476
+ " output_signature = 'Orpheus Drums Transformer',\n",
477
+ " output_file_name = fn1,\n",
478
+ " track_name='Project Los Angeles',\n",
479
+ " list_of_MIDI_patches=patches\n",
480
+ " )\n",
481
+ "\n",
482
+ "print('Done!')"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "metadata": {},
488
+ "source": [
489
+ "# Plot tokens embeddings"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {
496
+ "id": "al3TDlH7T8m7"
497
+ },
498
+ "outputs": [],
499
+ "source": [
500
+ "tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()\n",
501
+ "\n",
502
+ "cos_sim = metrics.pairwise_distances(\n",
503
+ " tok_emb, metric='cosine'\n",
504
+ ")\n",
505
+ "plt.figure(figsize=(7, 7))\n",
506
+ "plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n",
507
+ "im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n",
508
+ "plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n",
509
+ "plt.xlabel(\"Position\")\n",
510
+ "plt.ylabel(\"Position\")\n",
511
+ "plt.tight_layout()\n",
512
+ "plt.plot()\n",
513
+ "plt.savefig(\"/home/ubuntu/Orpheus-Drums-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "markdown",
518
+ "metadata": {
519
+ "id": "z87TlDTVl5cp"
520
+ },
521
+ "source": [
522
+ "# Congrats! You did it! :)"
523
+ ]
524
+ }
525
+ ],
526
+ "metadata": {
527
+ "accelerator": "GPU",
528
+ "colab": {
529
+ "gpuClass": "premium",
530
+ "gpuType": "T4",
531
+ "private_outputs": true,
532
+ "provenance": []
533
+ },
534
+ "kernelspec": {
535
+ "display_name": "Python 3 (ipykernel)",
536
+ "language": "python",
537
+ "name": "python3"
538
+ },
539
+ "language_info": {
540
+ "codemirror_mode": {
541
+ "name": "ipython",
542
+ "version": 3
543
+ },
544
+ "file_extension": ".py",
545
+ "mimetype": "text/x-python",
546
+ "name": "python",
547
+ "nbconvert_exporter": "python",
548
+ "pygments_lexer": "ipython3",
549
+ "version": "3.10.12"
550
+ }
551
+ },
552
+ "nbformat": 4,
553
+ "nbformat_minor": 4
554
+ }