projectlosangeles commited on
Commit
dcfdcda
·
verified ·
1 Parent(s): 8cb16ca

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +17 -7
  2. TMIDIX.py +0 -0
  3. app.py +440 -0
  4. packages.txt +1 -0
  5. requirements.txt +7 -0
  6. x_transformer_2_3_1.py +0 -0
README.md CHANGED
@@ -1,14 +1,24 @@
1
  ---
2
- title: MuseCraft Piano
3
- emoji:
4
  colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.33.0
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
- short_description: Solo Piano music transformer for MuseCraft project
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Godzilla Piano Transformer
3
+ emoji: 🎹
4
  colorFrom: gray
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.32.0
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
+ short_description: Fast 807M 4k solo Piano transformer trained on 1.14M+ MIDIs
12
+ tags:
13
+ - music
14
+ - music ai
15
+ - music transformer
16
+ - MIDI
17
+ - piano
18
+ - piano transformer
19
+ - godzil
20
+ thumbnail: >-
21
+ https://cdn-uploads.huggingface.co/production/uploads/5f57ea2d3f32f12a3c0692e6/LPIRJ14nakflySIcyhOUq.png
22
  ---
23
 
24
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
TMIDIX.py ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #====================================================================
2
+ # https://huggingface.co/spaces/asigalov61/Godzilla-Piano-Transformer
3
+ #====================================================================
4
+
5
+ """
6
+ Godzilla Piano Transformer Gradio App - Single Model, Simplified Version
7
+ Fast 807M 4k solo Piano music transformer trained on 1.14M+ MIDIs (2.7M+ samples)
8
+ Using only one model: "without velocity - 3 epochs"
9
+ """
10
+
11
+ import os
12
+
13
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
14
+
15
+ import time as reqtime
16
+ import datetime
17
+ from pytz import timezone
18
+
19
+ import torch
20
+ import matplotlib.pyplot as plt
21
+ import gradio as gr
22
+ import spaces
23
+
24
+ from huggingface_hub import hf_hub_download
25
+ import TMIDIX
26
+ from midi_to_colab_audio import midi_to_colab_audio
27
+ from x_transformer_2_3_1 import TransformerWrapper, AutoregressiveWrapper, Decoder
28
+
29
+ # -----------------------------
30
+ # CONFIGURATION & GLOBALS
31
+ # -----------------------------
32
+ SEP = '=' * 70
33
+ PDT = timezone('US/Pacific')
34
+
35
+ MODEL_CHECKPOINT = 'Godzilla_Piano_Transformer_No_Velocity_Trained_Model_21113_steps_0.3454_loss_0.895_acc.pth'
36
+ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
37
+ NUM_OUT_BATCHES = 12
38
+ PREVIEW_LENGTH = 120 # in tokens
39
+
40
+ # -----------------------------
41
+ # PRINT START-UP INFO
42
+ # -----------------------------
43
+ def print_sep():
44
+ print(SEP)
45
+
46
+ print_sep()
47
+ print("Godzilla Piano Transformer Gradio App")
48
+ print_sep()
49
+ print("Loading modules...")
50
+
51
+ # -----------------------------
52
+ # ENVIRONMENT & PyTorch Settings
53
+ # -----------------------------
54
+ os.environ['USE_FLASH_ATTENTION'] = '1'
55
+
56
+ torch.set_float32_matmul_precision('high')
57
+ torch.backends.cuda.matmul.allow_tf32 = True
58
+ torch.backends.cudnn.allow_tf32 = True
59
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
60
+ torch.backends.cuda.enable_math_sdp(True)
61
+ torch.backends.cuda.enable_flash_sdp(True)
62
+ torch.backends.cuda.enable_cudnn_sdp(True)
63
+
64
+ print_sep()
65
+ print("PyTorch version:", torch.__version__)
66
+ print("Done loading modules!")
67
+ print_sep()
68
+
69
+ # -----------------------------
70
+ # MODEL INITIALIZATION
71
+ # -----------------------------
72
+ print_sep()
73
+ print("Instantiating model...")
74
+
75
+ device_type = 'cuda'
76
+ dtype = 'bfloat16'
77
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
78
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
79
+
80
+ SEQ_LEN = 4096
81
+ PAD_IDX = 384
82
+
83
+ model = TransformerWrapper(
84
+ num_tokens=PAD_IDX + 1,
85
+ max_seq_len=SEQ_LEN,
86
+ attn_layers=Decoder(
87
+ dim=2048,
88
+ depth=16,
89
+ heads=32,
90
+ rotary_pos_emb=True,
91
+ attn_flash=True
92
+ )
93
+ )
94
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
95
+
96
+ print_sep()
97
+ print("Loading model checkpoint...")
98
+ checkpoint = hf_hub_download(
99
+ repo_id='asigalov61/Godzilla-Piano-Transformer',
100
+ filename=MODEL_CHECKPOINT
101
+ )
102
+ model.load_state_dict(torch.load(checkpoint, map_location='cuda', weights_only=True))
103
+ model = torch.compile(model, mode='max-autotune')
104
+ print_sep()
105
+ print("Done!")
106
+ print("Model will use", dtype, "precision...")
107
+ print_sep()
108
+
109
+ model.cuda()
110
+ model.eval()
111
+
112
+ # -----------------------------
113
+ # HELPER FUNCTIONS
114
+ # -----------------------------
115
+ def render_midi_output(final_composition):
116
+ """Generate MIDI score, plot, and audio from final composition."""
117
+ fname, midi_score = save_midi(final_composition)
118
+ time_val = midi_score[-1][1] / 1000 # seconds marker from last note
119
+ midi_plot = TMIDIX.plot_ms_SONG(
120
+ midi_score,
121
+ plot_title='Godzilla Piano Transformer Composition',
122
+ block_lines_times_list=[],
123
+ return_plt=True
124
+ )
125
+ midi_audio = midi_to_colab_audio(
126
+ fname + '.mid',
127
+ soundfont_path=SOUDFONT_PATH,
128
+ sample_rate=16000,
129
+ output_for_gradio=True
130
+ )
131
+ return (16000, midi_audio), midi_plot, fname + '.mid', time_val
132
+
133
+ # -----------------------------
134
+ # MIDI PROCESSING FUNCTIONS
135
+ # -----------------------------
136
+ def load_midi(input_midi):
137
+ """Process the input MIDI file and create a token sequence using without velocity logic."""
138
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
139
+ escore_notes = TMIDIX.advanced_score_processor(
140
+ raw_score, return_enhanced_score_notes=True, apply_sustain=True
141
+ )[0]
142
+ sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes)
143
+ zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
144
+ zscore = TMIDIX.augment_enhanced_score_notes(zscore, timings_divider=32)
145
+ fscore = TMIDIX.fix_escore_notes_durations(zscore)
146
+ cscore = TMIDIX.chordify_score([1000, fscore])
147
+
148
+ score = []
149
+ prev_chord = cscore[0]
150
+ for chord in cscore:
151
+ # Time difference token.
152
+ score.append(max(0, min(127, chord[0][1] - prev_chord[0][1])))
153
+ for note in chord:
154
+ score.extend([
155
+ max(1, min(127, note[2])) + 128,
156
+ max(1, min(127, note[4])) + 256
157
+ ])
158
+ prev_chord = chord
159
+ return score
160
+
161
+ def save_midi(tokens, batch_number=None):
162
+ """Convert token sequence back to a MIDI score and write it using TMIDIX (without velocity).
163
+ The output MIDI file name incorporates a date-time stamp.
164
+ """
165
+ song_events = []
166
+ time_marker = 0
167
+ duration = 0
168
+ pitch = 0
169
+ patches = [0] * 16
170
+
171
+ for token in tokens:
172
+ if 0 <= token < 128:
173
+ time_marker += token * 32
174
+ elif 128 <= token < 256:
175
+ duration = (token - 128) * 32
176
+ elif 256 <= token < 384:
177
+ pitch = token - 256
178
+ song_events.append(['note', time_marker, duration, 0, pitch, max(40, pitch), 0])
179
+ # No velocity tokens are used.
180
+
181
+ # Generate a time stamp using the PDT timezone.
182
+ timestamp = datetime.datetime.now(PDT).strftime("%Y%m%d_%H%M%S")
183
+
184
+ '''if batch_number is None:
185
+ fname = f"Godzilla-Piano-Transformer-Music-Composition_{timestamp}"
186
+ else:
187
+ fname = f"Godzilla-Piano-Transformer-Music-Composition_{timestamp}_Batch_{batch_number}"'''
188
+
189
+ fname = f"Godzilla-Piano-Transformer-Music-Composition"
190
+
191
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
192
+ song_events,
193
+ output_signature='Godzilla Piano Transformer',
194
+ output_file_name=fname,
195
+ track_name='Project Los Angeles',
196
+ list_of_MIDI_patches=patches,
197
+ verbose=False
198
+ )
199
+ return fname, song_events
200
+
201
+ # -----------------------------
202
+ # MUSIC GENERATION FUNCTION (Combined)
203
+ # -----------------------------
204
+ @spaces.GPU
205
+ def generate_music(prime, num_gen_tokens, num_mem_tokens, num_gen_batches, model_temperature):
206
+ """Generate music tokens given prime tokens and parameters."""
207
+ inputs = prime[-num_mem_tokens:] if prime else [0]
208
+ print("Generating...")
209
+ inp = torch.LongTensor([inputs] * num_gen_batches).cuda()
210
+ with ctx:
211
+ out = model.generate(
212
+ inp,
213
+ num_gen_tokens,
214
+ temperature=model_temperature,
215
+ return_prime=False,
216
+ verbose=False
217
+ )
218
+ print("Done!")
219
+ print_sep()
220
+ return out.tolist()
221
+
222
+ def generate_music_and_state(input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens,
223
+ model_temperature, final_composition, generated_batches, block_lines):
224
+ """
225
+ Generate tokens using the model, update the composition state, and prepare outputs.
226
+ This function combines seed loading, token generation, and UI output packaging.
227
+ """
228
+ print_sep()
229
+ print("Request start time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
230
+
231
+ print('=' * 70)
232
+ if input_midi is not None:
233
+ fn = os.path.basename(input_midi.name)
234
+ fn1 = fn.split('.')[0]
235
+ print('Input file name:', fn)
236
+
237
+ print('Num prime tokens:', num_prime_tokens)
238
+ print('Num gen tokens:', num_gen_tokens)
239
+ print('Num mem tokens:', num_mem_tokens)
240
+
241
+ print('Model temp:', model_temperature)
242
+ print('=' * 70)
243
+
244
+ # Load seed from MIDI if there is no existing composition.
245
+ if not final_composition and input_midi is not None:
246
+ final_composition = load_midi(input_midi)[:num_prime_tokens]
247
+ midi_fname, midi_score = save_midi(final_composition)
248
+ # Use the last note's time as a marker.
249
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
250
+ midi_score,
251
+ output_signature='Godzilla Piano Transformer',
252
+ output_file_name=midi_fname,
253
+ track_name='Project Los Angeles',
254
+ list_of_MIDI_patches=[0]*16,
255
+ verbose=False
256
+ )
257
+ block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
258
+
259
+ batched_gen_tokens = generate_music(final_composition, num_gen_tokens, num_mem_tokens,
260
+ NUM_OUT_BATCHES, model_temperature)
261
+
262
+ output_batches = []
263
+ for i, tokens in enumerate(batched_gen_tokens):
264
+ preview_tokens = final_composition[-PREVIEW_LENGTH:]
265
+ midi_fname, midi_score = save_midi(preview_tokens + tokens, batch_number=i)
266
+ plot_kwargs = {'plot_title': f'Batch # {i}', 'return_plt': True}
267
+ if len(final_composition) > PREVIEW_LENGTH:
268
+ plot_kwargs['preview_length_in_notes'] = len([t for t in preview_tokens if t > 256])
269
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
270
+ midi_score,
271
+ output_signature='Godzilla Piano Transformer',
272
+ output_file_name=midi_fname,
273
+ track_name='Project Los Angeles',
274
+ list_of_MIDI_patches=[0]*16,
275
+ verbose=False
276
+ )
277
+ midi_plot = TMIDIX.plot_ms_SONG(midi_score, **plot_kwargs)
278
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
279
+ soundfont_path=SOUDFONT_PATH,
280
+ sample_rate=16000,
281
+ output_for_gradio=True)
282
+ output_batches.append([(16000, midi_audio), midi_plot, tokens])
283
+
284
+ # Update generated_batches (for use by add/remove functions)
285
+ generated_batches = batched_gen_tokens
286
+
287
+ print("Request end time:", datetime.datetime.now(PDT).strftime("%Y-%m-%d %H:%M:%S"))
288
+ print_sep()
289
+
290
+ # Flatten outputs: states then audio and plots for each batch.
291
+ outputs_flat = []
292
+ for batch in output_batches:
293
+ outputs_flat.extend([batch[0], batch[1]])
294
+ return [final_composition, generated_batches, block_lines] + outputs_flat
295
+
296
+ # -----------------------------
297
+ # BATCH HANDLING FUNCTIONS
298
+ # -----------------------------
299
+ def add_batch(batch_number, final_composition, generated_batches, block_lines):
300
+ """Add tokens from the specified batch to the final composition and update outputs."""
301
+ if generated_batches:
302
+ final_composition.extend(generated_batches[batch_number])
303
+ midi_fname, midi_score = save_midi(final_composition)
304
+ block_lines.append(midi_score[-1][1] / 1000 if final_composition else 0)
305
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
306
+ midi_score,
307
+ output_signature='Godzilla Piano Transformer',
308
+ output_file_name=midi_fname,
309
+ track_name='Project Los Angeles',
310
+ list_of_MIDI_patches=[0]*16,
311
+ verbose=False
312
+ )
313
+ midi_plot = TMIDIX.plot_ms_SONG(
314
+ midi_score,
315
+ plot_title='Godzilla Piano Transformer Composition',
316
+ block_lines_times_list=block_lines[:-1],
317
+ return_plt=True
318
+ )
319
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
320
+ soundfont_path=SOUDFONT_PATH,
321
+ sample_rate=16000,
322
+ output_for_gradio=True)
323
+ print("Added batch #", batch_number)
324
+ print_sep()
325
+ return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
326
+ else:
327
+ return None, None, None, [], [], []
328
+
329
+ def remove_batch(batch_number, num_tokens, final_composition, generated_batches, block_lines):
330
+ """Remove tokens from the final composition and update outputs."""
331
+ if final_composition and len(final_composition) > num_tokens:
332
+ final_composition = final_composition[:-num_tokens]
333
+ if block_lines:
334
+ block_lines.pop()
335
+ midi_fname, midi_score = save_midi(final_composition)
336
+ TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(
337
+ midi_score,
338
+ output_signature='Godzilla Piano Transformer',
339
+ output_file_name=midi_fname,
340
+ track_name='Project Los Angeles',
341
+ list_of_MIDI_patches=[0]*16,
342
+ verbose=False
343
+ )
344
+ midi_plot = TMIDIX.plot_ms_SONG(
345
+ midi_score,
346
+ plot_title='Godzilla Piano Transformer Composition',
347
+ block_lines_times_list=block_lines[:-1],
348
+ return_plt=True
349
+ )
350
+ midi_audio = midi_to_colab_audio(midi_fname + '.mid',
351
+ soundfont_path=SOUDFONT_PATH,
352
+ sample_rate=16000,
353
+ output_for_gradio=True)
354
+ print("Removed batch #", batch_number)
355
+ print_sep()
356
+ return (16000, midi_audio), midi_plot, midi_fname + '.mid', final_composition, generated_batches, block_lines
357
+ else:
358
+ return None, None, None, [], [], []
359
+
360
+ def clear():
361
+ """Clear outputs and reset state."""
362
+ return None, None, None, [], []
363
+
364
+ def reset(final_composition=[], generated_batches=[], block_lines=[]):
365
+ """Reset composition state."""
366
+ return [], [], []
367
+
368
+ # -----------------------------
369
+ # GRADIO INTERFACE SETUP
370
+ # -----------------------------
371
+ with gr.Blocks() as demo:
372
+
373
+ gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Godzilla Piano Transformer</h1>")
374
+ gr.Markdown("<h1 style='text-align: left; margin-bottom: 1rem'>Fast 807M 4k solo Piano music transformer trained on 1.14M+ MIDIs (2.7M+ samples)</h1>")
375
+ gr.HTML("""
376
+ Check out <a href="https://huggingface.co/datasets/asigalov61/Godzilla-Piano">Godzilla Piano dataset</a> on Hugging Face
377
+ <p>
378
+ <a href="https://huggingface.co/spaces/asigalov61/Godzilla-Piano-Transformer?duplicate=true">
379
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate in Hugging Face">
380
+ </a>
381
+ </p>
382
+ for faster execution and endless generation!
383
+ """)
384
+
385
+ # Global state variables for composition
386
+ final_composition = gr.State([])
387
+ generated_batches = gr.State([])
388
+ block_lines = gr.State([])
389
+
390
+ gr.Markdown("## Upload seed MIDI or click 'Generate' for a random output")
391
+ input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"])
392
+ input_midi.upload(reset, [final_composition, generated_batches, block_lines],
393
+ [final_composition, generated_batches, block_lines])
394
+
395
+ gr.Markdown("## Generate")
396
+ num_prime_tokens = gr.Slider(15, 3072, value=3072, step=1, label="Number of prime tokens")
397
+ num_gen_tokens = gr.Slider(15, 1024, value=512, step=1, label="Number of tokens to generate")
398
+ num_mem_tokens = gr.Slider(15, 4096, value=4096, step=1, label="Number of memory tokens")
399
+ model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
400
+ generate_btn = gr.Button("Generate", variant="primary")
401
+
402
+ gr.Markdown("## Batch Previews")
403
+ outputs = [final_composition, generated_batches, block_lines]
404
+ # Two outputs (audio and plot) for each batch
405
+ for i in range(NUM_OUT_BATCHES):
406
+ with gr.Tab(f"Batch # {i}"):
407
+ audio_output = gr.Audio(label=f"Batch # {i} MIDI Audio", format="mp3")
408
+ plot_output = gr.Plot(label=f"Batch # {i} MIDI Plot")
409
+ outputs.extend([audio_output, plot_output])
410
+ generate_btn.click(
411
+ generate_music_and_state,
412
+ [input_midi, num_prime_tokens, num_gen_tokens, num_mem_tokens, model_temperature,
413
+ final_composition, generated_batches, block_lines],
414
+ outputs
415
+ )
416
+
417
+ gr.Markdown("## Add/Remove Batch")
418
+ batch_number = gr.Slider(0, NUM_OUT_BATCHES - 1, value=0, step=1, label="Batch number to add/remove")
419
+ add_btn = gr.Button("Add batch", variant="primary")
420
+ remove_btn = gr.Button("Remove batch", variant="stop")
421
+ clear_btn = gr.ClearButton()
422
+
423
+ final_audio_output = gr.Audio(label="Final MIDI audio", format="mp3")
424
+ final_plot_output = gr.Plot(label="Final MIDI plot")
425
+ final_file_output = gr.File(label="Final MIDI file")
426
+
427
+ add_btn.click(
428
+ add_batch,
429
+ [batch_number, final_composition, generated_batches, block_lines],
430
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
431
+ )
432
+ remove_btn.click(
433
+ remove_batch,
434
+ [batch_number, num_gen_tokens, final_composition, generated_batches, block_lines],
435
+ [final_audio_output, final_plot_output, final_file_output, final_composition, generated_batches, block_lines]
436
+ )
437
+ clear_btn.click(clear, inputs=None,
438
+ outputs=[final_audio_output, final_plot_output, final_file_output, final_composition, block_lines])
439
+
440
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ huggingface_hub
3
+ gradio
4
+ einops
5
+ einx
6
+ matplotlib
7
+ tqdm
x_transformer_2_3_1.py ADDED
The diff for this file is too large to render. See raw diff