kmaes commited on
Commit
b148e11
·
verified ·
1 Parent(s): 102df92

Upload 27 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ text2midi_repo/captions/captions.json filter=lfs diff=lfs merge=lfs -text
37
+ text2midi_repo/text2midi_architecture.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,118 @@
1
  ---
2
- title: TextToAudio
3
- emoji: 📊
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
 
9
  pinned: false
10
- license: mit
11
- short_description: TextToAudio
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: VR Game Music Generator
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
+ python_version: 3.11
10
  pinned: false
 
 
11
  ---
12
 
13
+ # VR Game Music Generator
14
+
15
+ Generate music from text descriptions using the text2midi AI model. Designed for integration with Unity and other game engines via the Gradio API.
16
+
17
+ ## Features
18
+
19
+ - Text-to-music generation using AI
20
+ - Real-time audio streaming (no file persistence)
21
+ - RESTful API for game engine integration
22
+ - Supports various music styles and instruments
23
+
24
+ ## API Usage
25
+
26
+ ### Endpoint
27
+ ```
28
+ POST https://YOUR-SPACE.hf.space/api/generate
29
+ ```
30
+
31
+ ### Request
32
+ ```json
33
+ {
34
+ "data": ["A cheerful pop song with piano and drums", 512, 0.9]
35
+ }
36
+ ```
37
+
38
+ Parameters:
39
+ - `data[0]`: Music prompt (string)
40
+ - `data[1]`: Max length in tokens (256-2048, default: 512)
41
+ - `data[2]`: Temperature (0.1-1.5, default: 0.9)
42
+
43
+ ### Response
44
+ ```json
45
+ {
46
+ "data": [
47
+ {"path": "/file=...", "url": "https://...", "orig_name": "audio.wav"},
48
+ "AI-generated audio for: 'A cheerful pop song...'"
49
+ ]
50
+ }
51
+ ```
52
+
53
+ ## Unity Integration
54
+
55
+ ```csharp
56
+ using UnityEngine;
57
+ using UnityEngine.Networking;
58
+ using System.Collections;
59
+
60
+ public class MusicGenerator : MonoBehaviour
61
+ {
62
+ private const string API_URL = "https://YOUR-SPACE.hf.space/api/generate";
63
+
64
+ public IEnumerator GenerateMusic(string prompt, System.Action<AudioClip> callback)
65
+ {
66
+ string json = $"{{\"data\": [\"{prompt}\", 512, 0.9]}}";
67
+
68
+ using (UnityWebRequest request = new UnityWebRequest(API_URL, "POST"))
69
+ {
70
+ byte[] bodyRaw = System.Text.Encoding.UTF8.GetBytes(json);
71
+ request.uploadHandler = new UploadHandlerRaw(bodyRaw);
72
+ request.downloadHandler = new DownloadHandlerBuffer();
73
+ request.SetRequestHeader("Content-Type", "application/json");
74
+
75
+ yield return request.SendWebRequest();
76
+
77
+ if (request.result == UnityWebRequest.Result.Success)
78
+ {
79
+ // Parse response and download audio from returned URL
80
+ var response = JsonUtility.FromJson<GradioResponse>(request.downloadHandler.text);
81
+ yield return DownloadAudio(response.data[0].url, callback);
82
+ }
83
+ }
84
+ }
85
+
86
+ private IEnumerator DownloadAudio(string url, System.Action<AudioClip> callback)
87
+ {
88
+ using (UnityWebRequest www = UnityWebRequestMultimedia.GetAudioClip(url, AudioType.WAV))
89
+ {
90
+ yield return www.SendWebRequest();
91
+ if (www.result == UnityWebRequest.Result.Success)
92
+ {
93
+ callback(DownloadHandlerAudioClip.GetContent(www));
94
+ }
95
+ }
96
+ }
97
+ }
98
+ ```
99
+
100
+ ## Example Prompts
101
+
102
+ - A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, and drums
103
+ - An energetic electronic trance track with synth bass and drums at 138 BPM
104
+ - A slow and emotional classical piece featuring cello and violin in C minor
105
+ - A cinematic electronic soundtrack with an epic and dark atmosphere
106
+ - Happy medieval tavern music with lute and flute
107
+
108
+ ## Local Development
109
+
110
+ ```bash
111
+ pip install -r requirements.txt
112
+ python app.py
113
+ ```
114
+
115
+ ## Credits
116
+
117
+ - Model: [amaai-lab/text2midi](https://huggingface.co/amaai-lab/text2midi)
118
+ - Audio synthesis: FluidSynth with FluidR3 GM SoundFont
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VR Music Generator - HuggingFace Spaces Version
3
+ Generates music from text descriptions using the text2midi AI model.
4
+ Exposes a Gradio API for Unity integration.
5
+ Audio is streamed directly - no files are persisted.
6
+ """
7
+ import gradio as gr
8
+ import torch
9
+ import torch.nn as nn
10
+ import subprocess
11
+ import os
12
+ import sys
13
+ import pickle
14
+ import tempfile
15
+ import io
16
+ import numpy as np
17
+ from scipy.io import wavfile
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ # Add text2midi model to path
21
+ sys.path.insert(0, "text2midi_repo")
22
+
23
+ repo_id = "amaai-lab/text2midi"
24
+
25
+ # Detect device
26
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
+ print(f"Using device: {device}")
28
+
29
+ # Global model variables
30
+ text2midi_model = None
31
+ midi_tokenizer = None
32
+ text_tokenizer = None
33
+
34
+ def load_text2midi_model():
35
+ """Load the text2midi model and tokenizers."""
36
+ global text2midi_model, midi_tokenizer, text_tokenizer
37
+
38
+ try:
39
+ from model.transformer_model import Transformer
40
+ from transformers import T5Tokenizer
41
+
42
+ print("Loading text2midi model...")
43
+
44
+ # Download model files
45
+ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
46
+ tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
47
+
48
+ print(f"Model path: {model_path}")
49
+ print(f"Tokenizer path: {tokenizer_path}")
50
+
51
+ # Load MIDI tokenizer
52
+ with open(tokenizer_path, "rb") as f:
53
+ midi_tokenizer = pickle.load(f)
54
+
55
+ vocab_size = len(midi_tokenizer)
56
+ print(f"Vocab size: {vocab_size}")
57
+
58
+ # Initialize and load model
59
+ text2midi_model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
60
+ text2midi_model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
61
+ text2midi_model.to(device)
62
+ text2midi_model.eval()
63
+
64
+ # Load T5 tokenizer for text encoding
65
+ text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
66
+
67
+ print("Text2midi model loaded successfully!")
68
+ return True
69
+
70
+ except Exception as e:
71
+ print(f"Warning: Could not load text2midi model: {e}")
72
+ import traceback
73
+ traceback.print_exc()
74
+ print("Falling back to simple MIDI generation...")
75
+ return False
76
+
77
+ # Try to load the model
78
+ MODEL_LOADED = load_text2midi_model()
79
+
80
+ def find_soundfont():
81
+ """Find a SoundFont file on the system."""
82
+ common_paths = [
83
+ "/usr/share/sounds/sf2/FluidR3_GM.sf2",
84
+ "/usr/share/soundfonts/FluidR3_GM.sf2",
85
+ "/usr/share/sounds/sf2/default-GM.sf2",
86
+ "FluidR3_GM.sf2",
87
+ ]
88
+ for path in common_paths:
89
+ if os.path.exists(path):
90
+ return path
91
+ return None
92
+
93
+ SOUNDFONT_PATH = find_soundfont()
94
+ print(f"SoundFont: {SOUNDFONT_PATH or 'Not found'}")
95
+
96
+ def generate_midi_with_model(prompt: str, output_path: str, max_len: int = 512, temperature: float = 0.9):
97
+ """Generate MIDI using the text2midi model."""
98
+ global text2midi_model, midi_tokenizer, text_tokenizer
99
+
100
+ # Tokenize input text
101
+ inputs = text_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
102
+ input_ids = inputs.input_ids.to(device)
103
+ attention_mask = inputs.attention_mask.to(device)
104
+
105
+ # Generate MIDI tokens
106
+ with torch.no_grad():
107
+ output = text2midi_model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
108
+
109
+ output_list = output[0].tolist()
110
+
111
+ # Decode to MIDI
112
+ generated_midi = midi_tokenizer.decode(output_list)
113
+ generated_midi.dump_midi(output_path)
114
+
115
+ return output_path
116
+
117
+ def midi_to_audio_bytes(midi_path: str, sample_rate: int = 44100) -> tuple:
118
+ """
119
+ Convert MIDI to audio using FluidSynth, returning numpy array.
120
+ Uses stdout piping to avoid creating intermediate files.
121
+ """
122
+ if not SOUNDFONT_PATH:
123
+ return None
124
+
125
+ # Use FluidSynth to render MIDI to raw audio via stdout
126
+ # -T raw outputs raw audio, -F - outputs to stdout
127
+ result = subprocess.run([
128
+ "fluidsynth",
129
+ "-ni", # No interactive mode
130
+ "-T", "raw", # Output raw audio format
131
+ "-F", "-", # Output to stdout
132
+ "-r", str(sample_rate), # Sample rate
133
+ SOUNDFONT_PATH, # SoundFont file
134
+ midi_path, # MIDI file
135
+ ], capture_output=True, timeout=120)
136
+
137
+ if result.returncode != 0:
138
+ print(f"FluidSynth error: {result.stderr.decode()}")
139
+ return None
140
+
141
+ # Convert raw audio bytes to numpy array (16-bit signed, stereo)
142
+ audio_data = np.frombuffer(result.stdout, dtype=np.int16)
143
+
144
+ # FluidSynth outputs stereo by default, reshape if needed
145
+ if len(audio_data) > 0:
146
+ # Convert to float32 normalized [-1, 1] for Gradio
147
+ audio_float = audio_data.astype(np.float32) / 32768.0
148
+ return (sample_rate, audio_float)
149
+
150
+ return None
151
+
152
+ def generate_music(prompt: str, max_length: int = 512, temperature: float = 0.9):
153
+ """
154
+ Generate music from text prompt.
155
+ Returns audio data directly without saving files.
156
+
157
+ Args:
158
+ prompt: Text description of the music to generate
159
+ max_length: Maximum length in tokens (256-2048)
160
+ temperature: Generation temperature (0.1-1.5)
161
+
162
+ Returns:
163
+ Tuple of (audio_data, status_message)
164
+ audio_data is (sample_rate, numpy_array) for Gradio
165
+ """
166
+ if not prompt or not prompt.strip():
167
+ return None, "Please enter a music prompt"
168
+
169
+ try:
170
+ # Create temporary MIDI file (auto-deleted when closed)
171
+ with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as midi_file:
172
+ midi_path = midi_file.name
173
+
174
+ try:
175
+ # Generate MIDI using the model or fallback
176
+ if MODEL_LOADED:
177
+ status_prefix = "AI-generated"
178
+ generate_midi_with_model(prompt, midi_path, max_len=int(max_length), temperature=temperature)
179
+ else:
180
+ status_prefix = "Simple"
181
+ # Fallback: create simple MIDI
182
+ from midiutil import MIDIFile
183
+ midi = MIDIFile(1)
184
+ midi.addTempo(0, 0, 120)
185
+ notes = [60, 62, 64, 65, 67, 69, 71, 72]
186
+ for i, note in enumerate(notes[:min(len(prompt.split()), 8)]):
187
+ midi.addNote(0, 0, note, i, 1, 100)
188
+ with open(midi_path, "wb") as f:
189
+ midi.writeFile(f)
190
+
191
+ # Convert MIDI to audio
192
+ if SOUNDFONT_PATH:
193
+ audio_result = midi_to_audio_bytes(midi_path)
194
+ if audio_result:
195
+ return audio_result, f"{status_prefix} audio for: '{prompt[:50]}...'" if len(prompt) > 50 else f"{status_prefix} audio for: '{prompt}'"
196
+ else:
197
+ return None, f"Error: FluidSynth conversion failed"
198
+ else:
199
+ return None, f"Error: FluidSynth/SoundFont not available"
200
+
201
+ finally:
202
+ # Clean up temporary MIDI file
203
+ try:
204
+ os.unlink(midi_path)
205
+ except:
206
+ pass
207
+
208
+ except Exception as e:
209
+ import traceback
210
+ traceback.print_exc()
211
+ return None, f"Error: {str(e)}"
212
+
213
+ # Create Gradio interface with API enabled
214
+ with gr.Blocks(title="VR Music Generator") as demo:
215
+ gr.Markdown("# VR Game Music Generator")
216
+ gr.Markdown("Generate music from text descriptions using the text2midi AI model")
217
+
218
+ if not MODEL_LOADED:
219
+ gr.Markdown("**Warning:** AI model not loaded - using simple placeholder MIDI")
220
+ if not SOUNDFONT_PATH:
221
+ gr.Markdown("**Note:** FluidSynth not configured - audio generation disabled")
222
+
223
+ with gr.Row():
224
+ with gr.Column():
225
+ prompt_input = gr.Textbox(
226
+ label="Music Prompt",
227
+ placeholder="A cheerful pop song with piano and drums in C major at 120 BPM",
228
+ lines=3
229
+ )
230
+ with gr.Row():
231
+ max_length = gr.Slider(
232
+ minimum=256,
233
+ maximum=2048,
234
+ value=512,
235
+ step=256,
236
+ label="Max Length (tokens)"
237
+ )
238
+ temperature = gr.Slider(
239
+ minimum=0.1,
240
+ maximum=1.5,
241
+ value=0.9,
242
+ step=0.1,
243
+ label="Temperature"
244
+ )
245
+ generate_btn = gr.Button("Generate Music", variant="primary")
246
+
247
+ with gr.Column():
248
+ audio_output = gr.Audio(label="Generated Music", type="numpy")
249
+ status_output = gr.Textbox(label="Status", lines=2)
250
+
251
+ generate_btn.click(
252
+ fn=generate_music,
253
+ inputs=[prompt_input, max_length, temperature],
254
+ outputs=[audio_output, status_output],
255
+ api_name="generate" # Exposes as /api/generate endpoint
256
+ )
257
+
258
+ gr.Markdown("---")
259
+ gr.Markdown("""
260
+ **Example prompts:**
261
+ - A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, and drums
262
+ - An energetic electronic trance track with synth bass and drums at 138 BPM
263
+ - A slow and emotional classical piece featuring cello and violin in C minor
264
+ - A cinematic electronic soundtrack with an epic and dark atmosphere
265
+
266
+ **API Usage (for Unity):**
267
+ ```csharp
268
+ // POST to: https://YOUR-SPACE.hf.space/api/generate
269
+ // Body: {"data": ["your music prompt", 512, 0.9]}
270
+ // Response: {"data": [{"path": "audio_url", ...}, "status"]}
271
+ ```
272
+ """)
273
+
274
+ # For HuggingFace Spaces - launch() is called automatically
275
+ # For local testing, uncomment below:
276
+ # if __name__ == "__main__":
277
+ # demo.launch()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fluidsynth
2
+ fluid-soundfont-gm
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ huggingface-hub>=0.20.0
5
+ midiutil>=1.2.1
6
+ miditok>=3.0.0
7
+ scipy>=1.10.0
8
+ numpy>=1.24.0
9
+ tqdm>=4.65.0
text2midi_repo/.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
text2midi_repo/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 AMAAI Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
text2midi_repo/README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text2midi: Generating Symbolic Music from Captions
2
+
3
+ [Demo](https://huggingface.co/spaces/amaai-lab/text2midi) | [Model](https://huggingface.co/amaai-lab/text2midi) | [Examples](https://amaai-lab.github.io/Text2midi/) | [Paper](https://arxiv.org/abs/2412.16526) | [Dataset](https://huggingface.co/datasets/amaai-lab/MidiCaps)
4
+
5
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/amaai-lab/text2midi)
6
+ </div>
7
+
8
+ **text2midi** is the first end-to-end model for generating MIDI files from textual descriptions. By leveraging pretrained large language models and a powerful autoregressive transformer decoder, **text2midi** allows users to create symbolic music that aligns with detailed textual prompts, including musical attributes like chords, tempo, and style. The details of the model are described in [this paper](https://arxiv.org/abs/2412.16526).
9
+
10
+ 🔥 Live demo available on [HuggingFace Spaces](https://huggingface.co/spaces/amaai-lab/text2midi).
11
+
12
+ 🔥 Update: Text2midi has been accepted at AAAI!
13
+
14
+ <div align="center">
15
+ <img src="text2midi_architecture.jpg" width="500"/>
16
+ </div>
17
+
18
+ ## Quickstart Guide
19
+
20
+ Generate symbolic music from a text prompt:
21
+
22
+ ```python
23
+ import pickle
24
+ import torch
25
+ import torch.nn as nn
26
+ from transformers import T5Tokenizer
27
+ from model.transformer_model import Transformer
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ repo_id = "amaai-lab/text2midi"
31
+ # Download the model.bin file
32
+ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
33
+ # Download the vocab_remi.pkl file
34
+ tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
35
+
36
+ if torch.cuda.is_available():
37
+ device = 'cuda'
38
+ elif torch.backends.mps.is_available():
39
+ device = 'mps'
40
+ else:
41
+ device = 'cpu'
42
+
43
+ print(f"Using device: {device}")
44
+
45
+ # Load the tokenizer dictionary
46
+ with open(tokenizer_path, "rb") as f:
47
+ r_tokenizer = pickle.load(f)
48
+
49
+ # Get the vocab size
50
+ vocab_size = len(r_tokenizer)
51
+ print("Vocab size: ", vocab_size)
52
+ model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
53
+ model.load_state_dict(torch.load(model_path, map_location=device))
54
+ model.eval()
55
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
56
+
57
+ print('Model loaded.')
58
+
59
+
60
+ # Enter the text prompt and tokenize it
61
+ src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
62
+ print('Generating for prompt: ' + src)
63
+
64
+ inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
65
+ input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
66
+ input_ids = input_ids.to(device)
67
+ attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
68
+ attention_mask = attention_mask.to(device)
69
+
70
+ # Generate the midi
71
+ output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
72
+ output_list = output[0].tolist()
73
+ generated_midi = r_tokenizer.decode(output_list)
74
+ generated_midi.dump_midi("output.mid")
75
+ ```
76
+
77
+ ## Installation
78
+
79
+ If you have CUDA supported machine:
80
+ ```bash
81
+ git clone https://github.com/AMAAI-Lab/text2midi
82
+ cd text2midi
83
+ pip install -r requirements.txt
84
+ ```
85
+ Alternatively, if you have MPS supported machine:
86
+ ```bash
87
+ git clone https://github.com/AMAAI-Lab/text2midi
88
+ cd text2midi
89
+ pip install -r requirements-mac.txt
90
+ ```
91
+
92
+ ## Datasets
93
+
94
+ The model was trained using two datasets: [SymphonyNet](https://symphonynet.github.io/) for semi-supervised pretraining and MidiCaps for finetuning towards MIDI generation from captions.
95
+ The [MidiCaps dataset](https://huggingface.co/datasets/amaai-lab/MidiCaps) is a large-scale dataset of 168k MIDI files paired with rich text captions. These captions contain musical attributes such as key, tempo, style, and mood, making it ideal for text-to-MIDI generation tasks as described in [this paper](https://arxiv.org/abs/2406.02255).
96
+
97
+
98
+ ## Citation
99
+ If you use text2midi in your research, please cite:
100
+ ```
101
+ @inproceedings{bhandari2025text2midi,
102
+ title={text2midi: Generating Symbolic Music from Captions},
103
+ author={Keshav Bhandari and Abhinaba Roy and Kyra Wang and Geeta Puri and Simon Colton and Dorien Herremans},
104
+ booktitle={Proceedings of the 39th AAAI Conference on Artificial Intelligence (AAAI 2025)},
105
+ year={2025}
106
+ }
107
+ ```
108
+
109
+ ## Results of the Listening Study
110
+
111
+ Each question is rated on a Likert scale from 1 (very bad) to 7 (very good). The table shows the average ratings per question for each group of participants.
112
+
113
+ | Question | MidiCaps | text2midi | MuseCoco |
114
+ |---------------------|----------|-----------|----------|
115
+ | Musical Quality | 5.79 | 4.62 | 4.40 |
116
+ | Overall Matching | 5.42 | 4.67 | 4.07 |
117
+ | Genre Matching | 5.54 | 4.98 | 4.40 |
118
+ | Mood Matching | 5.70 | 5.00 | 4.32 |
119
+ | Key Matching | 4.61 | 3.64 | 3.36 |
120
+ | Chord Matching | 3.20 | 2.50 | 2.00 |
121
+ | Tempo Matching | 5.89 | 5.42 | 4.94 |
122
+
123
+
124
+ ## Objective Evaluations
125
+ Results of objective evaluation for *all* of MidiCaps test set. Please not we have improved from all the numbers written in the paper (the numbers in paper are on a small subset of MidiCaps test set).
126
+
127
+ | Metric | text2midi | MidiCaps | MuseCoco |
128
+ |---------------------|-----------|----------|----------|
129
+ | CR ↑ | 2.31 | 3.43 | 2.12 |
130
+ | CLAP ↑ | 0.22 | 0.26 | 0.21 |
131
+ | TB (%) ↑ | 39.70 | - | 21.71 |
132
+ | TBT (%) ↑ | 65.80 | - | 54.63 |
133
+ | CK (%) ↑ | 33.60 | - | 13.70 |
134
+ | CKD (%) ↑ | 35.60 | - | 14.59 |
135
+
136
+ **Note**:
137
+ CR = Compression ratio
138
+ CLAP = CLAP score
139
+ TB = Tempo Bin
140
+ TBT = Tempo Bin with Tolerance
141
+ CK = Correct Key
142
+ CKD = Correct Key with Duplicates
143
+ ↑ = Higher score is better.
144
+
145
+ ## Training
146
+ To train text2midi, we recommend using accelerate for multi-GPU support. First, configure accelerate by running:
147
+ ```bash
148
+ accelerate config
149
+ ```
150
+
151
+ Then, use the following command to start training:
152
+ ```bash
153
+ accelerate launch --multi_gpu --num_processes=4 train_accelerate.py --config ../config.yaml
154
+ ```
155
+
156
+ ## Inference
157
+ We support inference on CUDA, MPS and cpu. Please make sure you have pip installed the correct requirement file (requirments.txt for CUDA, requirements-mac.txt for MPS)
158
+ ```bash
159
+ python model/transformer_model.py --caption <your intended descriptions>
160
+ ```
161
+
162
+
text2midi_repo/artifacts/vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b15330f5ab9c2cd32d359bcc64a1de320a7dc1227180a7658fd0b8f2d35e12c
3
+ size 239637
text2midi_repo/artifacts/vocab_remi.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:877d4511d6b9d5eea1c706199fe13a0de3d984c8f5d09c75d727ffe7f6f54ee6
3
+ size 27256
text2midi_repo/captions/captions.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c25b36b6196618ff79f111e24c52b97d0bde9e1b47d2c596650944ebe6dcac5
3
+ size 69068459
text2midi_repo/configs/config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ text2midi_model:
3
+ decoder_max_sequence_length: 2048
4
+ decoder_num_layers: 18
5
+ decoder_num_heads: 8
6
+ decoder_d_model: 768
7
+ decoder_intermediate_size: 1024
8
+ use_moe: False
9
+ num_experts: 4
10
+ use_deepspeed: False
11
+ use_accelerate: True
12
+
13
+
14
+ training:
15
+ text2midi_model:
16
+ epochs: 140
17
+ batch_size: 1
18
+ learning_rate: 0.000001
19
+ weight_decay: 0.01
20
+ gradient_accumulation_steps: 4
21
+ with_tracking: True
22
+ checkpointing_steps: epoch
23
+ report_to: wandb
24
+ output_dir: /root/output_test_new
25
+ per_device_train_batch_size: 32
26
+ use_scheduler: True
27
+ lr_scheduler_type: cosine #choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
28
+ num_warmup_steps: 100
29
+ save_every: 5
30
+ max_train_steps: None
31
+ scheduled_sampling: False
32
+ epsilon: 0
33
+ c: -0.0161
34
+ k: -0.312
35
+
36
+ raw_data:
37
+ caption_dataset_path: /root/captions/train.json
38
+ raw_data_folders:
39
+ lmd:
40
+ folder_path: /import/c4dm-datasets-ext/lakhmidi
41
+ file_extension: midi
42
+ symphonynet:
43
+ folder_path: /root/text2midi/data/symphonynet/data/SymphonyNet_Dataset
44
+ file_extension: mid
45
+ maestro:
46
+ folder_path: /import/c4dm-datasets/maestro-v3.0.0
47
+ file_extension: midi
48
+ pop909:
49
+ folder_path: /import/c4dm-datasets-ext/POP909
50
+ file_extension: mid
51
+ pijama:
52
+ folder_path: /import/c4dm-datasets/PiJAMA/data/midi
53
+ file_extension: midi
54
+ midicaps:
55
+ folder_path: /root/data
56
+ file_extension: mid
57
+
58
+
59
+ deepspeed_config:
60
+ deepspeed_config_path: /root/test/text2midi/configs/ds_config.json
61
+
62
+ artifact_folder: ../artifacts
text2midi_repo/configs/ds_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": 1,
3
+ "gradient_accumulation_steps": 1,
4
+ "optimizer": {
5
+ "type": "Adam",
6
+ "params": {
7
+ "lr": 1e-4
8
+ }
9
+ },
10
+ "bf16": {
11
+ "enabled": true
12
+ },
13
+ "zero_optimization": {
14
+ "stage": 1,
15
+ "offload_optimizer": {
16
+ "device": "cpu",
17
+ "pin_memory": true
18
+ },
19
+ "offload_param": {
20
+ "device": "cpu",
21
+ "pin_memory": true
22
+ },
23
+ "overlap_comm": true,
24
+ "contiguous_gradients": true,
25
+ "sub_group_size": 1e9
26
+ },
27
+ "activation_checkpointing": {
28
+ "partition_activations": true,
29
+ "number_checkpoints": null,
30
+ "contiguous_memory_optimization":true,
31
+ "cpu_checkpointing": true
32
+ }
33
+ }
text2midi_repo/model/__pycache__/transformer_model.cpython-314.pyc ADDED
Binary file (77.4 kB). View file
 
text2midi_repo/model/build_vocab.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ import argparse
4
+ import pickle
5
+ import glob
6
+ import numpy as np
7
+ import json
8
+ from tqdm import tqdm
9
+ import random
10
+ from copy import deepcopy
11
+ import sys
12
+ import pickle
13
+
14
+ # Parse command line arguments
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
17
+ help="Path to the config file")
18
+ args = parser.parse_args()
19
+
20
+ # Load config file
21
+ with open(args.config, 'r') as f:
22
+ configs = yaml.safe_load(f)
23
+
24
+ artifact_folder = configs["artifact_folder"]
25
+ raw_data_folders = configs["raw_data"]["raw_data_folders"]
26
+
27
+
28
+ # Build the vocabulary
29
+ vocab = {}
30
+
31
+ instruments = ['piano', 'chromatic', 'organ', 'guitar', 'bass', 'strings', 'ensemble', 'brass', 'reed', 'pipe', 'synth_lead', 'synth_pad', 'synth_effect', 'ethnic', 'percussive', 'sfx', 'drum']
32
+
33
+ # Special tokens
34
+ for i in instruments:
35
+ vocab[('prefix', 'instrument', i)] = len(vocab) + 1
36
+
37
+ # MIDI velocity range from 0 to 127
38
+ velocity = [0, 15, 30, 45, 60, 75, 90, 105, 120, 127]
39
+ # MIDI pitch range from 0 to 127
40
+ midi_pitch = list(range(0, 128))
41
+ # Onsets are quantized in 10 milliseconds up to 5 seconds
42
+ onset = list(range(0, 5001, 10))
43
+ duration = list(range(0, 5001, 10))
44
+
45
+ # Add the instrument tokens to the vocabulary
46
+ for v in velocity:
47
+ for i in instruments:
48
+ for p in midi_pitch:
49
+ if i == "drum":
50
+ continue
51
+ else:
52
+ vocab[(i, p, v)] = len(vocab) + 1
53
+
54
+ for p in midi_pitch:
55
+ vocab[("drum", p)] = len(vocab) + 1
56
+
57
+ for o in onset:
58
+ vocab[("onset", o)] = len(vocab) + 1
59
+ for d in duration:
60
+ vocab[("dur", d)] = len(vocab) + 1
61
+
62
+ vocab["<T>"] = len(vocab) + 1
63
+ vocab["<D>"] = len(vocab) + 1
64
+ vocab["<U>"] = len(vocab) + 1
65
+ vocab["<SS>"] = len(vocab) + 1
66
+ print('vocab[<ss>]', vocab['<SS>'])
67
+ vocab["<S>"] = len(vocab) + 1
68
+ vocab["<E>"] = len(vocab) + 1
69
+ vocab["SEP"] = len(vocab) + 1
70
+
71
+ # Print the vocabulary length
72
+ print(f"Vocabulary length: {len(vocab)}")
73
+
74
+ # Save the vocabulary
75
+ vocab_path = os.path.join(artifact_folder, "vocab.pkl")
76
+ with open(vocab_path, 'wb') as f:
77
+ pickle.dump(vocab, f)
78
+
79
+ print(f"Vocabulary saved to {vocab_path}")
text2midi_repo/model/build_vocab_remi.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ import argparse
4
+ import pickle
5
+ import glob
6
+ import numpy as np
7
+ import json
8
+ from tqdm import tqdm
9
+ import random
10
+ from copy import deepcopy
11
+ import sys
12
+ import pickle
13
+ from miditok import REMI, TokenizerConfig # here we choose to use REMI
14
+ import jsonlines
15
+
16
+ # Parse command line arguments
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
19
+ help="Path to the config file")
20
+ args = parser.parse_args()
21
+
22
+ # Load config file
23
+ with open(args.config, 'r') as f:
24
+ configs = yaml.safe_load(f)
25
+
26
+ artifact_folder = configs["artifact_folder"]
27
+ raw_data_folders = configs["raw_data"]["raw_data_folders"]
28
+ caption_dataset_path = configs["raw_data"]["caption_dataset_path"]
29
+ dataset_path = configs["raw_data"]["raw_data_folders"]["lmd"]["folder_path"]
30
+
31
+ # Our parameters
32
+ BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
33
+ TOKENIZER_PARAMS = {
34
+ "pitch_range": (21, 109),
35
+ "beat_res": BEAT_RES,
36
+ "num_velocities": 32,
37
+ "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
38
+ "use_chords": False,
39
+ "use_rests": False,
40
+ "use_tempos": True,
41
+ "use_time_signatures": True,
42
+ "use_programs": True,
43
+ "num_tempos": 32, # number of tempo bins
44
+ "tempo_range": (40, 250), # (min, max)
45
+ }
46
+ config = TokenizerConfig(**TOKENIZER_PARAMS)
47
+
48
+ # Creates the tokenizer
49
+ tokenizer = REMI(config)
50
+
51
+ # Load the caption dataset
52
+ with jsonlines.open(caption_dataset_path) as reader:
53
+ captions = list(reader)
54
+
55
+ midi_paths = [os.path.join(dataset_path, captions[i]['location']) for i in range(len(captions))][0:30000]
56
+
57
+ # Builds the vocabulary with BPE
58
+ # vocab_size = 30000
59
+ # tokenizer.train(vocab_size=vocab_size, files_paths=midi_paths)
60
+
61
+ # Print the vocabulary length
62
+ print(f"Vocabulary length: {tokenizer.vocab_size}")
63
+
64
+ # Save the vocabulary
65
+ vocab_path = os.path.join(artifact_folder, "vocab_remi.pkl")
66
+ with open(vocab_path, 'wb') as f:
67
+ pickle.dump(tokenizer, f)
68
+
69
+ print(f"Vocabulary saved to {vocab_path}")
text2midi_repo/model/data_loader.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aria.data.midi import MidiDict
2
+ # from aria.tokenizer import AbsTokenizer
3
+ # aria_tokenizer = AbsTokenizer()
4
+ import yaml
5
+ import jsonlines
6
+ import glob
7
+ import random
8
+ import os
9
+ import sys
10
+ import pickle
11
+ import json
12
+ import argparse
13
+ import numpy as np
14
+ from copy import deepcopy
15
+ from torch.utils.data import Dataset
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from transformers import T5Tokenizer
19
+ from spacy.lang.en import English
20
+
21
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
22
+ sys.path.append(os.path.dirname(SCRIPT_DIR))
23
+
24
+
25
+ class Text2MusicDataset(Dataset):
26
+ def __init__(self, configs, captions, aria_tokenizer, mode="train", shuffle = False):
27
+ self.mode = mode
28
+ self.captions = captions
29
+ if shuffle:
30
+ random.shuffle(self.captions)
31
+
32
+ # Path to dataset
33
+ self.dataset_path = configs['raw_data']['raw_data_folders']['midicaps']['folder_path']
34
+
35
+ # Artifact folder
36
+ self.artifact_folder = configs['artifact_folder']
37
+ # Load encoder tokenizer json file dictionary
38
+ tokenizer_filepath = os.path.join(self.artifact_folder, "vocab.pkl")
39
+ self.aria_tokenizer = aria_tokenizer #AbsTokenizer()
40
+ # Load the pickled tokenizer dictionary
41
+ with open(tokenizer_filepath, 'rb') as f:
42
+ self.tokenizer = pickle.load(f)
43
+
44
+ # Load the sentencizer
45
+ self.nlp = English()
46
+ self.nlp.add_pipe('sentencizer')
47
+
48
+ # Load the FLAN-T5 tokenizer and encoder
49
+ self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
50
+
51
+ # Get the maximum sequence length
52
+ self.decoder_max_sequence_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
53
+
54
+ # Print length of dataset
55
+ print("Length of dataset: ", len(self.captions))
56
+
57
+ def __len__(self):
58
+ return len(self.captions)
59
+
60
+ def __getitem__(self, idx):
61
+ caption = self.captions[idx]['caption']
62
+ midi_filepath = os.path.join(self.dataset_path, self.captions[idx]['location'])
63
+
64
+ # Read the MIDI file
65
+ midi = MidiDict.from_midi(midi_filepath)
66
+ if len(midi.note_msgs) == 0:
67
+ aria_tokenized_midi = ["<SS>", "<E>"]
68
+ else:
69
+ # Get the tokenized MIDI file
70
+ aria_tokenized_midi = self.aria_tokenizer.tokenize(midi)
71
+ # Add the start token
72
+ aria_tokenized_midi = ["<SS>"] + aria_tokenized_midi
73
+
74
+ # Drop a random number of sentences from the caption
75
+ do_drop = random.random() > 0.5
76
+ if do_drop:
77
+ sentences = list(self.nlp(caption).sents)
78
+ sent_length = len(sentences)
79
+ if sent_length<4:
80
+ how_many_to_drop = int(np.floor((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
81
+ else:
82
+ how_many_to_drop = int(np.ceil((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
83
+ which_to_drop = np.random.choice(sent_length, how_many_to_drop, replace=False)
84
+ new_sentences = [sentences[i] for i in range(sent_length) if i not in which_to_drop.tolist()]
85
+ new_sentences = " ".join([new_sentences[i].text for i in range(len(new_sentences))]) # combine sentences back with a space
86
+ else:
87
+ new_sentences = caption
88
+
89
+ # Tokenize the caption
90
+ inputs = self.t5_tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True)
91
+ input_ids = inputs['input_ids']
92
+ attention_mask = inputs['attention_mask']
93
+
94
+ # Tokenize the midi file
95
+ tokenized_midi = [self.tokenizer[token] for token in aria_tokenized_midi if token in self.tokenizer]
96
+
97
+ # Convert the tokenized MIDI file to a tensor and pad it to the maximum sequence length
98
+ if len(tokenized_midi) < self.decoder_max_sequence_length:
99
+ labels = F.pad(torch.tensor(tokenized_midi), (0, self.decoder_max_sequence_length - len(tokenized_midi))).to(torch.int64)
100
+ else:
101
+ labels = torch.tensor(tokenized_midi[-self.decoder_max_sequence_length:]).to(torch.int64)
102
+
103
+ return input_ids, attention_mask, labels
104
+
105
+ if __name__ == "__main__":
106
+ # Parse command line arguments
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument("--config", type=str, default=os.path.normpath("../configs/config.yaml"),
109
+ help="Path to the config file")
110
+ args = parser.parse_args()
111
+
112
+ # Load config file
113
+ with open(args.config, 'r') as f:
114
+ configs = yaml.safe_load(f)
115
+
116
+ caption_dataset_path = configs['raw_data']['caption_dataset_path']
117
+ # Load the caption dataset
118
+ with jsonlines.open(caption_dataset_path) as reader:
119
+ captions = list(reader)
120
+
121
+ # Load the dataset
122
+ dataset = Text2MusicDataset(configs, captions, mode="train", shuffle = True)
123
+ a,b,c = dataset[0]
124
+ print(c.shape)
text2midi_repo/model/data_loader_remi.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import jsonlines
3
+ import glob
4
+ import random
5
+ import os
6
+ import sys
7
+ import pickle
8
+ import json
9
+ import argparse
10
+ import numpy as np
11
+ from copy import deepcopy
12
+ from torch.utils.data import Dataset
13
+ import torch
14
+ from torch.nn import functional as F
15
+ from transformers import T5Tokenizer
16
+ from spacy.lang.en import English
17
+
18
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
19
+ sys.path.append(os.path.dirname(SCRIPT_DIR))
20
+
21
+
22
+ class Text2MusicDataset(Dataset):
23
+ def __init__(self, configs, captions, remi_tokenizer, mode="train", shuffle = False):
24
+ self.mode = mode
25
+ self.captions = captions
26
+ if shuffle:
27
+ random.shuffle(self.captions)
28
+
29
+ # Path to dataset
30
+ self.dataset_path = configs['raw_data']['raw_data_folders']['midicaps']['folder_path']
31
+
32
+ # Artifact folder
33
+ self.artifact_folder = configs['artifact_folder']
34
+ # Load encoder tokenizer json file dictionary
35
+ # tokenizer_filepath = os.path.join(self.artifact_folder, "vocab.pkl")
36
+ # Load the pickled tokenizer dictionary
37
+ # with open(tokenizer_filepath, 'rb') as f:
38
+ # self.tokenizer = pickle.load(f)
39
+
40
+ self.remi_tokenizer = remi_tokenizer
41
+
42
+ # Load the sentencizer
43
+ self.nlp = English()
44
+ self.nlp.add_pipe('sentencizer')
45
+
46
+ # Load the FLAN-T5 tokenizer and encoder
47
+ self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
48
+
49
+ # Get the maximum sequence length
50
+ self.decoder_max_sequence_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
51
+
52
+ # Print length of dataset
53
+ print("Length of dataset: ", len(self.captions))
54
+
55
+ def __len__(self):
56
+ return len(self.captions)
57
+
58
+ def __getitem__(self, idx):
59
+ caption = self.captions[idx]['caption']
60
+ midi_filepath = os.path.join(self.dataset_path, self.captions[idx]['location'])
61
+ # print(f'midi filepath: {midi_filepath}')
62
+ # Read the MIDI file
63
+ tokens = self.remi_tokenizer(midi_filepath)
64
+
65
+ if len(tokens.ids) == 0:
66
+ tokenized_midi = [self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]]
67
+ else:
68
+ tokenized_midi = [self.remi_tokenizer["BOS_None"]] + tokens.ids + [self.remi_tokenizer["EOS_None"]]
69
+
70
+ # Drop a random number of sentences from the caption
71
+ do_drop = random.random() > 0.5
72
+ if do_drop:
73
+ sentences = list(self.nlp(caption).sents)
74
+ sent_length = len(sentences)
75
+ if sent_length<4:
76
+ how_many_to_drop = int(np.floor((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
77
+ else:
78
+ how_many_to_drop = int(np.ceil((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
79
+ which_to_drop = np.random.choice(sent_length, how_many_to_drop, replace=False)
80
+ new_sentences = [sentences[i] for i in range(sent_length) if i not in which_to_drop.tolist()]
81
+ new_sentences = " ".join([new_sentences[i].text for i in range(len(new_sentences))]) # combine sentences back with a space
82
+ else:
83
+ new_sentences = caption
84
+
85
+ # Tokenize the caption
86
+ inputs = self.t5_tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True)
87
+ input_ids = inputs['input_ids']
88
+ attention_mask = inputs['attention_mask']
89
+
90
+ # Convert the tokenized MIDI file to a tensor and pad it to the maximum sequence length
91
+ if len(tokenized_midi) < self.decoder_max_sequence_length:
92
+ labels = F.pad(torch.tensor(tokenized_midi), (0, self.decoder_max_sequence_length - len(tokenized_midi))).to(torch.int64)
93
+ else:
94
+ labels = torch.tensor(tokenized_midi[0:self.decoder_max_sequence_length]).to(torch.int64)
95
+
96
+ return input_ids, attention_mask, labels
97
+
98
+ if __name__ == "__main__":
99
+ # Parse command line arguments
100
+ parser = argparse.ArgumentParser()
101
+ parser.add_argument("--config", type=str, default=os.path.normpath("../configs/config.yaml"),
102
+ help="Path to the config file")
103
+ args = parser.parse_args()
104
+
105
+ tokenizer_filepath = "../artifacts/vocab_remi.pkl"
106
+ # Load the tokenizer dictionary
107
+ with open(tokenizer_filepath, "rb") as f:
108
+ tokenizer = pickle.load(f)
109
+ bos_token_number = tokenizer["PAD_None"]
110
+ print(f"bos_token_number: {bos_token_number}")
111
+
112
+ # Load config file
113
+ with open(args.config, 'r') as f:
114
+ configs = yaml.safe_load(f)
115
+ caption_dataset_path = configs['raw_data']['caption_dataset_path']
116
+ # Load the caption dataset
117
+ with jsonlines.open(caption_dataset_path) as reader:
118
+ captions = list(reader)
119
+
120
+ # Load the dataset
121
+ dataset = Text2MusicDataset(configs, captions, remi_tokenizer=tokenizer, mode="train", shuffle = True)
122
+ a,b,c = dataset[0]
123
+ print(type(a))
124
+ generated_midi = tokenizer.decode(c)
125
+ print(type(generated_midi))
126
+ generated_midi.dump_midi("decoded_midi.mid")
text2midi_repo/model/dict_output.txt ADDED
The diff for this file is too large to render. See raw diff
 
text2midi_repo/model/train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"])
3
+ # import torch
4
+ # print("CUDA device count:", torch.cuda.device_count())
5
+ # print("CUDA current device:", torch.cuda.current_device())
6
+ # print("CUDA device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
7
+ # os.environ['CUDA_VISIBLE_DEVICES']="2,3"
8
+ from torch.cuda import is_available as cuda_available, is_bf16_supported
9
+ from torch.backends.mps import is_available as mps_available
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import yaml
13
+ import json
14
+ import pickle
15
+ import os
16
+ import random
17
+ import deepspeed
18
+ from tqdm import tqdm
19
+ import torch
20
+ from torch import Tensor, argmax
21
+ from evaluate import load as load_metric
22
+ import sys
23
+ import argparse
24
+ import jsonlines
25
+ from data_loader import Text2MusicDataset
26
+ from transformer_model import Transformer
27
+ from torch.utils.data import DataLoader
28
+
29
+ # Parse command line arguments
30
+ # parser = argparse.ArgumentParser()
31
+ # parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
32
+ # help="Path to the config file")
33
+ # parser = deepspeed.add_config_arguments(parser)
34
+ # args = parser.parse_args()
35
+ config_file = "../configs/config.yaml"
36
+ # Load config file
37
+ with open(config_file, 'r') as f: ##args.config
38
+ configs = yaml.safe_load(f)
39
+
40
+ batch_size = configs['training']['text2midi_model']['batch_size']
41
+ learning_rate = configs['training']['text2midi_model']['learning_rate']
42
+ epochs = configs['training']['text2midi_model']['epochs']
43
+
44
+ # Artifact folder
45
+ artifact_folder = configs['artifact_folder']
46
+ # Load encoder tokenizer json file dictionary
47
+ tokenizer_filepath = os.path.join(artifact_folder, "vocab.pkl")
48
+ # Load the tokenizer dictionary
49
+ with open(tokenizer_filepath, "rb") as f:
50
+ tokenizer = pickle.load(f)
51
+
52
+ # Get the vocab size
53
+ vocab_size = len(tokenizer)+1
54
+ print("Vocab size: ", vocab_size)
55
+
56
+ caption_dataset_path = configs['raw_data']['caption_dataset_path']
57
+ # Load the caption dataset
58
+ with jsonlines.open(caption_dataset_path) as reader:
59
+ captions = list(reader)
60
+
61
+
62
+ def collate_fn(batch):
63
+ """
64
+ Collate function for the DataLoader
65
+ :param batch: The batch
66
+ :return: The collated batch
67
+ """
68
+ input_ids = [item[0].squeeze(0) for item in batch]
69
+ # Pad or trim batch to the same length
70
+ input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
71
+ attention_mask = [item[1].squeeze(0) for item in batch]
72
+ # Pad or trim batch to the same length
73
+ attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
74
+ labels = [item[2].squeeze(0) for item in batch]
75
+ # Pad or trim batch to the same length
76
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
77
+ return input_ids, attention_mask, labels
78
+
79
+
80
+ # Load the dataset
81
+ dataset = Text2MusicDataset(configs, captions, mode="train", shuffle = True)
82
+ data_length = len(dataset)
83
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
84
+
85
+
86
+ # Create the encoder-decoder model
87
+ # Initialize the model
88
+ d_model = configs['model']['text2midi_model']['decoder_d_model'] # Model dimension (same as FLAN-T5 encoder output dimension)
89
+ nhead = configs['model']['text2midi_model']['decoder_num_heads'] # Number of heads in the multiheadattention models
90
+ num_layers = configs['model']['text2midi_model']['decoder_num_layers'] # Number of decoder layers
91
+ max_len = configs['model']['text2midi_model']['decoder_max_sequence_length'] # Maximum length of the input sequence
92
+ use_moe = configs['model']['text2midi_model']['use_moe'] # Use mixture of experts
93
+ num_experts = configs['model']['text2midi_model']['num_experts'] # Number of experts in the mixture of experts
94
+ dim_feedforward = configs['model']['text2midi_model']['decoder_intermediate_size'] # Dimension of the feedforward network model
95
+ use_deepspeed = configs['model']['text2midi_model']['use_deepspeed'] # Use deepspeed
96
+ if use_deepspeed:
97
+ ds_config = configs['deepspeed_config']['deepspeed_config_path']
98
+ import deepspeed
99
+ from deepspeed.accelerator import get_accelerator
100
+ local_rank = int(os.environ['LOCAL_RANK'])
101
+ device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1)
102
+ and get_accelerator().is_available() else torch.device("cpu"))
103
+ deepspeed.init_distributed(dist_backend='nccl')
104
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
105
+ torch.backends.cuda.enable_flash_sdp(False)
106
+ else:
107
+ device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
108
+ print(f"Device: {device}")
109
+
110
+ print_every = 10
111
+ model = Transformer(vocab_size, d_model, nhead, max_len, num_layers, dim_feedforward, use_moe, num_experts, device=device)
112
+ # Print number of parameters
113
+ num_params = sum(p.numel() for p in model.parameters())
114
+ print(f"Number of parameters: {num_params}")
115
+ # Print number of trainable parameters
116
+ num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
117
+ print(f"Number of trainable parameters: {num_trainable_params}")
118
+ if not use_deepspeed:
119
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
120
+ criterion = nn.CrossEntropyLoss()
121
+ torch.cuda.empty_cache()
122
+ def train_model(model, dataloader, criterion, num_epochs, optimizer=None, data_length=1000):
123
+ if use_deepspeed:
124
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
125
+ model, optimizer, _, _ = deepspeed.initialize(model=model,
126
+ optimizer=optimizer,
127
+ model_parameters=model.parameters(),
128
+ config=ds_config)
129
+ else:
130
+ model = model.to(device)
131
+ model.train()
132
+ for epoch in range(num_epochs):
133
+ total_loss = 0
134
+ with tqdm(total=int(data_length/batch_size), desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
135
+ for step, batch in enumerate(dataloader):
136
+ if use_deepspeed:
137
+ model.zero_grad()
138
+ else:
139
+ optimizer.zero_grad()
140
+
141
+ # Get the batch
142
+ encoder_input, attention_mask, tgt = batch
143
+ # print(encoder_input.shape)
144
+ encoder_input = encoder_input.to(device)
145
+ attention_mask = attention_mask.to(device)
146
+ tgt = tgt.to(device)
147
+
148
+ tgt_input = tgt[:, :-1]
149
+ tgt_output = tgt[:, 1:]
150
+
151
+ if use_moe:
152
+ outputs, aux_loss = model(encoder_input, attention_mask, tgt_input)
153
+ else:
154
+ outputs = model(encoder_input, attention_mask, tgt_input)
155
+ aux_loss = 0
156
+
157
+ loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_output.reshape(-1))
158
+ loss += aux_loss
159
+ if use_deepspeed:
160
+ model.backward(loss)
161
+ model.step()
162
+ else:
163
+ loss.backward()
164
+ optimizer.step()
165
+
166
+ total_loss += loss.item()
167
+ if step % print_every == 0:
168
+ pbar.set_postfix({"Loss": loss.item()})
169
+ pbar.update(1)
170
+
171
+ pbar.set_postfix({"Loss": total_loss / len(dataloader)})
172
+ pbar.update(1)
173
+
174
+ print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")
175
+
176
+
177
+ # Train the model
178
+ if use_deepspeed:
179
+ train_model(model, dataloader, criterion, num_epochs=epochs)
180
+ else:
181
+ train_model(model, dataloader, criterion, num_epochs=epochs, optimizer=optimizer, data_length=data_length)
182
+
183
+ # Save the trained model
184
+ torch.save(model.state_dict(), "transformer_decoder_remi_plus.pth")
185
+ print("Model saved as transformer_decoder_remi_plus.pth")
text2midi_repo/model/train_accelerate.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import yaml
5
+ import math
6
+ import time
7
+ from transformers import get_scheduler
8
+ import wandb
9
+ import pickle
10
+ import numpy as np
11
+ import json
12
+ import jsonlines
13
+ from tqdm import tqdm
14
+ import torch
15
+ from accelerate import DistributedDataParallelKwargs, Accelerator
16
+ from accelerate.logging import get_logger
17
+ from data_loader_remi import Text2MusicDataset
18
+ from transformer_model import Transformer
19
+ from torch.utils.data import DataLoader
20
+ import logging
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ # Load config file
25
+ config_file = "../configs/config.yaml"
26
+ with open(config_file, 'r') as f:
27
+ configs = yaml.safe_load(f)
28
+
29
+ batch_size = configs['training']['text2midi_model']['batch_size']
30
+ learning_rate = configs['training']['text2midi_model']['learning_rate']
31
+ epochs = configs['training']['text2midi_model']['epochs']
32
+ artifact_folder = configs['artifact_folder']
33
+ tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
34
+ with open(tokenizer_filepath, "rb") as f:
35
+ tokenizer = pickle.load(f)
36
+ vocab_size = len(tokenizer)
37
+ caption_dataset_path = configs['raw_data']['caption_dataset_path']
38
+
39
+ # Load the caption dataset
40
+ with jsonlines.open(caption_dataset_path) as reader:
41
+ captions = list(reader)
42
+ # captions = list(reader)
43
+
44
+ def collate_fn(batch):
45
+ input_ids = [item[0].squeeze(0) for item in batch]
46
+ input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
47
+ attention_mask = [item[1].squeeze(0) for item in batch]
48
+ attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
49
+ labels = [item[2].squeeze(0) for item in batch]
50
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
51
+ return input_ids, attention_mask, labels
52
+
53
+ d_model = configs['model']['text2midi_model']['decoder_d_model']
54
+ nhead = configs['model']['text2midi_model']['decoder_num_heads']
55
+ num_layers = configs['model']['text2midi_model']['decoder_num_layers']
56
+ max_len = configs['model']['text2midi_model']['decoder_max_sequence_length']
57
+ use_moe = configs['model']['text2midi_model']['use_moe']
58
+ num_experts = configs['model']['text2midi_model']['num_experts']
59
+ dim_feedforward = configs['model']['text2midi_model']['decoder_intermediate_size']
60
+ gradient_accumulation_steps = configs['training']['text2midi_model']['gradient_accumulation_steps']
61
+ use_scheduler = configs['training']['text2midi_model']['use_scheduler']
62
+ checkpointing_steps = configs['training']['text2midi_model']['checkpointing_steps']
63
+ lr_scheduler_type = configs['training']['text2midi_model']['lr_scheduler_type']
64
+ num_warmup_steps = configs['training']['text2midi_model']['num_warmup_steps']
65
+ max_train_steps = configs['training']['text2midi_model']['max_train_steps']
66
+ with_tracking = configs['training']['text2midi_model']['with_tracking']
67
+ report_to = configs['training']['text2midi_model']['report_to']
68
+ output_dir = configs['training']['text2midi_model']['output_dir']
69
+ per_device_train_batch_size = configs['training']['text2midi_model']['per_device_train_batch_size']
70
+ save_every = configs['training']['text2midi_model']['save_every']
71
+
72
+ accelerator_log_kwargs = {}
73
+ if with_tracking:
74
+ accelerator_log_kwargs["log_with"] = report_to
75
+ # Remove the logging_dir argument in case of error
76
+ accelerator_log_kwargs["logging_dir"] = output_dir
77
+ accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision='fp16', kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)], **accelerator_log_kwargs)
78
+ logging.basicConfig(
79
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
80
+ datefmt="%m/%d/%Y %H:%M:%S",
81
+ level=logging.INFO,
82
+ )
83
+ logger.info(accelerator.state, main_process_only=False)
84
+ if accelerator.is_main_process:
85
+ if output_dir is None or output_dir == "":
86
+ output_dir = "saved/" + str(int(time.time()))
87
+ if not os.path.exists("saved"):
88
+ os.makedirs("saved")
89
+ os.makedirs(output_dir, exist_ok=True)
90
+ elif output_dir is not None:
91
+ os.makedirs(output_dir, exist_ok=True)
92
+ os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
93
+ accelerator.project_configuration.automatic_checkpoint_naming = False
94
+ wandb.login()
95
+ wandb.init(project="Text-2-Midi", settings=wandb.Settings(init_timeout=120))
96
+ accelerator.wait_for_everyone()
97
+ device = accelerator.device
98
+
99
+ with accelerator.main_process_first():
100
+ dataset = Text2MusicDataset(configs, captions, remi_tokenizer=tokenizer, mode="train", shuffle=True)
101
+ dataloader = DataLoader(dataset, batch_size=per_device_train_batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn, drop_last=True)
102
+
103
+ model = Transformer(vocab_size, d_model, nhead, max_len, num_layers, dim_feedforward, use_moe, num_experts, device=device)
104
+ model.load_state_dict(torch.load('/root/output_test_new/epoch_68/pytorch_model.bin', map_location=device))
105
+ def count_parameters(model):
106
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
107
+
108
+ total_params = count_parameters(model)
109
+ print(f"Total number of trainable parameters: {total_params}")
110
+
111
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
112
+ overrode_max_train_steps = False
113
+ num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps)
114
+ print("num_update_steps_per_epoch", num_update_steps_per_epoch)
115
+ print("max_train_steps", max_train_steps)
116
+ if max_train_steps == 'None':
117
+ max_train_steps = epochs * num_update_steps_per_epoch
118
+ print("max_train_steps", max_train_steps)
119
+ overrode_max_train_steps = True
120
+ num_warmup_steps = 20000
121
+ elif isinstance(max_train_steps, str):
122
+ max_train_steps = int(max_train_steps)
123
+ lr_scheduler = get_scheduler(
124
+ name=lr_scheduler_type,
125
+ optimizer=optimizer,
126
+ num_warmup_steps=num_warmup_steps,
127
+ num_training_steps=max_train_steps,
128
+ )
129
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
130
+ dataloader = accelerator.prepare(dataloader)
131
+ if overrode_max_train_steps:
132
+ max_train_steps = epochs * num_update_steps_per_epoch
133
+ epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
134
+ # checkpointing_steps = checkpointing_steps if checkpointing_steps.isdigit() else None
135
+ total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps
136
+ logger.info("***** Running training *****")
137
+ logger.info(f" Num examples = {len(dataset)}")
138
+ logger.info(f" Num Epochs = {epochs}")
139
+ logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}")
140
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
141
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
142
+ logger.info(f" Total optimization steps = {max_train_steps}")
143
+
144
+ criterion = nn.CrossEntropyLoss()
145
+
146
+ def train_model_accelerate(model, dataloader, criterion, num_epochs, max_train_steps, optimizer=None, out_dir=None, checkpointing_steps='epoch', with_tracking=False, save_every=5, device='cpu'):
147
+ progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
148
+ completed_steps = 0
149
+ starting_epoch = 68
150
+ model = model.to(device)
151
+ model.train()
152
+ best_loss = np.inf
153
+ for epoch in range(starting_epoch, num_epochs):
154
+ total_loss = 0
155
+ for step, batch in enumerate(dataloader):
156
+ with accelerator.accumulate(model):
157
+ encoder_input, attention_mask, tgt = batch
158
+ encoder_input = encoder_input.to(device)
159
+ attention_mask = attention_mask.to(device)
160
+ tgt = tgt.to(device)
161
+ tgt_input = tgt[:, :-1]
162
+ tgt_output = tgt[:, 1:]
163
+ if use_moe:
164
+ outputs, aux_loss = model(encoder_input, attention_mask, tgt_input)
165
+ else:
166
+ outputs = model(encoder_input, attention_mask, tgt_input)
167
+ aux_loss = 0
168
+ loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_output.reshape(-1))
169
+ loss += aux_loss
170
+ total_loss += loss.detach().float()
171
+ accelerator.backward(loss)
172
+ optimizer.step()
173
+ lr_scheduler.step()
174
+ optimizer.zero_grad()
175
+ if accelerator.sync_gradients:
176
+ progress_bar.set_postfix({"Loss": loss.item()})
177
+ progress_bar.update(1)
178
+ completed_steps += 1
179
+ if accelerator.is_main_process:
180
+ result = {}
181
+ result["epoch"] = epoch+1
182
+ result["step"] = completed_steps
183
+ result["train_loss"] = round(total_loss.item()/(gradient_accumulation_steps*completed_steps),4)
184
+ wandb.log(result)
185
+ if isinstance(checkpointing_steps, int):
186
+ if completed_steps % checkpointing_steps == 0:
187
+ output_dir = f"step_{completed_steps }"
188
+ if out_dir is not None:
189
+ output_dir = os.path.join(out_dir, output_dir)
190
+ accelerator.save_state(output_dir)
191
+ if completed_steps >= max_train_steps:
192
+ break
193
+ if accelerator.is_main_process:
194
+ result = {}
195
+ result["epoch"] = epoch+1
196
+ result["step"] = completed_steps
197
+ result["train_loss"] = round(total_loss.item()/len(dataloader), 4)
198
+ result_string = "Epoch: {}, Loss Train: {}\n".format(epoch, result["train_loss"])
199
+ accelerator.print(result_string)
200
+ with open("{}/summary.jsonl".format(out_dir), "a") as f:
201
+ f.write(json.dumps(result) + "\n\n")
202
+ logger.info(result)
203
+ if accelerator.is_main_process:
204
+ if total_loss < best_loss:
205
+ best_loss = total_loss
206
+ save_checkpoint = True
207
+ else:
208
+ save_checkpoint = False
209
+ accelerator.wait_for_everyone()
210
+ if accelerator.is_main_process and checkpointing_steps == "best":
211
+ if save_checkpoint:
212
+ accelerator.save_state("{}/{}".format(out_dir, "best"))
213
+ if (epoch + 1) % save_every == 0:
214
+ logger.info("Saving checkpoint at epoch {}".format(epoch+1))
215
+ accelerator.save_state("{}/{}".format(out_dir, "epoch_" + str(epoch+1)))
216
+ if accelerator.is_main_process and checkpointing_steps == "epoch":
217
+ accelerator.save_state("{}/{}".format(out_dir, "epoch_" + str(epoch+1)))
218
+
219
+ train_model_accelerate(model, dataloader, criterion, num_epochs=epochs, max_train_steps=max_train_steps,
220
+ optimizer=optimizer, out_dir=output_dir, checkpointing_steps=checkpointing_steps,
221
+ with_tracking=with_tracking, save_every=save_every, device=device)
222
+
223
+ # torch.save(model.state_dict(), "transformer_decoder_remi_plus.pth")
224
+ # print("Model saved as transformer_decoder_remi_plus.pth")
text2midi_repo/model/train_hf.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.cuda import is_available as cuda_available, is_bf16_supported
3
+ from torch.backends.mps import is_available as mps_available
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import yaml
7
+ import json
8
+ import pickle
9
+ import os
10
+ import random
11
+ from tqdm import tqdm
12
+ from transformers import T5EncoderModel, BertModel, BertConfig, Trainer, TrainingArguments, PreTrainedModel, T5Config, T5EncoderModel, BertLMHeadModel
13
+ import torch
14
+ from torch import Tensor, argmax
15
+ from evaluate import load as load_metric
16
+ import sys
17
+ import argparse
18
+ import jsonlines
19
+ from data_loader_remi import Text2MusicDataset
20
+ from transformer_model import Transformer
21
+ from torch.utils.data import DataLoader
22
+
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
26
+ help="Path to the config file")
27
+ args = parser.parse_args()
28
+
29
+ # Load config file
30
+ with open(args.config, 'r') as f: ##args.config
31
+ configs = yaml.safe_load(f)
32
+
33
+ batch_size = configs['training']['text2midi_model']['batch_size']
34
+ learning_rate = configs['training']['text2midi_model']['learning_rate']
35
+ epochs = configs['training']['text2midi_model']['epochs']
36
+
37
+ # Artifact folder
38
+ artifact_folder = configs['artifact_folder']
39
+ # Load remi tokenizer
40
+ tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
41
+ # Load the tokenizer dictionary
42
+ with open(tokenizer_filepath, "rb") as f:
43
+ tokenizer = pickle.load(f)
44
+
45
+ # Get the vocab size
46
+ vocab_size = tokenizer.vocab_size + 1
47
+ print("Vocab size: ", vocab_size)
48
+
49
+ caption_dataset_path = configs['raw_data']['caption_dataset_path']
50
+ # Load the caption dataset
51
+ with jsonlines.open(caption_dataset_path) as reader:
52
+ captions = list(reader)
53
+
54
+
55
+ def collate_fn(batch):
56
+ """
57
+ Collate function for the DataLoader
58
+ :param batch: The batch
59
+ :return: The collated batch
60
+ """
61
+ input_ids = [item[0].squeeze(0) for item in batch]
62
+ # Pad or trim batch to the same length
63
+ input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
64
+ attention_mask = [item[1].squeeze(0) for item in batch]
65
+ # Pad or trim batch to the same length
66
+ attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
67
+ labels = [item[2].squeeze(0) for item in batch]
68
+ # Pad or trim batch to the same length
69
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
70
+ decoder_input_ids = labels[:, :-1].contiguous()
71
+ labels = labels[:, 1:].contiguous()
72
+ # return input_ids, attention_mask, labels
73
+ return {
74
+ 'input_ids': input_ids,
75
+ 'attention_mask': attention_mask,
76
+ 'decoder_input_ids': decoder_input_ids,
77
+ 'labels': labels
78
+ }
79
+
80
+ # Train test split captions
81
+ random.seed(444)
82
+ random.shuffle(captions)
83
+ train_size = int(0.8 * len(captions))
84
+ train_captions = captions[:train_size]
85
+ test_captions = captions[train_size:]
86
+
87
+ # Load the dataset
88
+ train_dataset = Text2MusicDataset(configs, train_captions, tokenizer, mode="train", shuffle = True)
89
+ print(f"Train Data length: {len(train_dataset)}")
90
+ test_dataset = Text2MusicDataset(configs, test_captions, tokenizer, mode="eval", shuffle = False)
91
+ print(f"Test Data length: {len(test_dataset)}")
92
+
93
+ # Dataloader
94
+ # train_dataset = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
95
+ # test_dataset = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
96
+
97
+ # Create the encoder-decoder model
98
+ class CustomEncoderDecoderModel(PreTrainedModel):
99
+ def __init__(self, encoder, decoder, encoder_config, decoder_config):
100
+ super().__init__(encoder_config)
101
+ self.encoder = encoder
102
+ self.decoder = decoder
103
+ self.encoder_config = encoder_config
104
+ self.decoder_config = decoder_config
105
+
106
+ def forward(self, input_ids, decoder_input_ids, attention_mask=None, decoder_attention_mask=None, labels=None, **kwargs):
107
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
108
+ encoder_hidden_states = encoder_outputs.last_hidden_state
109
+
110
+ # Assume the decoder can take encoder hidden states as inputs
111
+ output = self.decoder(
112
+ input_ids=decoder_input_ids,
113
+ attention_mask=decoder_attention_mask,
114
+ encoder_hidden_states=encoder_hidden_states,
115
+ encoder_attention_mask=attention_mask,
116
+ labels=labels
117
+ )
118
+
119
+ logits = output.logits
120
+
121
+ loss = output.loss
122
+
123
+ return {'loss': loss, 'logits': logits}
124
+
125
+ # Load the pre-trained FLAN T5 encoder and freeze its parameters
126
+ flan_t5_encoder = T5EncoderModel.from_pretrained('google/flan-t5-small')
127
+ for param in flan_t5_encoder.parameters():
128
+ param.requires_grad = False
129
+
130
+ # Load the configurations
131
+ encoder_config = T5Config.from_pretrained('google/flan-t5-small')
132
+
133
+ # Define a configuration for the BERT decoder
134
+ config_decoder = BertConfig()
135
+ config_decoder.vocab_size = vocab_size
136
+ config_decoder.max_position_embeddings = configs['model']['text2midi_model']['decoder_max_sequence_length']
137
+ config_decoder.max_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
138
+ config_decoder.bos_token_id = tokenizer["BOS_None"]
139
+ config_decoder.eos_token_id = tokenizer["EOS_None"]
140
+ config_decoder.pad_token_id = 0
141
+ config_decoder.num_hidden_layers = configs['model']['text2midi_model']['decoder_num_layers']
142
+ config_decoder.num_attention_heads = configs['model']['text2midi_model']['decoder_num_heads']
143
+ config_decoder.hidden_size = configs['model']['text2midi_model']['decoder_d_model']
144
+ config_decoder.intermediate_size = configs['model']['text2midi_model']['decoder_intermediate_size']
145
+
146
+ # set decoder config to causal lm
147
+ config_decoder.is_decoder = True
148
+ config_decoder.add_cross_attention = True
149
+ config_decoder.tie_encoder_decoder = False
150
+ config_decoder.tie_word_embeddings = False
151
+
152
+ # Create a BERT model based on the configuration
153
+ custom_decoder = BertLMHeadModel(config_decoder)
154
+
155
+ # Initialize the custom model
156
+ model = CustomEncoderDecoderModel(
157
+ encoder=flan_t5_encoder,
158
+ decoder=custom_decoder,
159
+ encoder_config=encoder_config,
160
+ decoder_config=config_decoder
161
+ )
162
+
163
+ # Print the number of parameters in the model
164
+ num_params = sum(p.numel() for p in model.parameters())
165
+ print(f"Number of parameters in the model: {num_params}")
166
+
167
+ # Create config for the Trainer
168
+ USE_CUDA = cuda_available()
169
+ print(f"USE_CUDA: {USE_CUDA}")
170
+ if not cuda_available():
171
+ FP16 = FP16_EVAL = BF16 = BF16_EVAL = False
172
+ elif is_bf16_supported():
173
+ BF16 = BF16_EVAL = True
174
+ FP16 = FP16_EVAL = False
175
+ else:
176
+ BF16 = BF16_EVAL = False
177
+ FP16 = FP16_EVAL = True
178
+ USE_MPS = not USE_CUDA and mps_available()
179
+
180
+ metrics = {metric: load_metric(metric) for metric in ["accuracy"]}
181
+
182
+ def compute_metrics(eval_pred):
183
+ """
184
+ Compute metrics for pretraining.
185
+
186
+ Must use preprocess_logits function that converts logits to predictions (argmax or sampling).
187
+
188
+ :param eval_pred: EvalPrediction containing predictions and labels
189
+ :return: metrics
190
+ """
191
+ predictions, labels = eval_pred
192
+ not_pad_mask = labels != 0
193
+ labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]
194
+ return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten())
195
+
196
+ def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:
197
+ """
198
+ Preprocess the logits before accumulating them during evaluation.
199
+
200
+ This allows to significantly reduce the memory usage and make the training tractable.
201
+ """
202
+ pred_ids = argmax(logits, dim=-1) # long dtype
203
+ return pred_ids
204
+
205
+ run_name = configs['training']['text2midi_model']['run_name']
206
+ model_dir = os.path.join(artifact_folder, run_name)
207
+ log_dir = os.path.join(model_dir, "logs")
208
+ # Clear the logs directory before training
209
+ os.system(f"rm -rf {log_dir}")
210
+
211
+ # Define the training arguments
212
+ training_args = TrainingArguments(
213
+ output_dir=model_dir,
214
+ per_device_train_batch_size=batch_size,
215
+ per_device_eval_batch_size=batch_size,
216
+ save_strategy="epoch", # "steps" or "epoch"
217
+ save_total_limit=1,
218
+ learning_rate=learning_rate,
219
+ lr_scheduler_type="cosine_with_restarts",
220
+ warmup_ratio=0.3,
221
+ max_grad_norm=3.0,
222
+ weight_decay= configs['training']['text2midi_model']['weight_decay'],
223
+ num_train_epochs=epochs,
224
+ evaluation_strategy="epoch",
225
+ gradient_accumulation_steps=configs['training']['text2midi_model']['gradient_accumulation_steps'],
226
+ # gradient_checkpointing=True,
227
+ optim="adafactor",
228
+ seed=444,
229
+ logging_strategy="steps",
230
+ logging_steps=10,
231
+ logging_dir=log_dir,
232
+ no_cuda=not USE_CUDA,
233
+ fp16=FP16,
234
+ fp16_full_eval=FP16_EVAL,
235
+ bf16=BF16,
236
+ bf16_full_eval=BF16_EVAL,
237
+ load_best_model_at_end=True,
238
+ # metric_for_best_model="loss",
239
+ greater_is_better=False,
240
+ report_to="tensorboard",
241
+ run_name=run_name,
242
+ push_to_hub=False,
243
+ dataloader_num_workers=5
244
+ )
245
+
246
+ # # Define the Trainer
247
+ # trainer = Trainer(
248
+ # model=model,
249
+ # args=training_args,
250
+ # train_dataset=train_dataset,
251
+ # eval_dataset=test_dataset,
252
+ # compute_metrics=compute_metrics,
253
+ # preprocess_logits_for_metrics=preprocess_logits,
254
+ # # callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
255
+ # )
256
+
257
+ class CustomTrainer(Trainer):
258
+ def get_train_dataloader(self):
259
+ return DataLoader(self.train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
260
+
261
+ def get_eval_dataloader(self, eval_dataset):
262
+ return DataLoader(eval_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
263
+
264
+ def get_test_dataloader(self, test_dataset):
265
+ return DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
266
+
267
+ # Define the Trainer
268
+ trainer = CustomTrainer(
269
+ model=model,
270
+ args=training_args,
271
+ train_dataset=train_dataset,
272
+ eval_dataset=test_dataset,
273
+ compute_metrics=compute_metrics,
274
+ preprocess_logits_for_metrics=preprocess_logits,
275
+ # callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
276
+ )
277
+
278
+ # Train and save the model
279
+ train_result = trainer.train()
280
+ trainer.save_model()
281
+ trainer.log_metrics("train", train_result.metrics)
282
+ trainer.save_metrics("train", train_result.metrics)
283
+ trainer.save_state()
text2midi_repo/model/transformer_model.py ADDED
@@ -0,0 +1,1509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from aria.tokenizer import AbsTokenizer
2
+ # aria_tokenizer = AbsTokenizer()
3
+ import copy
4
+ import json
5
+ from typing import Optional, Any, Union, Callable
6
+ import torch.multiprocessing as mp
7
+ from torch.nn import DataParallel
8
+ import jsonlines
9
+ import math
10
+ import time
11
+ import torch
12
+ import os
13
+ import warnings
14
+ from tqdm import tqdm
15
+ from torch import Tensor
16
+ # from aria.tokenizer import AbsTokenizer
17
+ import pickle
18
+ from torch.nn import Module, LayerNorm, Dropout, Linear
19
+ from torch.nn.modules.container import ModuleList
20
+ from torch.nn.modules.activation import MultiheadAttention
21
+ from torch.multiprocessing import Process, set_start_method
22
+ from torch.nn.init import xavier_uniform_
23
+ import torch.nn.functional as F
24
+ import torch.nn as nn
25
+
26
+ from st_moe_pytorch import MoE
27
+ from st_moe_pytorch import SparseMoEBlock
28
+
29
+ from einops import rearrange
30
+
31
+ from transformers import T5Tokenizer, T5EncoderModel
32
+
33
+ import sys
34
+ import torch.distributed as dist
35
+ from torch.nn.parallel import DistributedDataParallel as DDP
36
+ from torch.utils.data import DataLoader, Dataset
37
+
38
+ import torch.profiler
39
+
40
+ from accelerate import Accelerator
41
+ import argparse # Add this import
42
+
43
+ class CaptionDataset(Dataset):
44
+ def __init__(self, captions):
45
+ self.captions = captions
46
+
47
+ def __len__(self):
48
+ return len(self.captions)
49
+
50
+ def __getitem__(self, idx):
51
+ return self.captions[idx]
52
+
53
+ def custom_collate_fn(batch):
54
+ captions = [item['caption'] for item in batch]
55
+ locations = [item['location'] for item in batch]
56
+ return captions, locations
57
+
58
+ def ensure_log_dir_exists(log_dir):
59
+ if not os.path.exists(log_dir):
60
+ os.makedirs(log_dir)
61
+
62
+ __all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
63
+
64
+ def _generate_square_subsequent_mask(
65
+ sz: int,
66
+ device: Optional[torch.device] = None,
67
+ dtype: Optional[torch.dtype] = None,
68
+ ) -> Tensor:
69
+ r"""Generate a square causal mask for the sequence.
70
+
71
+ The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
72
+ """
73
+ if device is None:
74
+ device = torch.device('cpu')
75
+ if dtype is None:
76
+ dtype = torch.float32
77
+ return torch.triu(
78
+ torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
79
+ diagonal=1,
80
+ )
81
+
82
+
83
+ def _get_seq_len(
84
+ src: Tensor,
85
+ batch_first: bool
86
+ ) -> Optional[int]:
87
+
88
+ if src.is_nested:
89
+ return None
90
+ else:
91
+ src_size = src.size()
92
+ if len(src_size) == 2:
93
+ # unbatched: S, E
94
+ return src_size[0]
95
+ else:
96
+ # batched: B, S, E if batch_first else S, B, E
97
+ seq_len_pos = 1 if batch_first else 0
98
+ return src_size[seq_len_pos]
99
+
100
+
101
+ class PositionalEncoding(nn.Module):
102
+ r"""Inject some information about the relative or absolute position of the tokens in the sequence.
103
+ The positional encodings have the same dimension as the embeddings, so that the two can be summed.
104
+ Here, we use sine and cosine functions of different frequencies.
105
+ .. math:
106
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
107
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
108
+ \text{where pos is the word position and i is the embed idx)
109
+ Args:
110
+ d_model: the embed dim (required).
111
+ dropout: the dropout value (default=0.1).
112
+ max_len: the max. length of the incoming sequence (default=5000).
113
+ Examples:
114
+ >>> pos_encoder = PositionalEncoding(d_model)
115
+ """
116
+
117
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
118
+ super(PositionalEncoding, self).__init__()
119
+ self.dropout = nn.Dropout(p=dropout)
120
+
121
+ pe = torch.zeros(max_len, d_model)
122
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
123
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
124
+ pe[:, 0::2] = torch.sin(position * div_term)
125
+ pe[:, 1::2] = torch.cos(position * div_term)
126
+ pe = pe.unsqueeze(0).transpose(0, 1)
127
+ # self.register_buffer('pe', pe)
128
+ self.register_parameter('pe', nn.Parameter(pe, requires_grad=False))
129
+
130
+ def forward(self, x):
131
+ r"""Inputs of forward function
132
+ Args:
133
+ x: the sequence fed to the positional encoder model (required).
134
+ Shape:
135
+ x: [sequence length, batch size, embed dim]
136
+ output: [sequence length, batch size, embed dim]
137
+ Examples:
138
+ >>> output = pos_encoder(x)
139
+ """
140
+ x = x + self.pe[:x.size(0), :]
141
+ return self.dropout(x)
142
+
143
+
144
+ def precompute_freqs_cis(
145
+ seq_len: int,
146
+ n_elem: int,
147
+ base: int = 10000,
148
+ dtype: torch.dtype = torch.bfloat16,
149
+ ):
150
+ freqs = 1.0 / (
151
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
152
+ )
153
+ t = torch.arange(seq_len, device=freqs.device)
154
+ freqs = torch.outer(t, freqs)
155
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
156
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
157
+
158
+ return cache.to(dtype=dtype)
159
+
160
+
161
+ @torch.jit.script
162
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ In-place RoPE. Credits to Katherine Crowson:
165
+ x shape (b_sz, n_head, s_len, d_head).
166
+ cos, sin shape (s_len, d_head // 2).
167
+ """
168
+
169
+ x = x.permute(0, 2, 1, 3)
170
+ d = x.shape[-1] // 2
171
+ cos = freqs_cis[..., 0][None, :, None]
172
+ sin = freqs_cis[..., 1][None, :, None]
173
+ x1, x2 = x[..., :d], x[..., d : d * 2]
174
+ tmp = x1.clone()
175
+ # x1.mul_(cos).addcmul_(x2, sin, value=-1)
176
+ # x2.mul_(cos).addcmul_(tmp, sin, value=1) ##was throwing some error: RuntimeError: Output 0 of SliceBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
177
+ x1_new = x1.mul(cos) - x2.mul(sin)
178
+ x2_new = x2.mul(cos) + tmp.mul(sin)
179
+ x = torch.cat((x1_new, x2_new), dim=-1)
180
+ x = x.permute(0, 2, 1, 3)
181
+
182
+ return x
183
+
184
+
185
+ class MultiHeadSelfAttention(nn.Module):
186
+ r"""Multi-head self-attention module.
187
+
188
+ Args:
189
+ embed_dim (int): The input embedding dimension.
190
+ num_heads (int, optional): The number of attention heads (default: 4).
191
+ dropout (float, optional): The dropout probability (default: 0.1).
192
+ device (torch.device, optional): The device to use (default: None).
193
+ dtype (torch.dtype, optional): The data type to use (default: None).
194
+
195
+ Attributes:
196
+ dim_head (int): The dimension of each attention head.
197
+ scale (float): The scaling factor for attention scores.
198
+ heads (int): The number of attention heads.
199
+ to_qkv (nn.Linear): Linear layer for projecting input to query, key, and value.
200
+ to_out (nn.Linear): Linear layer for projecting attention output to the original embedding dimension.
201
+ dropout (nn.Dropout): Dropout layer.
202
+
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ embed_dim: int,
208
+ num_heads: int = 4,
209
+ dropout: float = 0.1,
210
+ batch_first: bool = True,
211
+ device: Optional[torch.device] = None,
212
+ dtype: Optional[torch.dtype] = None,
213
+ ):
214
+ factory_kwargs = {'device': device, 'dtype': dtype}
215
+ super().__init__()
216
+ self.embed_dim = embed_dim
217
+ self.batch_first = batch_first
218
+ self.dim_head = embed_dim // num_heads
219
+ self.scale = self.dim_head ** -0.5
220
+ self.heads = num_heads
221
+ hidden_dim = self.dim_head * num_heads
222
+ self.to_qkv = nn.Linear(embed_dim, hidden_dim * 3, bias=False, **factory_kwargs)
223
+ self.to_out = nn.Linear(hidden_dim, embed_dim, bias=False, **factory_kwargs)
224
+ self.dropout = nn.Dropout(dropout)
225
+
226
+ def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
227
+
228
+ r"""Forward pass of the multi-head self-attention module.
229
+
230
+ Args:
231
+ x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embed_dim).
232
+
233
+ Returns:
234
+ torch.Tensor: The output tensor of shape (batch_size, sequence_length, embed_dim).
235
+
236
+ """
237
+ if not self.batch_first:
238
+ x = x.transpose(0, 1)
239
+ b, n, _ = x.size()
240
+ q, k, v = torch.chunk(self.to_qkv(x), chunks=3, dim=-1)
241
+ q, k, v = map(lambda t: t.contiguous().view(b, self.heads, n, -1), (q, k, v))
242
+
243
+ self.freqs_cis = precompute_freqs_cis(
244
+ seq_len=n,
245
+ n_elem=self.embed_dim // self.heads,
246
+ base=10000,
247
+ dtype=x.dtype,
248
+ ).to(x.device)
249
+ freqs_cis = self.freqs_cis[: x.shape[1]]
250
+ # q = apply_rotary_emb(q, freqs_cis)
251
+ # k = apply_rotary_emb(k, freqs_cis)
252
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
253
+ out = out.contiguous().view(b, n, -1)
254
+ out = self.dropout(out)
255
+ return self.to_out(out)
256
+
257
+
258
+ class Transformer(Module):
259
+ r"""A transformer model.
260
+
261
+ User is able to modify the attributes as needed. The architecture
262
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
263
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
264
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
265
+ Processing Systems, pages 6000-6010.
266
+
267
+ Args:
268
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
269
+ nhead: the number of heads in the multiheadattention models (default=8).
270
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
271
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
272
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
273
+ use_moe: if True, use MoE instead of linear layer for feedforward network (default=False).
274
+ dropout: the dropout value (default=0.1).
275
+ activation: the activation function of encoder/decoder intermediate layer, can be a string
276
+ ("relu" or "gelu") or a unary callable. Default: relu
277
+ custom_encoder: custom encoder (default=None).
278
+ custom_decoder: custom decoder (default=None).
279
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
280
+ batch_first: If ``True``, then the input and output tensors are provided
281
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
282
+ norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
283
+ other attention and feedforward operations, otherwise after. Default: ``False`` (after).
284
+ bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
285
+ bias. Default: ``True``.
286
+
287
+ Examples::
288
+ >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
289
+ >>> src = torch.rand((32, 512))
290
+ >>> tgt = torch.rand((32, 512, 30000))
291
+ >>> out = transformer_model(src, tgt)
292
+
293
+ Note: A full example to apply nn.Transformer module for the word language model is available in
294
+ https://github.com/pytorch/examples/tree/master/word_language_model
295
+ """
296
+
297
+ def __init__(self, n_vocab: int = 30000, d_model: int = 512, nhead: int = 8, max_len: int = 5000,
298
+ num_decoder_layers: int = 6, dim_feedforward: int = 2048, use_moe: bool = False,
299
+ num_experts: int = 16, dropout: float = 0.1,
300
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
301
+ layer_norm_eps: float = 1e-5, batch_first: bool = True, norm_first: bool = False,
302
+ bias: bool = True, device=None, dtype=None) -> None:
303
+ factory_kwargs = {'device': device, 'dtype': dtype}
304
+ super().__init__()
305
+ torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
306
+
307
+ self.use_moe = use_moe
308
+
309
+ self.input_emb = nn.Embedding(n_vocab, d_model, **factory_kwargs)
310
+ self.pos_encoder = PositionalEncoding(d_model, dropout, max_len).to(device)
311
+
312
+ # Load the FLAN-T5 encoder
313
+ self.encoder = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device)
314
+ # Freeze the encoder
315
+ for param in self.encoder.parameters():
316
+ param.requires_grad = False
317
+
318
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, use_moe, num_experts, dropout,
319
+ activation, layer_norm_eps, batch_first, norm_first,
320
+ bias, **factory_kwargs)
321
+ decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
322
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, use_moe, decoder_norm)
323
+
324
+ self.projection = nn.Linear(d_model, n_vocab).to(device)
325
+
326
+ self._reset_parameters()
327
+
328
+ self.d_model = d_model
329
+ self.nhead = nhead
330
+
331
+ self.batch_first = batch_first
332
+
333
+ def forward(self, src: Tensor, src_mask: Tensor, tgt: Tensor, memory_mask: Optional[Tensor] = None,
334
+ memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: bool = True,
335
+ memory_is_causal: bool = False) -> Tensor:
336
+ r"""Take in and process masked source/target sequences.
337
+
338
+ .. note::
339
+
340
+ If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
341
+ not allowed to participate in the attention,
342
+ which is the opposite of the definition for :attr:`attn_mask`
343
+ in :func:`torch.nn.functional.scaled_dot_product_attention`.
344
+
345
+ Args:
346
+ src: the sequence to the encoder (required).
347
+ src_attn_mask: the attention mask for the src sequence (required).
348
+ tgt: the sequence to the decoder (required).
349
+ tgt_mask: the additive mask for the tgt sequence (optional).
350
+ memory_mask: the additive mask for the encoder output (optional).
351
+ tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
352
+ memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
353
+ tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
354
+ Default: ``None``; try to detect a causal mask.
355
+ Warning:
356
+ ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
357
+ the causal mask. Providing incorrect hints can result in
358
+ incorrect execution, including forward and backward
359
+ compatibility.
360
+ memory_is_causal: If specified, applies a causal mask as
361
+ ``memory_mask``.
362
+ Default: ``False``.
363
+ Warning:
364
+ ``memory_is_causal`` provides a hint that
365
+ ``memory_mask`` is the causal mask. Providing incorrect
366
+ hints can result in incorrect execution, including
367
+ forward and backward compatibility.
368
+
369
+ Shape:
370
+ - src: :math:`(S, S)` for unbatched input, :math:`(S, N)` if `batch_first=False` or
371
+ `(N, S)` if `batch_first=True`.
372
+ - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
373
+ - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
374
+ `(N, T, E)` if `batch_first=True`.
375
+ - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
376
+ - memory_mask: :math:`(T, S)`.
377
+ - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
378
+ - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
379
+ - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
380
+
381
+ Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
382
+ positions. If a BoolTensor is provided, positions with ``True``
383
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
384
+ is provided, it will be added to the attention weight.
385
+ [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
386
+ the attention. If a BoolTensor is provided, the positions with the
387
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
388
+
389
+ - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
390
+ `(N, T, E)` if `batch_first=True`.
391
+
392
+ Note: Due to the multi-head attention architecture in the transformer model,
393
+ the output sequence length of a transformer is same as the input sequence
394
+ (i.e. target) length of the decoder.
395
+
396
+ where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
397
+ batch size, :math:`E` is the feature number
398
+
399
+ Examples:
400
+ >>> # xdoctest: +SKIP
401
+ >>> output = transformer_model(src, tgt, src_mask=src_mask)
402
+ """
403
+ if src.dim() != tgt.dim():
404
+ raise RuntimeError("the number of dimensions in src and tgt must be equal")
405
+
406
+ memory = self.encoder(src, attention_mask=src_mask).last_hidden_state
407
+
408
+ tgt = self.input_emb(tgt) * math.sqrt(self.d_model)
409
+ tgt = self.pos_encoder(tgt)
410
+ # tgt = tgt + tgt_pos
411
+
412
+ if self.use_moe:
413
+ with torch.cuda.amp.autocast(enabled =False):
414
+ output, sum_total_aux_loss = self.decoder(tgt, memory, memory_mask=memory_mask,
415
+ memory_key_padding_mask=memory_key_padding_mask,
416
+ tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
417
+ else:
418
+ output = self.decoder(tgt, memory, memory_mask=memory_mask,
419
+ memory_key_padding_mask=memory_key_padding_mask,
420
+ tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
421
+
422
+ output = self.projection(output)
423
+ # output = F.log_softmax(output, dim=-1)
424
+
425
+ if self.use_moe:
426
+ return output, sum_total_aux_loss
427
+ else:
428
+ return output
429
+
430
+ def generate(self, src: Tensor, src_mask: Tensor, max_len: int = 100, temperature: float = 1.0):
431
+ ## ADD A START OF SEQUENCE TOKEN <SS> token to the src tensor
432
+ r"""Generate a sequence of tokens from the given inputs.
433
+
434
+ Args:
435
+ src: the sequence to the encoder (required).
436
+ src_mask: the attention mask for the src sequence (required).
437
+ max_len: the maximum length of the sequence to generate (default=100).
438
+ temperature: the temperature for the softmax (default=1.0).
439
+
440
+ Returns:
441
+ torch.Tensor: The generated sequence of tokens.
442
+
443
+ """
444
+ if src.dim() != 2:
445
+ raise RuntimeError("The src tensor should be 2-dimensional")
446
+ tgt_fin = torch.full((src.size(0), 1), 1, dtype=torch.long, device=src.device)
447
+ # values = [21631, 8, 10, 9, 6, 7, 17, 21632, 11474, 20626, 21151, 9426, 20627, 21143, 11476, 20640, 21143, 11477, 20655, 21145, 11476, 20669, 21145, 11477, 20683, 21145, 13527, 20697, 21146, 13529, 20712, 21145, 7013, 20769, 21143, 7006, 20769, 21143, 7006, 20769, 21141, 7009, 20769, 21143, 9426, 20797, 21144, 11474, 20797, 21173, 11476, 20812, 21144, 11477, 20826, 21145, 11476, 20840, 21145, 11477, 20855, 21145, 13527, 20869, 21144, 13529, 20883, 21143, 7006, 20940, 21139, 7013, 20940, 21140, 7006, 20940, 21147, 7009, 20940, 21147, 11474, 20969, 21144, 11474, 20969, 21170, 11476, 20983, 21144, 11477, 20997, 21145, 11476, 21012, 21144, 11477, 21026, 21144, 11479, 21040]
448
+ # values_tensor = torch.tensor(values, dtype=torch.long, device=src.device)
449
+ # tgt_fin = values_tensor.unsqueeze(0).repeat(src.size(0), 1)
450
+ for i in tqdm(range(max_len)):
451
+ max_index = tgt_fin.max()
452
+ # assert max_index < 21634, "tgt_fin contains index out of range. Adjust n_vocab or fix tgt_fin indices."
453
+ tgt = tgt_fin
454
+ if self.use_moe:
455
+ output, _ = self.froward(src, src_mask, tgt, memory_mask=None,
456
+ memory_key_padding_mask=None,
457
+ tgt_is_causal=True, memory_is_causal=False)
458
+ else:
459
+ output = self.forward(src, src_mask, tgt, memory_mask=None,
460
+ memory_key_padding_mask=None,
461
+ tgt_is_causal=True, memory_is_causal=False)
462
+ # logits = self.projection(output)
463
+ logits = output
464
+ output = F.log_softmax(logits/temperature, dim=-1)
465
+ output = output.view(-1, output.size(-1))
466
+ next_tokens = torch.multinomial(torch.exp(output), 1)[-1] # taking the last logit and adding to the sequence
467
+ tgt_fin = torch.cat((tgt_fin, next_tokens.unsqueeze(-1)), dim=1)
468
+ return tgt_fin[:, 1:]
469
+
470
+ @staticmethod
471
+ def generate_square_subsequent_mask(
472
+ sz: int,
473
+ device: Optional[torch.device] = None,
474
+ dtype: Optional[torch.dtype] = None,
475
+ ) -> Tensor:
476
+ r"""Generate a square causal mask for the sequence.
477
+
478
+ The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
479
+ """
480
+ return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
481
+
482
+
483
+ def _reset_parameters(self):
484
+ r"""Initiate parameters in the transformer model."""
485
+ for p in self.parameters():
486
+ if p.dim() > 1:
487
+ xavier_uniform_(p)
488
+
489
+
490
+
491
+
492
+ class TransformerEncoder(Module):
493
+ r"""TransformerEncoder is a stack of N encoder layers.
494
+
495
+ Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
496
+
497
+ Args:
498
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
499
+ num_layers: the number of sub-encoder-layers in the encoder (required).
500
+ norm: the layer normalization component (optional).
501
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
502
+ (and convert back on output). This will improve the overall performance of
503
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
504
+
505
+ Examples::
506
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
507
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
508
+ >>> src = torch.rand(10, 32, 512)
509
+ >>> out = transformer_encoder(src)
510
+ """
511
+
512
+ __constants__ = ['norm']
513
+
514
+ def __init__(
515
+ self,
516
+ encoder_layer: "TransformerEncoderLayer",
517
+ num_layers: int,
518
+ norm: Optional[Module] = None,
519
+ enable_nested_tensor: bool = True,
520
+ mask_check: bool = True
521
+ ) -> None:
522
+ super().__init__()
523
+ torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
524
+ self.layers = _get_clones(encoder_layer, num_layers)
525
+ self.num_layers = num_layers
526
+ self.norm = norm
527
+ # this attribute saves the value providedat object construction
528
+ self.enable_nested_tensor = enable_nested_tensor
529
+ # this attribute controls whether nested tensors are used
530
+ self.use_nested_tensor = enable_nested_tensor
531
+ self.mask_check = mask_check
532
+
533
+ enc_layer = "encoder_layer"
534
+ why_not_sparsity_fast_path = ''
535
+ if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
536
+ why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
537
+ elif encoder_layer.norm_first :
538
+ why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
539
+ elif not encoder_layer.self_attn.batch_first:
540
+ why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
541
+ "(use batch_first for better inference performance)")
542
+ elif not encoder_layer.self_attn._qkv_same_embed_dim:
543
+ why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
544
+ elif encoder_layer.self_attn.in_proj_bias is None:
545
+ why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
546
+ elif not encoder_layer.activation_relu_or_gelu:
547
+ why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
548
+ elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
549
+ why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
550
+ elif encoder_layer.self_attn.num_heads % 2 == 1:
551
+ why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
552
+
553
+ if enable_nested_tensor and why_not_sparsity_fast_path:
554
+ warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
555
+ self.use_nested_tensor = False
556
+
557
+
558
+
559
+ def forward(
560
+ self,
561
+ src: Tensor,
562
+ mask: Optional[Tensor] = None,
563
+ src_key_padding_mask: Optional[Tensor] = None,
564
+ is_causal: Optional[bool] = None) -> Tensor:
565
+ r"""Pass the input through the encoder layers in turn.
566
+
567
+ Args:
568
+ src: the sequence to the encoder (required).
569
+ mask: the mask for the src sequence (optional).
570
+ src_key_padding_mask: the mask for the src keys per batch (optional).
571
+ is_causal: If specified, applies a causal mask as ``mask``.
572
+ Default: ``None``; try to detect a causal mask.
573
+ Warning:
574
+ ``is_causal`` provides a hint that ``mask`` is the
575
+ causal mask. Providing incorrect hints can result in
576
+ incorrect execution, including forward and backward
577
+ compatibility.
578
+
579
+ Shape:
580
+ see the docs in :class:`~torch.nn.Transformer`.
581
+ """
582
+ src_key_padding_mask = F._canonical_mask(
583
+ mask=src_key_padding_mask,
584
+ mask_name="src_key_padding_mask",
585
+ other_type=F._none_or_dtype(mask),
586
+ other_name="mask",
587
+ target_type=src.dtype
588
+ )
589
+
590
+ mask = F._canonical_mask(
591
+ mask=mask,
592
+ mask_name="mask",
593
+ other_type=None,
594
+ other_name="",
595
+ target_type=src.dtype,
596
+ check_other=False,
597
+ )
598
+
599
+ output = src
600
+ convert_to_nested = False
601
+ first_layer = self.layers[0]
602
+ src_key_padding_mask_for_layers = src_key_padding_mask
603
+ why_not_sparsity_fast_path = ''
604
+ str_first_layer = "self.layers[0]"
605
+ batch_first = first_layer.self_attn.batch_first
606
+ # is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
607
+
608
+ # if not is_fastpath_enabled:
609
+ # why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
610
+ if not hasattr(self, "use_nested_tensor"):
611
+ why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
612
+ elif not self.use_nested_tensor:
613
+ why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
614
+ elif first_layer.training:
615
+ why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
616
+ elif not src.dim() == 3:
617
+ why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
618
+ elif src_key_padding_mask is None:
619
+ why_not_sparsity_fast_path = "src_key_padding_mask was None"
620
+ elif (((not hasattr(self, "mask_check")) or self.mask_check)
621
+ and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
622
+ why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
623
+ elif output.is_nested:
624
+ why_not_sparsity_fast_path = "NestedTensor input is not supported"
625
+ elif mask is not None:
626
+ why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
627
+ elif torch.is_autocast_enabled():
628
+ why_not_sparsity_fast_path = "autocast is enabled"
629
+
630
+ if not why_not_sparsity_fast_path:
631
+ tensor_args = (
632
+ src,
633
+ first_layer.self_attn.in_proj_weight,
634
+ first_layer.self_attn.in_proj_bias,
635
+ first_layer.self_attn.out_proj.weight,
636
+ first_layer.self_attn.out_proj.bias,
637
+ first_layer.norm1.weight,
638
+ first_layer.norm1.bias,
639
+ first_layer.norm2.weight,
640
+ first_layer.norm2.bias,
641
+ first_layer.linear1.weight,
642
+ first_layer.linear1.bias,
643
+ first_layer.linear2.weight,
644
+ first_layer.linear2.bias,
645
+ )
646
+ _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
647
+ if torch.overrides.has_torch_function(tensor_args):
648
+ why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
649
+ elif src.device.type not in _supported_device_type:
650
+ why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
651
+ elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
652
+ why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
653
+ "input/output projection weights or biases requires_grad")
654
+
655
+ if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
656
+ convert_to_nested = True
657
+ output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
658
+ src_key_padding_mask_for_layers = None
659
+
660
+ seq_len = _get_seq_len(src, batch_first)
661
+ is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
662
+
663
+ for mod in self.layers:
664
+ output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
665
+
666
+ if convert_to_nested:
667
+ output = output.to_padded_tensor(0., src.size())
668
+
669
+ if self.norm is not None:
670
+ output = self.norm(output)
671
+
672
+ return output
673
+
674
+
675
+
676
+
677
+ class TransformerDecoder(Module):
678
+ r"""TransformerDecoder is a stack of N decoder layers.
679
+
680
+ Args:
681
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
682
+ num_layers: the number of sub-decoder-layers in the decoder (required).
683
+ norm: the layer normalization component (optional).
684
+
685
+ Examples::
686
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
687
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
688
+ >>> memory = torch.rand(10, 32, 512)
689
+ >>> tgt = torch.rand(20, 32, 512)
690
+ >>> out = transformer_decoder(tgt, memory)
691
+ """
692
+
693
+ __constants__ = ['norm']
694
+
695
+ def __init__(
696
+ self,
697
+ decoder_layer: "TransformerDecoderLayer",
698
+ num_layers: int,
699
+ use_moe: bool = False,
700
+ norm: Optional[Module] = None
701
+ ) -> None:
702
+ super().__init__()
703
+ torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
704
+ self.layers = _get_clones(decoder_layer, num_layers)
705
+ self.num_layers = num_layers
706
+ self.use_moe = use_moe
707
+ self.norm = norm
708
+
709
+
710
+ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
711
+ memory_mask: Optional[Tensor] = None,
712
+ memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
713
+ memory_is_causal: bool = False) -> Tensor:
714
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
715
+
716
+ Args:
717
+ tgt: the sequence to the decoder (required).
718
+ memory: the sequence from the last layer of the encoder (required).
719
+ tgt_mask: the mask for the tgt sequence (optional).
720
+ memory_mask: the mask for the memory sequence (optional).
721
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
722
+ tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
723
+ Default: ``None``; try to detect a causal mask.
724
+ Warning:
725
+ ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
726
+ the causal mask. Providing incorrect hints can result in
727
+ incorrect execution, including forward and backward
728
+ compatibility.
729
+ memory_is_causal: If specified, applies a causal mask as
730
+ ``memory mask``.
731
+ Default: ``False``.
732
+ Warning:
733
+ ``memory_is_causal`` provides a hint that
734
+ ``memory_mask`` is the causal mask. Providing incorrect
735
+ hints can result in incorrect execution, including
736
+ forward and backward compatibility.
737
+
738
+ Shape:
739
+ see the docs in :class:`~torch.nn.Transformer`.
740
+ """
741
+ output = tgt
742
+
743
+ seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
744
+ tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
745
+ # print(f'target is causal: {tgt_is_causal}')
746
+
747
+ if self.use_moe:
748
+ sum_total_aux_loss = 0
749
+ for mod in self.layers:
750
+ output, total_aux_loss, balance_loss, router_z_loss = mod(output, memory,
751
+ memory_mask=memory_mask,
752
+ memory_key_padding_mask=memory_key_padding_mask,
753
+ tgt_is_causal=tgt_is_causal,
754
+ memory_is_causal=memory_is_causal)
755
+ sum_total_aux_loss += total_aux_loss
756
+ else:
757
+ for mod in self.layers:
758
+ output = mod(output, memory,
759
+ memory_mask=memory_mask,
760
+ memory_key_padding_mask=memory_key_padding_mask,
761
+ tgt_is_causal=tgt_is_causal,
762
+ memory_is_causal=memory_is_causal)
763
+
764
+ if self.norm is not None:
765
+ output = self.norm(output)
766
+
767
+ if self.use_moe:
768
+ return output, sum_total_aux_loss
769
+ else:
770
+ return output
771
+
772
+
773
+
774
+ class TransformerEncoderLayer(Module):
775
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
776
+
777
+ This standard encoder layer is based on the paper "Attention Is All You Need".
778
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
779
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
780
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
781
+ in a different way during application.
782
+
783
+ TransformerEncoderLayer can handle either traditional torch.tensor inputs,
784
+ or Nested Tensor inputs. Derived classes are expected to similarly accept
785
+ both input formats. (Not all combinations of inputs are currently
786
+ supported by TransformerEncoderLayer while Nested Tensor is in prototype
787
+ state.)
788
+
789
+ If you are implementing a custom layer, you may derive it either from
790
+ the Module or TransformerEncoderLayer class. If your custom layer
791
+ supports both torch.Tensors and Nested Tensors inputs, make its
792
+ implementation a derived class of TransformerEncoderLayer. If your custom
793
+ Layer supports only torch.Tensor inputs, derive its implementation from
794
+ Module.
795
+
796
+ Args:
797
+ d_model: the number of expected features in the input (required).
798
+ nhead: the number of heads in the multiheadattention models (required).
799
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
800
+ dropout: the dropout value (default=0.1).
801
+ activation: the activation function of the intermediate layer, can be a string
802
+ ("relu" or "gelu") or a unary callable. Default: relu
803
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
804
+ batch_first: If ``True``, then the input and output tensors are provided
805
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
806
+ norm_first: if ``True``, layer norm is done prior to attention and feedforward
807
+ operations, respectively. Otherwise it's done after. Default: ``False`` (after).
808
+ bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
809
+ bias. Default: ``True``.
810
+
811
+ Examples::
812
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
813
+ >>> src = torch.rand(10, 32, 512)
814
+ >>> out = encoder_layer(src)
815
+
816
+ Alternatively, when ``batch_first`` is ``True``:
817
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
818
+ >>> src = torch.rand(32, 10, 512)
819
+ >>> out = encoder_layer(src)
820
+
821
+ Fast path:
822
+ forward() will use a special optimized implementation described in
823
+ `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
824
+ conditions are met:
825
+
826
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
827
+ argument ``requires_grad``
828
+ - training is disabled (using ``.eval()``)
829
+ - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
830
+ - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
831
+ - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
832
+ - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
833
+ nor ``src_key_padding_mask`` is passed
834
+ - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
835
+ unless the caller has manually modified one without modifying the other)
836
+
837
+ If the optimized implementation is in use, a
838
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
839
+ passed for ``src`` to represent padding more efficiently than using a padding
840
+ mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
841
+ returned, and an additional speedup proportional to the fraction of the input that
842
+ is padding can be expected.
843
+
844
+ .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
845
+ https://arxiv.org/abs/2205.14135
846
+
847
+ """
848
+
849
+ __constants__ = ['norm_first']
850
+
851
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
852
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
853
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
854
+ bias: bool = True, device=None, dtype=None) -> None:
855
+ factory_kwargs = {'device': device, 'dtype': dtype}
856
+ super().__init__()
857
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
858
+ bias=bias, batch_first=batch_first,
859
+ **factory_kwargs)
860
+ # Implementation of Feedforward model
861
+ self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
862
+ self.dropout = Dropout(dropout)
863
+ self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
864
+
865
+ self.norm_first = norm_first
866
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
867
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
868
+ self.dropout1 = Dropout(dropout)
869
+ self.dropout2 = Dropout(dropout)
870
+
871
+ # Legacy string support for activation function.
872
+ if isinstance(activation, str):
873
+ activation = _get_activation_fn(activation)
874
+
875
+ # We can't test self.activation in forward() in TorchScript,
876
+ # so stash some information about it instead.
877
+ if activation is F.relu or isinstance(activation, torch.nn.ReLU):
878
+ self.activation_relu_or_gelu = 1
879
+ elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
880
+ self.activation_relu_or_gelu = 2
881
+ else:
882
+ self.activation_relu_or_gelu = 0
883
+ self.activation = activation
884
+
885
+ def __setstate__(self, state):
886
+ super().__setstate__(state)
887
+ if not hasattr(self, 'activation'):
888
+ self.activation = F.relu
889
+
890
+
891
+
892
+ def forward(
893
+ self,
894
+ src: Tensor,
895
+ src_mask: Optional[Tensor] = None,
896
+ src_key_padding_mask: Optional[Tensor] = None,
897
+ is_causal: bool = False) -> Tensor:
898
+ r"""Pass the input through the encoder layer.
899
+
900
+ Args:
901
+ src: the sequence to the encoder layer (required).
902
+ src_mask: the mask for the src sequence (optional).
903
+ src_key_padding_mask: the mask for the src keys per batch (optional).
904
+ is_causal: If specified, applies a causal mask as ``src mask``.
905
+ Default: ``False``.
906
+ Warning:
907
+ ``is_causal`` provides a hint that ``src_mask`` is the
908
+ causal mask. Providing incorrect hints can result in
909
+ incorrect execution, including forward and backward
910
+ compatibility.
911
+
912
+ Shape:
913
+ see the docs in :class:`~torch.nn.Transformer`.
914
+ """
915
+ src_key_padding_mask = F._canonical_mask(
916
+ mask=src_key_padding_mask,
917
+ mask_name="src_key_padding_mask",
918
+ other_type=F._none_or_dtype(src_mask),
919
+ other_name="src_mask",
920
+ target_type=src.dtype
921
+ )
922
+
923
+ src_mask = F._canonical_mask(
924
+ mask=src_mask,
925
+ mask_name="src_mask",
926
+ other_type=None,
927
+ other_name="",
928
+ target_type=src.dtype,
929
+ check_other=False,
930
+ )
931
+
932
+ # is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
933
+
934
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
935
+ why_not_sparsity_fast_path = ''
936
+ # if not is_fastpath_enabled:
937
+ # why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
938
+ if not src.dim() == 3:
939
+ why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
940
+ elif self.training:
941
+ why_not_sparsity_fast_path = "training is enabled"
942
+ elif not self.self_attn.batch_first:
943
+ why_not_sparsity_fast_path = "self_attn.batch_first was not True"
944
+ elif self.self_attn.in_proj_bias is None:
945
+ why_not_sparsity_fast_path = "self_attn was passed bias=False"
946
+ elif not self.self_attn._qkv_same_embed_dim:
947
+ why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
948
+ elif not self.activation_relu_or_gelu:
949
+ why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
950
+ elif not (self.norm1.eps == self.norm2.eps):
951
+ why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
952
+ elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
953
+ why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
954
+ elif self.self_attn.num_heads % 2 == 1:
955
+ why_not_sparsity_fast_path = "num_head is odd"
956
+ elif torch.is_autocast_enabled():
957
+ why_not_sparsity_fast_path = "autocast is enabled"
958
+ if not why_not_sparsity_fast_path:
959
+ tensor_args = (
960
+ src,
961
+ self.self_attn.in_proj_weight,
962
+ self.self_attn.in_proj_bias,
963
+ self.self_attn.out_proj.weight,
964
+ self.self_attn.out_proj.bias,
965
+ self.norm1.weight,
966
+ self.norm1.bias,
967
+ self.norm2.weight,
968
+ self.norm2.bias,
969
+ self.linear1.weight,
970
+ self.linear1.bias,
971
+ self.linear2.weight,
972
+ self.linear2.bias,
973
+ )
974
+
975
+ # We have to use list comprehensions below because TorchScript does not support
976
+ # generator expressions.
977
+ _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
978
+ if torch.overrides.has_torch_function(tensor_args):
979
+ why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
980
+ elif not all((x.device.type in _supported_device_type) for x in tensor_args):
981
+ why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
982
+ f"{_supported_device_type}")
983
+ elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
984
+ why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
985
+ "input/output projection weights or biases requires_grad")
986
+
987
+ if not why_not_sparsity_fast_path:
988
+ merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
989
+ return torch._transformer_encoder_layer_fwd(
990
+ src,
991
+ self.self_attn.embed_dim,
992
+ self.self_attn.num_heads,
993
+ self.self_attn.in_proj_weight,
994
+ self.self_attn.in_proj_bias,
995
+ self.self_attn.out_proj.weight,
996
+ self.self_attn.out_proj.bias,
997
+ self.activation_relu_or_gelu == 2,
998
+ self.norm_first,
999
+ self.norm1.eps,
1000
+ self.norm1.weight,
1001
+ self.norm1.bias,
1002
+ self.norm2.weight,
1003
+ self.norm2.bias,
1004
+ self.linear1.weight,
1005
+ self.linear1.bias,
1006
+ self.linear2.weight,
1007
+ self.linear2.bias,
1008
+ merged_mask,
1009
+ mask_type,
1010
+ )
1011
+
1012
+
1013
+ x = src
1014
+ if self.norm_first:
1015
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
1016
+ x = x + self._ff_block(self.norm2(x))
1017
+ else:
1018
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
1019
+ x = self.norm2(x + self._ff_block(x))
1020
+
1021
+ return x
1022
+
1023
+
1024
+ # self-attention block
1025
+ def _sa_block(self, x: Tensor,
1026
+ attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
1027
+ x = self.self_attn(x, x, x,
1028
+ attn_mask=attn_mask,
1029
+ key_padding_mask=key_padding_mask,
1030
+ need_weights=False, is_causal=is_causal)[0]
1031
+ return self.dropout1(x)
1032
+
1033
+ # feed forward block
1034
+ def _ff_block(self, x: Tensor) -> Tensor:
1035
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
1036
+ return self.dropout2(x)
1037
+
1038
+
1039
+
1040
+
1041
+ class TransformerDecoderLayer(Module):
1042
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
1043
+
1044
+ This standard decoder layer is based on the paper "Attention Is All You Need".
1045
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
1046
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
1047
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
1048
+ in a different way during application.
1049
+
1050
+ Args:
1051
+ d_model: the number of expected features in the input (required).
1052
+ nhead: the number of heads in the multiheadattention models (required).
1053
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
1054
+ dropout: the dropout value (default=0.1).
1055
+ activation: the activation function of the intermediate layer, can be a string
1056
+ ("relu" or "gelu") or a unary callable. Default: relu
1057
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
1058
+ batch_first: If ``True``, then the input and output tensors are provided
1059
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
1060
+ norm_first: if ``True``, layer norm is done prior to self attention, multihead
1061
+ attention and feedforward operations, respectively. Otherwise it's done after.
1062
+ Default: ``False`` (after).
1063
+ bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
1064
+ bias. Default: ``True``.
1065
+
1066
+ Examples::
1067
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
1068
+ >>> memory = torch.rand(10, 32, 512)
1069
+ >>> tgt = torch.rand(20, 32, 512)
1070
+ >>> out = decoder_layer(tgt, memory)
1071
+
1072
+ Alternatively, when ``batch_first`` is ``True``:
1073
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
1074
+ >>> memory = torch.rand(32, 10, 512)
1075
+ >>> tgt = torch.rand(32, 20, 512)
1076
+ >>> out = decoder_layer(tgt, memory)
1077
+ """
1078
+
1079
+ __constants__ = ['norm_first']
1080
+
1081
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, use_moe: bool = False, num_experts: int = 16,
1082
+ dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
1083
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
1084
+ bias: bool = True, device=None, dtype=None) -> None:
1085
+ factory_kwargs = {'device': device, 'dtype': dtype}
1086
+ super().__init__()
1087
+
1088
+ self.self_attn = MultiHeadSelfAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
1089
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
1090
+ bias=bias, **factory_kwargs)
1091
+ self.use_moe = use_moe
1092
+
1093
+ if use_moe:
1094
+ self.moe = MoE(
1095
+ dim = d_model,
1096
+ num_experts = num_experts, # increase the experts (# parameters) of your model without increasing computation
1097
+ gating_top_n = 2, # default to top 2 gating, but can also be more (3 was tested in the paper with a lower threshold)
1098
+ threshold_train = 0.2, # at what threshold to accept a token to be routed to second expert and beyond - 0.2 was optimal for 2 expert routing, and apparently should be lower for 3
1099
+ threshold_eval = 0.2,
1100
+ capacity_factor_train = 1.25, # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
1101
+ capacity_factor_eval = 2., # capacity_factor_* should be set to a value >=1
1102
+ balance_loss_coef = 1e-2, # multiplier on the auxiliary expert balancing auxiliary loss
1103
+ router_z_loss_coef = 1e-3, # loss weight for router z-loss
1104
+ ).to(device)
1105
+ self.moe_block = SparseMoEBlock(
1106
+ self.moe,
1107
+ add_ff_before = True,
1108
+ add_ff_after = True
1109
+ ).to(device)
1110
+ else:
1111
+ # Implementation of Feedforward model
1112
+ self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
1113
+ self.dropout = Dropout(dropout)
1114
+ self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
1115
+
1116
+ self.norm_first = norm_first
1117
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
1118
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
1119
+ self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
1120
+ self.dropout1 = Dropout(dropout)
1121
+ self.dropout2 = Dropout(dropout)
1122
+ self.dropout3 = Dropout(dropout)
1123
+
1124
+ # Legacy string support for activation function.
1125
+ if isinstance(activation, str):
1126
+ self.activation = _get_activation_fn(activation)
1127
+ else:
1128
+ self.activation = activation
1129
+
1130
+ def __setstate__(self, state):
1131
+ if 'activation' not in state:
1132
+ state['activation'] = F.relu
1133
+ super().__setstate__(state)
1134
+
1135
+
1136
+ def forward(
1137
+ self,
1138
+ tgt: Tensor,
1139
+ memory: Tensor,
1140
+ memory_mask: Optional[Tensor] = None,
1141
+ memory_key_padding_mask: Optional[Tensor] = None,
1142
+ tgt_is_causal: bool = False,
1143
+ memory_is_causal: bool = False,
1144
+ ) -> Tensor:
1145
+ r"""Pass the inputs (and mask) through the decoder layer.
1146
+
1147
+ Args:
1148
+ tgt: the sequence to the decoder layer (required).
1149
+ memory: the sequence from the last layer of the encoder (required).
1150
+ memory_mask: the mask for the memory sequence (optional).
1151
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
1152
+ tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
1153
+ Default: ``False``.
1154
+ Warning:
1155
+ ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
1156
+ the causal mask. Providing incorrect hints can result in
1157
+ incorrect execution, including forward and backward
1158
+ compatibility.
1159
+ memory_is_causal: If specified, applies a causal mask as
1160
+ ``memory mask``.
1161
+ Default: ``False``.
1162
+ Warning:
1163
+ ``memory_is_causal`` provides a hint that
1164
+ ``memory_mask`` is the causal mask. Providing incorrect
1165
+ hints can result in incorrect execution, including
1166
+ forward and backward compatibility.
1167
+
1168
+ Shape:
1169
+ see the docs in :class:`~torch.nn.Transformer`.
1170
+ """
1171
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
1172
+
1173
+ x = tgt
1174
+ # print(f'target is causal: {tgt_is_causal}')
1175
+ if self.norm_first:
1176
+ x = x + self._sa_block(self.norm1(x), tgt_is_causal)
1177
+ x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
1178
+ if self.use_moe:
1179
+ m, total_aux_loss, balance_loss, router_z_loss = self.moe_block(x)
1180
+ x = x + m
1181
+ else:
1182
+ x = x + self._ff_block(self.norm3(x))
1183
+ else:
1184
+ x = self.norm1(x + self._sa_block(x, tgt_is_causal))
1185
+ x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
1186
+ if self.use_moe:
1187
+ m, total_aux_loss, balance_loss, router_z_loss = self.moe_block(x)
1188
+ x = x + m
1189
+ else:
1190
+ x = self.norm3(x + self._ff_block(x))
1191
+
1192
+ if self.use_moe:
1193
+ return x, total_aux_loss, balance_loss, router_z_loss
1194
+ else:
1195
+ return x
1196
+
1197
+
1198
+ # self-attention block
1199
+ def _sa_block(self, x: Tensor,
1200
+ is_causal: bool = False) -> Tensor:
1201
+ x = self.self_attn(x, is_causal=is_causal)
1202
+ return self.dropout1(x)
1203
+
1204
+ # multihead attention block
1205
+ def _mha_block(self, x: Tensor, mem: Tensor,
1206
+ attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
1207
+ x = self.multihead_attn(x, mem, mem,
1208
+ attn_mask=attn_mask,
1209
+ key_padding_mask=key_padding_mask,
1210
+ is_causal=is_causal,
1211
+ need_weights=False)[0]
1212
+ return self.dropout2(x)
1213
+
1214
+ # feed forward block
1215
+ def _ff_block(self, x: Tensor) -> Tensor:
1216
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
1217
+ return self.dropout3(x)
1218
+
1219
+
1220
+
1221
+ def _get_clones(module, N):
1222
+ # FIXME: copy.deepcopy() is not defined on nn.module
1223
+ return ModuleList([copy.deepcopy(module) for i in range(N)])
1224
+
1225
+
1226
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
1227
+ if activation == "relu":
1228
+ return F.relu
1229
+ elif activation == "gelu":
1230
+ return F.gelu
1231
+
1232
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}")
1233
+
1234
+
1235
+ def _detect_is_causal_mask(
1236
+ mask: Optional[Tensor],
1237
+ is_causal: Optional[bool] = None,
1238
+ size: Optional[int] = None,
1239
+ ) -> bool:
1240
+ """Return whether the given attention mask is causal.
1241
+
1242
+ Warning:
1243
+ If ``is_causal`` is not ``None``, its value will be returned as is. If a
1244
+ user supplies an incorrect ``is_causal`` hint,
1245
+
1246
+ ``is_causal=False`` when the mask is in fact a causal attention.mask
1247
+ may lead to reduced performance relative to what would be achievable
1248
+ with ``is_causal=True``;
1249
+ ``is_causal=True`` when the mask is in fact not a causal attention.mask
1250
+ may lead to incorrect and unpredictable execution - in some scenarios,
1251
+ a causal mask may be applied based on the hint, in other execution
1252
+ scenarios the specified mask may be used. The choice may not appear
1253
+ to be deterministic, in that a number of factors like alignment,
1254
+ hardware SKU, etc influence the decision whether to use a mask or
1255
+ rely on the hint.
1256
+ ``size`` if not None, check whether the mask is a causal mask of the provided size
1257
+ Otherwise, checks for any causal mask.
1258
+ """
1259
+ # Prevent type refinement
1260
+ make_causal = (is_causal is True)
1261
+
1262
+ if is_causal is None and mask is not None:
1263
+ sz = size if size is not None else mask.size(-2)
1264
+ causal_comparison = _generate_square_subsequent_mask(
1265
+ sz, device=mask.device, dtype=mask.dtype)
1266
+
1267
+ # Do not use `torch.equal` so we handle batched masks by
1268
+ # broadcasting the comparison.
1269
+ if mask.size() == causal_comparison.size():
1270
+ make_causal = bool((mask == causal_comparison).all())
1271
+ else:
1272
+ make_causal = False
1273
+
1274
+ return make_causal
1275
+
1276
+ def check_instruments(genereated_seq):
1277
+ ins_present = []
1278
+ ins_count = 0
1279
+ instrument_list = ["piano", "chromatic", "organ", "guitar", "bass", "strings", "ensemble", "brass", "reed", "drum", "pipe", "synth_lead", "synth_pad", "synth_effect", "ethnic", "percussive", "sfx"]
1280
+ for token in genereated_seq:
1281
+ try:
1282
+ ins, pitch, vel = token
1283
+ # print(str(ins))
1284
+ except ValueError:
1285
+ try:
1286
+ ins, pitch = token
1287
+ except ValueError:
1288
+ ins = token
1289
+ if str(ins) in instrument_list:
1290
+ # print('coming here')
1291
+
1292
+ if ('prefix', 'instrument', str(ins)) not in ins_present and ins_count < 15:
1293
+ ins_count += 1
1294
+ print(f'adding instrument {ins}')
1295
+ ins_present.append(('prefix', 'instrument', str(ins)))
1296
+ if ins_present != []:
1297
+ genereated_seq = ins_present + ['<S>']+ genereated_seq +['<E>']
1298
+ else:
1299
+ genereated_seq = genereated_seq +['<E>']
1300
+ print(genereated_seq)
1301
+ return genereated_seq
1302
+
1303
+ def process_caption(gpu_id, captions, model, tokenizer, r_tokenizer):
1304
+ # Detect device: CUDA, MPS, or CPU
1305
+ if torch.cuda.is_available():
1306
+ device = torch.device(f"cuda:{gpu_id}")
1307
+ torch.cuda.set_device(gpu_id)
1308
+ print(f"Using CUDA on GPU {gpu_id}")
1309
+ elif torch.backends.mps.is_available():
1310
+ device = torch.device("mps")
1311
+ print("Using MPS on macOS")
1312
+ else:
1313
+ device = torch.device("cpu")
1314
+ print("Using CPU")
1315
+
1316
+ # Move the model to the selected device
1317
+ model.to(device)
1318
+ model.eval()
1319
+
1320
+ for caption in captions:
1321
+ src = caption['caption']
1322
+ location = caption['location']
1323
+
1324
+ # Tokenize input
1325
+ inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
1326
+ input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
1327
+ input_ids = input_ids.to(device)
1328
+ attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
1329
+ attention_mask = attention_mask.to(device)
1330
+
1331
+ # Generate output
1332
+ output = model.generate(input_ids, attention_mask, max_len=5000, temperature=0.9)
1333
+ output_list = output[0].tolist()
1334
+
1335
+ # Decode MIDI and save it
1336
+ generated_midi = r_tokenizer.decode(output_list)
1337
+ generated_midi.dump_midi(f"../res/{location}")
1338
+
1339
+ # def process_caption(gpu_id, captions, model, tokenizer, r_tokenizer):
1340
+ # device = gpu_id
1341
+ # torch.cuda.set_device(gpu_id)
1342
+ # model.to(gpu_id)
1343
+ # model.eval()
1344
+ # for caption in captions:
1345
+ # src = caption['caption']
1346
+ # location = caption['location']
1347
+ # #src = "A cinematic electronic soundtrack that evokes an epic and dark atmosphere, featuring cello, contrabass, and drums. The song is set in A minor with a moderate tempo and a 4/4 time signature, creating an emotional and action-packed ambiance suitable for film."
1348
+ # '''
1349
+ # example 1: "A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, vibraphone, bass, and drums, set in the key of Eb minor with a fast tempo of 123 bpm and a 4/4 time signature, creating a joyful and relaxing atmosphere."lmd_full/1/1b9f5f325c2080d345d877f590aa3dbe.mid
1350
+ # example 2: "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."lmd_full/1/152891ac63017b234c33e75e4a4a28c5.mid
1351
+ # example 3: "This motivational electronic and pop song features a clean electric guitar, rock organ, synth voice, acoustic guitar, and vibraphone, creating a melodic and uplifting atmosphere. Set in the key of G# minor with a 4/4 time signature, the track moves at an energetic Allegro tempo of 120 beats per minute. The chord progression of Bbm7 and F# adds to the song's inspiring and corporate feel." lmd_full/1/14347e50e9e8149a9da09f49b188180b.mid
1352
+ # example 4: "This short electronic song in C minor features a brass section, string ensemble, tenor saxophone, clean electric guitar, and slap bass, creating a melodic and slightly dark atmosphere. With a tempo of 124 BPM (Allegro) and a 4/4 time signature, the track incorporates a chord progression of C7/E, Eb6, and Bbm6, adding a touch of corporate and motivational vibes to the overall composition." lmd_full/1/1dc4cd50a5509d8042d27d80bc7e668e.mid
1353
+ # example 5: "An energetic and melodic electronic trance track with a space and retro vibe, featuring drums, distortion guitar, flute, synth bass, and slap bass. Set in A minor with a fast tempo of 138 BPM, the song maintains a 4/4 time signature throughout its duration." lmd_full/3/3328b854ebe7a2fc9a746ede74c410ae.mid
1354
+ # example 6: "A short but energetic rock fragment in C minor, featuring overdriven guitars, electric bass, and drums, with a vivacious tempo of 155 BPM and a 4/4 time signature, evoking a blend of dark and melodic tones." lmd_full/4/4c2232688c5f869b8470a408d197f5e3.mid
1355
+ # example 7: "A classical piece with a cinematic flair, this composition is characterized by its fast tempo and 4/4 time signature. The soprano saxophone and flute take turns leading the melody, supported by the lush tones of the string ensemble, acoustic bass, and pan flute. Set in the key of F minor, the harmonic landscape is painted with the chords Gm7b5, Cm7b5, Fm7, Eaug, and Ab/Eb. The overall mood evokes images of film, with hints of Christmas, drama, documentary, and adventure." lmd_full/9/95bce1b489a11829b4fef39200291f60.mid
1356
+ # exmaple 8: "A slow, dark, and emotional classical piece featuring cello, violin, and viola, likely to be used in a dramatic film soundtrack. The composition is in the key of C minor with a 4/4 time signature, and the main chord progression consists of Cm, G, Cm, and Fm." lmd_full/a/a22aad98ecfe4b3d8a353c2a72132834.mid
1357
+ # example 9: "A slow and emotional classical piece, likely used in a film soundtrack, featuring a church organ as the sole instrument. Written in the key of Eb major with a 3/4 time signature, it evokes a sense of drama and romance. The chord progression of Bb7, Eb, and Ab contributes to the relaxing atmosphere throughout the song." lmd_full/a/af4302a036c9df71e0435df9b08f8c4b.mid
1358
+ # example 10: "A cinematic electronic soundtrack that evokes an epic and dark atmosphere, featuring cello, contrabass, and drums. The song is set in A minor with a moderate tempo and a 4/4 time signature, creating an emotional and action-packed ambiance suitable for film." lmd_full/d/d920b6f451d7a72ae06f154e7c06c4c1.mid
1359
+ # '''
1360
+ # inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
1361
+ # input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
1362
+ # input_ids = input_ids.to(device)
1363
+ # attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
1364
+ # attention_mask = attention_mask.to(device)
1365
+ # output = model.generate(input_ids, attention_mask,max_len=5000,temperature = 0.9)
1366
+ # output_list = output[0].tolist()
1367
+ # print(type(output_list))
1368
+ # # generated_sequences = [dict_tokenizer[token] for token in output_list[0]]
1369
+ # # generated_sequences = check_instruments(generated_sequences)
1370
+ # # # generated_sequences = [('prefix', 'instrument', 'bass'), ('prefix', 'instrument', 'guitar'), ('prefix', 'instrument', 'piano'), ('prefix', 'instrument', 'guitar'), '<S>' ]+ generated_sequences +['<E>']
1371
+ # # generated_sequences = [token for token in generated_sequences]# if token not in ["<SS>", "<S>", "<E>", "<SEP>"]]
1372
+ # # # print("Generated sequences:", generated_sequences)
1373
+ # # with open('../../generated_seq.pkl', 'wb') as f:
1374
+ # # pickle.dump(generated_sequences, f)
1375
+ # # mid_dict = aria_tokenizer.detokenize(generated_sequences)
1376
+ # # mid = mid_dict.to_midi()
1377
+ # generated_midi = r_tokenizer.decode(output_list)
1378
+ # # print(type(generated_midi))
1379
+ # generated_midi.dump_midi(f"../res/{location}")
1380
+
1381
+ def test_generate(caption):
1382
+ # Detect device: CUDA, MPS, or CPU
1383
+ if torch.cuda.is_available():
1384
+ device = torch.device("cuda")
1385
+ print("Using CUDA on NVIDIA GPU")
1386
+ elif torch.backends.mps.is_available():
1387
+ device = torch.device("mps")
1388
+ print("Using MPS on macOS")
1389
+ else:
1390
+ device = torch.device("cpu")
1391
+ print("Using CPU")
1392
+
1393
+ artifact_folder = '../artifacts'
1394
+ tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
1395
+ caption_dataset_path = '/root/text2midi/captions/train.json'
1396
+ print(f'caption_dataset_path: {caption_dataset_path}')
1397
+
1398
+ # Load the tokenizer dictionary
1399
+ with open(tokenizer_filepath, "rb") as f:
1400
+ r_tokenizer = pickle.load(f)
1401
+ vocab_size = len(r_tokenizer) # +1
1402
+ print("Vocab size: ", vocab_size)
1403
+
1404
+ # Initialize model
1405
+ model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
1406
+ model.load_state_dict(torch.load('/root/test/text2midi/output_new/epoch_30/pytorch_model.bin', map_location=device))
1407
+ model.to(device) # Move model to detected device
1408
+ model.eval()
1409
+
1410
+ # Prepare input
1411
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
1412
+
1413
+ '''
1414
+ # num_gpus = torch.cuda.device_count()
1415
+ # captions_per_gpu = len(captions) // num_gpus
1416
+ # processes = []
1417
+ # for i in range(num_gpus):
1418
+ # start_idx = i * captions_per_gpu
1419
+ # end_idx = (i + 1) * captions_per_gpu if i != num_gpus - 1 else len(captions)
1420
+ # p = mp.Process(target=process_caption, args=(i, captions[start_idx:end_idx], model, tokenizer, r_tokenizer))
1421
+ # p.start()
1422
+ # processes.append(p)
1423
+
1424
+ # for p in processes:
1425
+ # p.join()
1426
+ '''
1427
+ # src = "A pop song with nostalgic feeling."
1428
+ # src = "A happy christmas song suitable for festive mood."
1429
+ # src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
1430
+ # src="An energetic and melodic electronic trance track with a space and retro vibe, featuring drums, distortion guitar, flute, synth bass, and slap bass. Set in A minor with a fast tempo of 138 BPM, the song maintains a 4/4 time signature throughout its duration."
1431
+ # src="A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, vibraphone, bass, and drums, set in the key of Eb minor with a fast tempo of 123 bpm and a 4/4 time signature, creating a joyful and relaxing atmosphere."
1432
+ # src = "This short electronic song in C minor features a brass section, string ensemble, tenor saxophone, clean electric guitar, and slap bass, creating a melodic and slightly dark atmosphere. With a tempo of 124 BPM (Allegro) and a 4/4 time signature, the track incorporates a chord progression of C7/E, Eb6, and Bbm6, adding a touch of corporate and motivational vibes to the overall composition."
1433
+ # src="This motivational electronic and pop song features a clean electric guitar, rock organ, synth voice, acoustic guitar, and vibraphone, creating a melodic and uplifting atmosphere. Set in the key of G# minor with a 4/4 time signature, the track moves at an energetic Allegro tempo of 120 beats per minute. The chord progression of Bbm7 and F# adds to the song's inspiring and corporate feel."
1434
+ # src = "Played at 149 beats per minute in 2/4 time signature and the key of G major, classical piece with instruments: bassoon, clarinet, flute, horn, oboe, and trumpet."
1435
+ # src= 'Played at 114 beats per minute in 1/4 time signature and the key of g# minor, classical piece with the following instruments: clarinet, english horn, flute, horn, piccolo, trombone, and trumpet.'
1436
+ inputs = tokenizer(caption, return_tensors='pt', padding=True, truncation=True)
1437
+ input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
1438
+ input_ids = input_ids.to(device)
1439
+ attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
1440
+ attention_mask = attention_mask.to(device)
1441
+ output = model.generate(input_ids, attention_mask, max_len=2000, temperature=0.9)
1442
+ output_list = output[0].tolist()
1443
+
1444
+ # Decode and save MIDI
1445
+ generated_midi = r_tokenizer.decode(output_list)
1446
+ generated_midi.dump_midi(f"../../output_christmas_2.mid")
1447
+
1448
+ def load_model_and_tokenizer(accelerator, model_path, vocab_size, tokenizer_filepath):
1449
+ device = accelerator.device
1450
+ with open(tokenizer_filepath, "rb") as f:
1451
+ r_tokenizer = pickle.load(f)
1452
+ model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
1453
+ model.load_state_dict(torch.load(model_path, map_location=device))
1454
+ model.to(device)
1455
+ model.eval()
1456
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
1457
+ return model, tokenizer, r_tokenizer
1458
+
1459
+ def process_example(accelerator, model, tokenizer, r_tokenizer, example, location, output_path):
1460
+ device = accelerator.device
1461
+ inputs = tokenizer(example, return_tensors='pt', padding=True, truncation=True).to(device)
1462
+ input_ids = inputs['input_ids']
1463
+ attention_mask = inputs['attention_mask']
1464
+ with torch.no_grad():
1465
+ output = model.module.generate(input_ids, attention_mask, max_len=2000, temperature=0.9)
1466
+ output_list = output[0].tolist()
1467
+ generated_midi = r_tokenizer.decode(output_list)
1468
+ generated_midi.dump_midi(output_path)
1469
+
1470
+ def run_accelerate_generation():
1471
+ accelerator = Accelerator()
1472
+ artifact_folder = '../artifacts'
1473
+ tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
1474
+ model_path = '/root/output_test_new/epoch_30/pytorch_model.bin'
1475
+ captions_path = '/root/captions/train.json'
1476
+
1477
+ with jsonlines.open(captions_path) as reader:
1478
+ selected_captions = [line for line in reader if line.get('test_set') is True]
1479
+
1480
+ with open(tokenizer_filepath, "rb") as f:
1481
+ r_tokenizer = pickle.load(f)
1482
+
1483
+ model, tokenizer, r_tokenizer = load_model_and_tokenizer(accelerator, model_path, len(r_tokenizer), tokenizer_filepath)
1484
+ model = accelerator.prepare(model)
1485
+
1486
+ dataset = CaptionDataset(selected_captions)
1487
+ dataloader = DataLoader(dataset, batch_size=8, num_workers=4, shuffle=False, collate_fn=custom_collate_fn)
1488
+ dataloader = accelerator.prepare(dataloader)
1489
+
1490
+ for captions, locations in dataloader:
1491
+ for example, location in zip(captions, locations):
1492
+ output_path = os.path.join(f'/root/Text2midi/res_acc', location)
1493
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
1494
+ process_example(accelerator, model, tokenizer, r_tokenizer, example, location, output_path)
1495
+
1496
+ # run_accelerate_generation() #uncomment this and comment __main__ to run accelerate generation
1497
+
1498
+ def main():
1499
+ parser = argparse.ArgumentParser(description="Generate MIDI from caption")
1500
+ parser.add_argument('--caption', type=str, required=True, help='Caption to generate MIDI from')
1501
+ args = parser.parse_args()
1502
+ test_generate(args.caption)
1503
+
1504
+ '''
1505
+ comment out the next section function and uncomment the run_accelerate_generation() function to run the accelerate generation
1506
+ '''
1507
+ if __name__ == "__main__":
1508
+ main()
1509
+ print("Done")
text2midi_repo/requirements-mac.txt ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.10
3
+ # by the following command:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ accelerate==0.18.0
8
+ # via -r requirements.in
9
+ aiohappyeyeballs==2.4.4
10
+ # via aiohttp
11
+ aiohttp==3.11.10
12
+ # via
13
+ # datasets
14
+ # fsspec
15
+ aiosignal==1.3.1
16
+ # via aiohttp
17
+ annotated-types==0.7.0
18
+ # via pydantic
19
+ async-timeout==5.0.1
20
+ # via aiohttp
21
+ attrs==24.2.0
22
+ # via
23
+ # aiohttp
24
+ # jsonlines
25
+ beartype==0.19.0
26
+ # via st-moe-pytorch
27
+ blis==1.0.1
28
+ # via thinc
29
+ catalogue==2.0.10
30
+ # via
31
+ # spacy
32
+ # srsly
33
+ # thinc
34
+ certifi==2024.8.30
35
+ # via
36
+ # requests
37
+ # sentry-sdk
38
+ charset-normalizer==3.4.0
39
+ # via requests
40
+ click==8.1.7
41
+ # via
42
+ # typer
43
+ # wandb
44
+ cloudpathlib==0.20.0
45
+ # via weasel
46
+ colt5-attention==0.11.1
47
+ # via st-moe-pytorch
48
+ confection==0.1.5
49
+ # via
50
+ # thinc
51
+ # weasel
52
+ cymem==2.0.10
53
+ # via
54
+ # preshed
55
+ # spacy
56
+ # thinc
57
+ datasets==3.1.0
58
+ # via evaluate
59
+ dill==0.3.8
60
+ # via
61
+ # datasets
62
+ # evaluate
63
+ # multiprocess
64
+ docker-pycreds==0.4.0
65
+ # via wandb
66
+ einops==0.8.0
67
+ # via
68
+ # -r requirements.in
69
+ # colt5-attention
70
+ # local-attention
71
+ # st-moe-pytorch
72
+ evaluate==0.4.3
73
+ # via -r requirements.in
74
+ filelock==3.16.1
75
+ # via
76
+ # datasets
77
+ # huggingface-hub
78
+ # torch
79
+ # transformers
80
+ # triton
81
+ frozenlist==1.5.0
82
+ # via
83
+ # aiohttp
84
+ # aiosignal
85
+ fsspec[http]==2024.9.0
86
+ # via
87
+ # datasets
88
+ # evaluate
89
+ # huggingface-hub
90
+ # torch
91
+ gitdb==4.0.11
92
+ # via gitpython
93
+ gitpython==3.1.43
94
+ # via wandb
95
+ huggingface-hub==0.26.3
96
+ # via
97
+ # accelerate
98
+ # datasets
99
+ # evaluate
100
+ # miditok
101
+ # tokenizers
102
+ # transformers
103
+ idna==3.10
104
+ # via
105
+ # requests
106
+ # yarl
107
+ jinja2==3.1.4
108
+ # via
109
+ # spacy
110
+ # torch
111
+ jsonlines==4.0.0
112
+ # via -r requirements.in
113
+ langcodes==3.5.0
114
+ # via spacy
115
+ language-data==1.3.0
116
+ # via langcodes
117
+ local-attention==1.9.15
118
+ # via colt5-attention
119
+ marisa-trie==1.2.1
120
+ # via language-data
121
+ markdown-it-py==3.0.0
122
+ # via rich
123
+ markupsafe==3.0.2
124
+ # via jinja2
125
+ mdurl==0.1.2
126
+ # via markdown-it-py
127
+ miditok==3.0.3
128
+ # via -r requirements.in
129
+ mpmath==1.3.0
130
+ # via sympy
131
+ multidict==6.1.0
132
+ # via
133
+ # aiohttp
134
+ # yarl
135
+ multiprocess==0.70.16
136
+ # via
137
+ # datasets
138
+ # evaluate
139
+ murmurhash==1.0.11
140
+ # via
141
+ # preshed
142
+ # spacy
143
+ # thinc
144
+ networkx==3.4.2
145
+ # via torch
146
+ numpy==2.0.2
147
+ # via
148
+ # -r requirements.in
149
+ # accelerate
150
+ # blis
151
+ # datasets
152
+ # evaluate
153
+ # miditok
154
+ # pandas
155
+ # spacy
156
+ # symusic
157
+ # thinc
158
+ # transformers
159
+ #nvidia-cublas-cu12==12.4.5.8
160
+ # via
161
+ # nvidia-cudnn-cu12
162
+ # nvidia-cusolver-cu12
163
+ # torch
164
+ #nvidia-cuda-cupti-cu12==12.4.127
165
+ # via torch
166
+ #nvidia-cuda-nvrtc-cu12==12.4.127
167
+ # via torch
168
+ #nvidia-cuda-runtime-cu12==12.4.127
169
+ # via torch
170
+ #nvidia-cudnn-cu12==9.1.0.70
171
+ # via torch
172
+ #nvidia-cufft-cu12==11.2.1.3
173
+ # via torch
174
+ #nvidia-curand-cu12==10.3.5.147
175
+ # via torch
176
+ #nvidia-cusolver-cu12==11.6.1.9
177
+ # via torch
178
+ #nvidia-cusparse-cu12==12.3.1.170
179
+ # via
180
+ # nvidia-cusolver-cu12
181
+ # torch
182
+ #nvidia-nccl-cu12==2.21.5
183
+ # via torch
184
+ #nvidia-nvjitlink-cu12==12.4.127
185
+ # via
186
+ # nvidia-cusolver-cu12
187
+ # nvidia-cusparse-cu12
188
+ # torch
189
+ #nvidia-nvtx-cu12==12.4.127
190
+ # via torch
191
+ packaging==24.2
192
+ # via
193
+ # accelerate
194
+ # colt5-attention
195
+ # datasets
196
+ # evaluate
197
+ # huggingface-hub
198
+ # spacy
199
+ # thinc
200
+ # transformers
201
+ # weasel
202
+ pandas==2.2.3
203
+ # via
204
+ # datasets
205
+ # evaluate
206
+ platformdirs==4.3.6
207
+ # via
208
+ # symusic
209
+ # wandb
210
+ preshed==3.0.9
211
+ # via
212
+ # spacy
213
+ # thinc
214
+ propcache==0.2.1
215
+ # via
216
+ # aiohttp
217
+ # yarl
218
+ protobuf==5.29.1
219
+ # via wandb
220
+ psutil==6.1.0
221
+ # via
222
+ # accelerate
223
+ # wandb
224
+ pyarrow==18.1.0
225
+ # via datasets
226
+ pydantic==2.10.3
227
+ # via
228
+ # confection
229
+ # spacy
230
+ # thinc
231
+ # wandb
232
+ # weasel
233
+ pydantic-core==2.27.1
234
+ # via pydantic
235
+ pygments==2.18.0
236
+ # via rich
237
+ pysmartdl==1.3.4
238
+ # via symusic
239
+ python-dateutil==2.9.0.post0
240
+ # via pandas
241
+ pytz==2024.2
242
+ # via pandas
243
+ pyyaml==6.0.2
244
+ # via
245
+ # -r requirements.in
246
+ # accelerate
247
+ # datasets
248
+ # huggingface-hub
249
+ # transformers
250
+ # wandb
251
+ regex==2024.11.6
252
+ # via transformers
253
+ requests==2.32.3
254
+ # via
255
+ # datasets
256
+ # evaluate
257
+ # huggingface-hub
258
+ # spacy
259
+ # transformers
260
+ # wandb
261
+ # weasel
262
+ rich==13.9.4
263
+ # via typer
264
+ safetensors==0.4.5
265
+ # via
266
+ # accelerate
267
+ # transformers
268
+ sentry-sdk==2.19.2
269
+ # via wandb
270
+ sentencepiece==0.2.0
271
+
272
+ setproctitle==1.3.4
273
+ # via wandb
274
+ shellingham==1.5.4
275
+ # via typer
276
+ six==1.17.0
277
+ # via
278
+ # docker-pycreds
279
+ # python-dateutil
280
+ smart-open==7.0.5
281
+ # via weasel
282
+ smmap==5.0.1
283
+ # via gitdb
284
+ spacy==3.8.2
285
+ # via -r requirements.in
286
+ spacy-legacy==3.0.12
287
+ # via spacy
288
+ spacy-loggers==1.0.5
289
+ # via spacy
290
+ srsly==2.4.8
291
+ # via
292
+ # confection
293
+ # spacy
294
+ # thinc
295
+ # weasel
296
+ st-moe-pytorch==0.1.8
297
+ # via -r requirements.in
298
+ sympy==1.13.1
299
+ # via torch
300
+ symusic==0.5.5
301
+ # via miditok
302
+ thinc==8.3.2
303
+ # via spacy
304
+ tokenizers==0.21.0
305
+ # via
306
+ # miditok
307
+ # transformers
308
+ torch==2.5.1
309
+ # via
310
+ # -r requirements.in
311
+ # accelerate
312
+ # colt5-attention
313
+ # local-attention
314
+ # st-moe-pytorch
315
+ tqdm==4.67.1
316
+ # via
317
+ # -r requirements.in
318
+ # datasets
319
+ # evaluate
320
+ # huggingface-hub
321
+ # miditok
322
+ # spacy
323
+ # transformers
324
+ transformers==4.47.0
325
+ # via -r requirements.in
326
+ #triton==3.1.0
327
+ # via torch
328
+ typer==0.15.1
329
+ # via
330
+ # spacy
331
+ # weasel
332
+ typing-extensions==4.12.2
333
+ # via
334
+ # cloudpathlib
335
+ # huggingface-hub
336
+ # multidict
337
+ # pydantic
338
+ # pydantic-core
339
+ # rich
340
+ # torch
341
+ # typer
342
+ # wandb
343
+ tzdata==2024.2
344
+ # via pandas
345
+ urllib3==2.2.3
346
+ # via
347
+ # requests
348
+ # sentry-sdk
349
+ wandb==0.19.0
350
+ # via -r requirements.in
351
+ wasabi==1.1.3
352
+ # via
353
+ # spacy
354
+ # thinc
355
+ # weasel
356
+ weasel==0.4.1
357
+ # via spacy
358
+ wrapt==1.17.0
359
+ # via smart-open
360
+ xxhash==3.5.0
361
+ # via
362
+ # datasets
363
+ # evaluate
364
+ yarl==1.18.3
365
+ # via aiohttp
366
+
367
+ # The following packages are considered to be unsafe in a requirements file:
368
+ # setuptools
text2midi_repo/requirements.txt ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.10
3
+ # by the following command:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ accelerate==0.18.0
8
+ # via -r requirements.in
9
+ aiohappyeyeballs==2.4.4
10
+ # via aiohttp
11
+ aiohttp==3.11.10
12
+ # via
13
+ # datasets
14
+ # fsspec
15
+ aiosignal==1.3.1
16
+ # via aiohttp
17
+ annotated-types==0.7.0
18
+ # via pydantic
19
+ async-timeout==5.0.1
20
+ # via aiohttp
21
+ attrs==24.2.0
22
+ # via
23
+ # aiohttp
24
+ # jsonlines
25
+ beartype==0.19.0
26
+ # via st-moe-pytorch
27
+ blis==1.0.1
28
+ # via thinc
29
+ catalogue==2.0.10
30
+ # via
31
+ # spacy
32
+ # srsly
33
+ # thinc
34
+ certifi==2024.8.30
35
+ # via
36
+ # requests
37
+ # sentry-sdk
38
+ charset-normalizer==3.4.0
39
+ # via requests
40
+ click==8.1.7
41
+ # via
42
+ # typer
43
+ # wandb
44
+ cloudpathlib==0.20.0
45
+ # via weasel
46
+ colt5-attention==0.11.1
47
+ # via st-moe-pytorch
48
+ confection==0.1.5
49
+ # via
50
+ # thinc
51
+ # weasel
52
+ cymem==2.0.10
53
+ # via
54
+ # preshed
55
+ # spacy
56
+ # thinc
57
+ datasets==3.1.0
58
+ # via evaluate
59
+ dill==0.3.8
60
+ # via
61
+ # datasets
62
+ # evaluate
63
+ # multiprocess
64
+ docker-pycreds==0.4.0
65
+ # via wandb
66
+ einops==0.8.0
67
+ # via
68
+ # -r requirements.in
69
+ # colt5-attention
70
+ # local-attention
71
+ # st-moe-pytorch
72
+ evaluate==0.4.3
73
+ # via -r requirements.in
74
+ filelock==3.16.1
75
+ # via
76
+ # datasets
77
+ # huggingface-hub
78
+ # torch
79
+ # transformers
80
+ # triton
81
+ frozenlist==1.5.0
82
+ # via
83
+ # aiohttp
84
+ # aiosignal
85
+ fsspec[http]==2024.9.0
86
+ # via
87
+ # datasets
88
+ # evaluate
89
+ # huggingface-hub
90
+ # torch
91
+ gitdb==4.0.11
92
+ # via gitpython
93
+ gitpython==3.1.43
94
+ # via wandb
95
+ huggingface-hub==0.26.3
96
+ # via
97
+ # accelerate
98
+ # datasets
99
+ # evaluate
100
+ # miditok
101
+ # tokenizers
102
+ # transformers
103
+ idna==3.10
104
+ # via
105
+ # requests
106
+ # yarl
107
+ jinja2==3.1.4
108
+ # via
109
+ # spacy
110
+ # torch
111
+ jsonlines==4.0.0
112
+ # via -r requirements.in
113
+ langcodes==3.5.0
114
+ # via spacy
115
+ language-data==1.3.0
116
+ # via langcodes
117
+ local-attention==1.9.15
118
+ # via colt5-attention
119
+ marisa-trie==1.2.1
120
+ # via language-data
121
+ markdown-it-py==3.0.0
122
+ # via rich
123
+ markupsafe==3.0.2
124
+ # via jinja2
125
+ mdurl==0.1.2
126
+ # via markdown-it-py
127
+ miditok==3.0.3
128
+ # via -r requirements.in
129
+ mpmath==1.3.0
130
+ # via sympy
131
+ multidict==6.1.0
132
+ # via
133
+ # aiohttp
134
+ # yarl
135
+ multiprocess==0.70.16
136
+ # via
137
+ # datasets
138
+ # evaluate
139
+ murmurhash==1.0.11
140
+ # via
141
+ # preshed
142
+ # spacy
143
+ # thinc
144
+ networkx==3.4.2
145
+ # via torch
146
+ numpy==2.0.2
147
+ # via
148
+ # -r requirements.in
149
+ # accelerate
150
+ # blis
151
+ # datasets
152
+ # evaluate
153
+ # miditok
154
+ # pandas
155
+ # spacy
156
+ # symusic
157
+ # thinc
158
+ # transformers
159
+ nvidia-cublas-cu12==12.4.5.8
160
+ # via
161
+ # nvidia-cudnn-cu12
162
+ # nvidia-cusolver-cu12
163
+ # torch
164
+ nvidia-cuda-cupti-cu12==12.4.127
165
+ # via torch
166
+ nvidia-cuda-nvrtc-cu12==12.4.127
167
+ # via torch
168
+ nvidia-cuda-runtime-cu12==12.4.127
169
+ # via torch
170
+ nvidia-cudnn-cu12==9.1.0.70
171
+ # via torch
172
+ nvidia-cufft-cu12==11.2.1.3
173
+ # via torch
174
+ nvidia-curand-cu12==10.3.5.147
175
+ # via torch
176
+ nvidia-cusolver-cu12==11.6.1.9
177
+ # via torch
178
+ nvidia-cusparse-cu12==12.3.1.170
179
+ # via
180
+ # nvidia-cusolver-cu12
181
+ # torch
182
+ nvidia-nccl-cu12==2.21.5
183
+ # via torch
184
+ nvidia-nvjitlink-cu12==12.4.127
185
+ # via
186
+ # nvidia-cusolver-cu12
187
+ # nvidia-cusparse-cu12
188
+ # torch
189
+ nvidia-nvtx-cu12==12.4.127
190
+ # via torch
191
+ packaging==24.2
192
+ # via
193
+ # accelerate
194
+ # colt5-attention
195
+ # datasets
196
+ # evaluate
197
+ # huggingface-hub
198
+ # spacy
199
+ # thinc
200
+ # transformers
201
+ # weasel
202
+ pandas==2.2.3
203
+ # via
204
+ # datasets
205
+ # evaluate
206
+ platformdirs==4.3.6
207
+ # via
208
+ # symusic
209
+ # wandb
210
+ preshed==3.0.9
211
+ # via
212
+ # spacy
213
+ # thinc
214
+ propcache==0.2.1
215
+ # via
216
+ # aiohttp
217
+ # yarl
218
+ protobuf==5.29.1
219
+ # via wandb
220
+ psutil==6.1.0
221
+ # via
222
+ # accelerate
223
+ # wandb
224
+ pyarrow==18.1.0
225
+ # via datasets
226
+ pydantic==2.10.3
227
+ # via
228
+ # confection
229
+ # spacy
230
+ # thinc
231
+ # wandb
232
+ # weasel
233
+ pydantic-core==2.27.1
234
+ # via pydantic
235
+ pygments==2.18.0
236
+ # via rich
237
+ pysmartdl==1.3.4
238
+ # via symusic
239
+ python-dateutil==2.9.0.post0
240
+ # via pandas
241
+ pytz==2024.2
242
+ # via pandas
243
+ pyyaml==6.0.2
244
+ # via
245
+ # -r requirements.in
246
+ # accelerate
247
+ # datasets
248
+ # huggingface-hub
249
+ # transformers
250
+ # wandb
251
+ regex==2024.11.6
252
+ # via transformers
253
+ requests==2.32.3
254
+ # via
255
+ # datasets
256
+ # evaluate
257
+ # huggingface-hub
258
+ # spacy
259
+ # transformers
260
+ # wandb
261
+ # weasel
262
+ rich==13.9.4
263
+ # via typer
264
+ safetensors==0.4.5
265
+ # via
266
+ # accelerate
267
+ # transformers
268
+ sentry-sdk==2.19.2
269
+ # via wandb
270
+ sentencepiece==0.2.0
271
+
272
+ setproctitle==1.3.4
273
+ # via wandb
274
+ shellingham==1.5.4
275
+ # via typer
276
+ six==1.17.0
277
+ # via
278
+ # docker-pycreds
279
+ # python-dateutil
280
+ smart-open==7.0.5
281
+ # via weasel
282
+ smmap==5.0.1
283
+ # via gitdb
284
+ spacy==3.8.2
285
+ # via -r requirements.in
286
+ spacy-legacy==3.0.12
287
+ # via spacy
288
+ spacy-loggers==1.0.5
289
+ # via spacy
290
+ srsly==2.4.8
291
+ # via
292
+ # confection
293
+ # spacy
294
+ # thinc
295
+ # weasel
296
+ st-moe-pytorch==0.1.8
297
+ # via -r requirements.in
298
+ sympy==1.13.1
299
+ # via torch
300
+ symusic==0.5.5
301
+ # via miditok
302
+ thinc==8.3.2
303
+ # via spacy
304
+ tokenizers==0.21.0
305
+ # via
306
+ # miditok
307
+ # transformers
308
+ torch==2.5.1
309
+ # via
310
+ # -r requirements.in
311
+ # accelerate
312
+ # colt5-attention
313
+ # local-attention
314
+ # st-moe-pytorch
315
+ tqdm==4.67.1
316
+ # via
317
+ # -r requirements.in
318
+ # datasets
319
+ # evaluate
320
+ # huggingface-hub
321
+ # miditok
322
+ # spacy
323
+ # transformers
324
+ transformers==4.47.0
325
+ # via -r requirements.in
326
+ triton==3.1.0
327
+ # via torch
328
+ typer==0.15.1
329
+ # via
330
+ # spacy
331
+ # weasel
332
+ typing-extensions==4.12.2
333
+ # via
334
+ # cloudpathlib
335
+ # huggingface-hub
336
+ # multidict
337
+ # pydantic
338
+ # pydantic-core
339
+ # rich
340
+ # torch
341
+ # typer
342
+ # wandb
343
+ tzdata==2024.2
344
+ # via pandas
345
+ urllib3==2.2.3
346
+ # via
347
+ # requests
348
+ # sentry-sdk
349
+ wandb==0.19.0
350
+ # via -r requirements.in
351
+ wasabi==1.1.3
352
+ # via
353
+ # spacy
354
+ # thinc
355
+ # weasel
356
+ weasel==0.4.1
357
+ # via spacy
358
+ wrapt==1.17.0
359
+ # via smart-open
360
+ xxhash==3.5.0
361
+ # via
362
+ # datasets
363
+ # evaluate
364
+ yarl==1.18.3
365
+ # via aiohttp
366
+
367
+ # The following packages are considered to be unsafe in a requirements file:
368
+ # setuptools
text2midi_repo/text2midi_architecture.jpg ADDED

Git LFS Details

  • SHA256: 732af27208de46a6f0ab508605597b1dc1f569daa0e256ce695773fd2466eb6f
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
text2midi_repo/utils/midi_to_wav.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from multiprocessing import Pool, cpu_count
4
+ from tqdm import tqdm
5
+
6
+ soundfont_filepath = "/root/soundfont/soundfont.sf"
7
+
8
+ def save_wav(midi_filepath, wav_filepath):
9
+ # Check if the .wav file already exists
10
+ if os.path.isfile(wav_filepath):
11
+ print(f"{wav_filepath} already exists, skipping")
12
+ return wav_filepath
13
+ else:
14
+ print(f"Creating {wav_filepath} from {midi_filepath}")
15
+
16
+ # Run the fluidsynth command to convert MIDI to WAV
17
+ command = f"fluidsynth -r 48000 {soundfont_filepath} -g 1.0 --quiet --no-shell {midi_filepath} -T wav -F {wav_filepath}"
18
+ print(f"Running command: {command}")
19
+ process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
20
+ stdout, stderr = process.communicate()
21
+
22
+ if process.returncode != 0:
23
+ print(f"Error converting {midi_filepath} to {wav_filepath}: {stderr.decode('utf-8')}")
24
+ else:
25
+ print(f"Successfully created {wav_filepath}")
26
+
27
+ return wav_filepath
28
+
29
+ def process_midi_file(midi_filepath):
30
+ # Determine the corresponding wav file path
31
+ relative_path = os.path.relpath(midi_filepath, "/root/Text2midi/res_acc")
32
+ wav_filepath = os.path.join("/root/wav", relative_path.replace('.mid', '.wav'))
33
+ wav_directory = os.path.dirname(wav_filepath)
34
+
35
+ # Ensure the directory exists
36
+ os.makedirs(wav_directory, exist_ok=True)
37
+
38
+ # Convert the MIDI file to WAV
39
+ save_wav(midi_filepath, wav_filepath)
40
+
41
+ def main():
42
+ # Find all .mid files in /root/Text2midi/res_acc
43
+ midi_files = []
44
+ for root, _, files in os.walk("/root/Text2midi/res_acc"):
45
+ for file in files:
46
+ if file.endswith(".mid"):
47
+ midi_files.append(os.path.join(root, file))
48
+
49
+ # Use half of the available CPU cores for multiprocessing
50
+ num_cores = cpu_count() // 2
51
+ with Pool(num_cores) as pool:
52
+ list(tqdm(pool.imap(process_midi_file, midi_files), total=len(midi_files), desc="Processing MIDI files"))
53
+
54
+ if __name__ == "__main__":
55
+ main()
text2midi_repo/utils/split_caption.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import jsonlines
5
+
6
+ def select_and_split_captions(input_path, output_dir, num_splits=6):
7
+ with jsonlines.open(input_path) as reader:
8
+ captions = [line for line in reader if line.get('test_set') is True]
9
+
10
+ selected_captions = captions #random.sample(captions, 500)
11
+
12
+ # Split the selected captions into num_splits groups
13
+ split_size = len(selected_captions) // num_splits
14
+ for i in range(num_splits):
15
+ start_idx = i * split_size
16
+ end_idx = (i + 1) * split_size if i != num_splits - 1 else len(selected_captions)
17
+ split_captions = selected_captions[start_idx:end_idx]
18
+
19
+ output_path = os.path.join(output_dir, f'selected_captions_{i}.json')
20
+ with open(output_path, 'w') as f:
21
+ json.dump(split_captions, f, indent=4)
22
+ print(f'Saved {len(split_captions)} captions to {output_path}')
23
+
24
+ if __name__ == "__main__":
25
+ input_path = '/root/captions/train.json'
26
+ output_dir = '/root/captions/'
27
+ select_and_split_captions(input_path, output_dir)