Spaces:
Sleeping
Sleeping
Upload 27 files
Browse files- .gitattributes +2 -0
- README.md +112 -8
- app.py +277 -0
- packages.txt +2 -0
- requirements.txt +9 -0
- text2midi_repo/.gitignore +162 -0
- text2midi_repo/LICENSE +21 -0
- text2midi_repo/README.md +162 -0
- text2midi_repo/artifacts/vocab.pkl +3 -0
- text2midi_repo/artifacts/vocab_remi.pkl +3 -0
- text2midi_repo/captions/captions.json +3 -0
- text2midi_repo/configs/config.yaml +62 -0
- text2midi_repo/configs/ds_config.json +33 -0
- text2midi_repo/model/__pycache__/transformer_model.cpython-314.pyc +0 -0
- text2midi_repo/model/build_vocab.py +79 -0
- text2midi_repo/model/build_vocab_remi.py +69 -0
- text2midi_repo/model/data_loader.py +124 -0
- text2midi_repo/model/data_loader_remi.py +126 -0
- text2midi_repo/model/dict_output.txt +0 -0
- text2midi_repo/model/train.py +185 -0
- text2midi_repo/model/train_accelerate.py +224 -0
- text2midi_repo/model/train_hf.py +283 -0
- text2midi_repo/model/transformer_model.py +1509 -0
- text2midi_repo/requirements-mac.txt +368 -0
- text2midi_repo/requirements.txt +368 -0
- text2midi_repo/text2midi_architecture.jpg +3 -0
- text2midi_repo/utils/midi_to_wav.py +55 -0
- text2midi_repo/utils/split_caption.py +27 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
text2midi_repo/captions/captions.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
text2midi_repo/text2midi_architecture.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,118 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
|
|
|
| 9 |
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
short_description: TextToAudio
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: VR Game Music Generator
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
+
python_version: 3.11
|
| 10 |
pinned: false
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# VR Game Music Generator
|
| 14 |
+
|
| 15 |
+
Generate music from text descriptions using the text2midi AI model. Designed for integration with Unity and other game engines via the Gradio API.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- Text-to-music generation using AI
|
| 20 |
+
- Real-time audio streaming (no file persistence)
|
| 21 |
+
- RESTful API for game engine integration
|
| 22 |
+
- Supports various music styles and instruments
|
| 23 |
+
|
| 24 |
+
## API Usage
|
| 25 |
+
|
| 26 |
+
### Endpoint
|
| 27 |
+
```
|
| 28 |
+
POST https://YOUR-SPACE.hf.space/api/generate
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Request
|
| 32 |
+
```json
|
| 33 |
+
{
|
| 34 |
+
"data": ["A cheerful pop song with piano and drums", 512, 0.9]
|
| 35 |
+
}
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Parameters:
|
| 39 |
+
- `data[0]`: Music prompt (string)
|
| 40 |
+
- `data[1]`: Max length in tokens (256-2048, default: 512)
|
| 41 |
+
- `data[2]`: Temperature (0.1-1.5, default: 0.9)
|
| 42 |
+
|
| 43 |
+
### Response
|
| 44 |
+
```json
|
| 45 |
+
{
|
| 46 |
+
"data": [
|
| 47 |
+
{"path": "/file=...", "url": "https://...", "orig_name": "audio.wav"},
|
| 48 |
+
"AI-generated audio for: 'A cheerful pop song...'"
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Unity Integration
|
| 54 |
+
|
| 55 |
+
```csharp
|
| 56 |
+
using UnityEngine;
|
| 57 |
+
using UnityEngine.Networking;
|
| 58 |
+
using System.Collections;
|
| 59 |
+
|
| 60 |
+
public class MusicGenerator : MonoBehaviour
|
| 61 |
+
{
|
| 62 |
+
private const string API_URL = "https://YOUR-SPACE.hf.space/api/generate";
|
| 63 |
+
|
| 64 |
+
public IEnumerator GenerateMusic(string prompt, System.Action<AudioClip> callback)
|
| 65 |
+
{
|
| 66 |
+
string json = $"{{\"data\": [\"{prompt}\", 512, 0.9]}}";
|
| 67 |
+
|
| 68 |
+
using (UnityWebRequest request = new UnityWebRequest(API_URL, "POST"))
|
| 69 |
+
{
|
| 70 |
+
byte[] bodyRaw = System.Text.Encoding.UTF8.GetBytes(json);
|
| 71 |
+
request.uploadHandler = new UploadHandlerRaw(bodyRaw);
|
| 72 |
+
request.downloadHandler = new DownloadHandlerBuffer();
|
| 73 |
+
request.SetRequestHeader("Content-Type", "application/json");
|
| 74 |
+
|
| 75 |
+
yield return request.SendWebRequest();
|
| 76 |
+
|
| 77 |
+
if (request.result == UnityWebRequest.Result.Success)
|
| 78 |
+
{
|
| 79 |
+
// Parse response and download audio from returned URL
|
| 80 |
+
var response = JsonUtility.FromJson<GradioResponse>(request.downloadHandler.text);
|
| 81 |
+
yield return DownloadAudio(response.data[0].url, callback);
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
private IEnumerator DownloadAudio(string url, System.Action<AudioClip> callback)
|
| 87 |
+
{
|
| 88 |
+
using (UnityWebRequest www = UnityWebRequestMultimedia.GetAudioClip(url, AudioType.WAV))
|
| 89 |
+
{
|
| 90 |
+
yield return www.SendWebRequest();
|
| 91 |
+
if (www.result == UnityWebRequest.Result.Success)
|
| 92 |
+
{
|
| 93 |
+
callback(DownloadHandlerAudioClip.GetContent(www));
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## Example Prompts
|
| 101 |
+
|
| 102 |
+
- A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, and drums
|
| 103 |
+
- An energetic electronic trance track with synth bass and drums at 138 BPM
|
| 104 |
+
- A slow and emotional classical piece featuring cello and violin in C minor
|
| 105 |
+
- A cinematic electronic soundtrack with an epic and dark atmosphere
|
| 106 |
+
- Happy medieval tavern music with lute and flute
|
| 107 |
+
|
| 108 |
+
## Local Development
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
pip install -r requirements.txt
|
| 112 |
+
python app.py
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Credits
|
| 116 |
+
|
| 117 |
+
- Model: [amaai-lab/text2midi](https://huggingface.co/amaai-lab/text2midi)
|
| 118 |
+
- Audio synthesis: FluidSynth with FluidR3 GM SoundFont
|
app.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VR Music Generator - HuggingFace Spaces Version
|
| 3 |
+
Generates music from text descriptions using the text2midi AI model.
|
| 4 |
+
Exposes a Gradio API for Unity integration.
|
| 5 |
+
Audio is streamed directly - no files are persisted.
|
| 6 |
+
"""
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import subprocess
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import pickle
|
| 14 |
+
import tempfile
|
| 15 |
+
import io
|
| 16 |
+
import numpy as np
|
| 17 |
+
from scipy.io import wavfile
|
| 18 |
+
from huggingface_hub import hf_hub_download
|
| 19 |
+
|
| 20 |
+
# Add text2midi model to path
|
| 21 |
+
sys.path.insert(0, "text2midi_repo")
|
| 22 |
+
|
| 23 |
+
repo_id = "amaai-lab/text2midi"
|
| 24 |
+
|
| 25 |
+
# Detect device
|
| 26 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 27 |
+
print(f"Using device: {device}")
|
| 28 |
+
|
| 29 |
+
# Global model variables
|
| 30 |
+
text2midi_model = None
|
| 31 |
+
midi_tokenizer = None
|
| 32 |
+
text_tokenizer = None
|
| 33 |
+
|
| 34 |
+
def load_text2midi_model():
|
| 35 |
+
"""Load the text2midi model and tokenizers."""
|
| 36 |
+
global text2midi_model, midi_tokenizer, text_tokenizer
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from model.transformer_model import Transformer
|
| 40 |
+
from transformers import T5Tokenizer
|
| 41 |
+
|
| 42 |
+
print("Loading text2midi model...")
|
| 43 |
+
|
| 44 |
+
# Download model files
|
| 45 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
|
| 46 |
+
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
|
| 47 |
+
|
| 48 |
+
print(f"Model path: {model_path}")
|
| 49 |
+
print(f"Tokenizer path: {tokenizer_path}")
|
| 50 |
+
|
| 51 |
+
# Load MIDI tokenizer
|
| 52 |
+
with open(tokenizer_path, "rb") as f:
|
| 53 |
+
midi_tokenizer = pickle.load(f)
|
| 54 |
+
|
| 55 |
+
vocab_size = len(midi_tokenizer)
|
| 56 |
+
print(f"Vocab size: {vocab_size}")
|
| 57 |
+
|
| 58 |
+
# Initialize and load model
|
| 59 |
+
text2midi_model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
|
| 60 |
+
text2midi_model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
| 61 |
+
text2midi_model.to(device)
|
| 62 |
+
text2midi_model.eval()
|
| 63 |
+
|
| 64 |
+
# Load T5 tokenizer for text encoding
|
| 65 |
+
text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 66 |
+
|
| 67 |
+
print("Text2midi model loaded successfully!")
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Warning: Could not load text2midi model: {e}")
|
| 72 |
+
import traceback
|
| 73 |
+
traceback.print_exc()
|
| 74 |
+
print("Falling back to simple MIDI generation...")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
# Try to load the model
|
| 78 |
+
MODEL_LOADED = load_text2midi_model()
|
| 79 |
+
|
| 80 |
+
def find_soundfont():
|
| 81 |
+
"""Find a SoundFont file on the system."""
|
| 82 |
+
common_paths = [
|
| 83 |
+
"/usr/share/sounds/sf2/FluidR3_GM.sf2",
|
| 84 |
+
"/usr/share/soundfonts/FluidR3_GM.sf2",
|
| 85 |
+
"/usr/share/sounds/sf2/default-GM.sf2",
|
| 86 |
+
"FluidR3_GM.sf2",
|
| 87 |
+
]
|
| 88 |
+
for path in common_paths:
|
| 89 |
+
if os.path.exists(path):
|
| 90 |
+
return path
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
SOUNDFONT_PATH = find_soundfont()
|
| 94 |
+
print(f"SoundFont: {SOUNDFONT_PATH or 'Not found'}")
|
| 95 |
+
|
| 96 |
+
def generate_midi_with_model(prompt: str, output_path: str, max_len: int = 512, temperature: float = 0.9):
|
| 97 |
+
"""Generate MIDI using the text2midi model."""
|
| 98 |
+
global text2midi_model, midi_tokenizer, text_tokenizer
|
| 99 |
+
|
| 100 |
+
# Tokenize input text
|
| 101 |
+
inputs = text_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
|
| 102 |
+
input_ids = inputs.input_ids.to(device)
|
| 103 |
+
attention_mask = inputs.attention_mask.to(device)
|
| 104 |
+
|
| 105 |
+
# Generate MIDI tokens
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
output = text2midi_model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
|
| 108 |
+
|
| 109 |
+
output_list = output[0].tolist()
|
| 110 |
+
|
| 111 |
+
# Decode to MIDI
|
| 112 |
+
generated_midi = midi_tokenizer.decode(output_list)
|
| 113 |
+
generated_midi.dump_midi(output_path)
|
| 114 |
+
|
| 115 |
+
return output_path
|
| 116 |
+
|
| 117 |
+
def midi_to_audio_bytes(midi_path: str, sample_rate: int = 44100) -> tuple:
|
| 118 |
+
"""
|
| 119 |
+
Convert MIDI to audio using FluidSynth, returning numpy array.
|
| 120 |
+
Uses stdout piping to avoid creating intermediate files.
|
| 121 |
+
"""
|
| 122 |
+
if not SOUNDFONT_PATH:
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
# Use FluidSynth to render MIDI to raw audio via stdout
|
| 126 |
+
# -T raw outputs raw audio, -F - outputs to stdout
|
| 127 |
+
result = subprocess.run([
|
| 128 |
+
"fluidsynth",
|
| 129 |
+
"-ni", # No interactive mode
|
| 130 |
+
"-T", "raw", # Output raw audio format
|
| 131 |
+
"-F", "-", # Output to stdout
|
| 132 |
+
"-r", str(sample_rate), # Sample rate
|
| 133 |
+
SOUNDFONT_PATH, # SoundFont file
|
| 134 |
+
midi_path, # MIDI file
|
| 135 |
+
], capture_output=True, timeout=120)
|
| 136 |
+
|
| 137 |
+
if result.returncode != 0:
|
| 138 |
+
print(f"FluidSynth error: {result.stderr.decode()}")
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
# Convert raw audio bytes to numpy array (16-bit signed, stereo)
|
| 142 |
+
audio_data = np.frombuffer(result.stdout, dtype=np.int16)
|
| 143 |
+
|
| 144 |
+
# FluidSynth outputs stereo by default, reshape if needed
|
| 145 |
+
if len(audio_data) > 0:
|
| 146 |
+
# Convert to float32 normalized [-1, 1] for Gradio
|
| 147 |
+
audio_float = audio_data.astype(np.float32) / 32768.0
|
| 148 |
+
return (sample_rate, audio_float)
|
| 149 |
+
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
def generate_music(prompt: str, max_length: int = 512, temperature: float = 0.9):
|
| 153 |
+
"""
|
| 154 |
+
Generate music from text prompt.
|
| 155 |
+
Returns audio data directly without saving files.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
prompt: Text description of the music to generate
|
| 159 |
+
max_length: Maximum length in tokens (256-2048)
|
| 160 |
+
temperature: Generation temperature (0.1-1.5)
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Tuple of (audio_data, status_message)
|
| 164 |
+
audio_data is (sample_rate, numpy_array) for Gradio
|
| 165 |
+
"""
|
| 166 |
+
if not prompt or not prompt.strip():
|
| 167 |
+
return None, "Please enter a music prompt"
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
# Create temporary MIDI file (auto-deleted when closed)
|
| 171 |
+
with tempfile.NamedTemporaryFile(suffix='.mid', delete=False) as midi_file:
|
| 172 |
+
midi_path = midi_file.name
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
# Generate MIDI using the model or fallback
|
| 176 |
+
if MODEL_LOADED:
|
| 177 |
+
status_prefix = "AI-generated"
|
| 178 |
+
generate_midi_with_model(prompt, midi_path, max_len=int(max_length), temperature=temperature)
|
| 179 |
+
else:
|
| 180 |
+
status_prefix = "Simple"
|
| 181 |
+
# Fallback: create simple MIDI
|
| 182 |
+
from midiutil import MIDIFile
|
| 183 |
+
midi = MIDIFile(1)
|
| 184 |
+
midi.addTempo(0, 0, 120)
|
| 185 |
+
notes = [60, 62, 64, 65, 67, 69, 71, 72]
|
| 186 |
+
for i, note in enumerate(notes[:min(len(prompt.split()), 8)]):
|
| 187 |
+
midi.addNote(0, 0, note, i, 1, 100)
|
| 188 |
+
with open(midi_path, "wb") as f:
|
| 189 |
+
midi.writeFile(f)
|
| 190 |
+
|
| 191 |
+
# Convert MIDI to audio
|
| 192 |
+
if SOUNDFONT_PATH:
|
| 193 |
+
audio_result = midi_to_audio_bytes(midi_path)
|
| 194 |
+
if audio_result:
|
| 195 |
+
return audio_result, f"{status_prefix} audio for: '{prompt[:50]}...'" if len(prompt) > 50 else f"{status_prefix} audio for: '{prompt}'"
|
| 196 |
+
else:
|
| 197 |
+
return None, f"Error: FluidSynth conversion failed"
|
| 198 |
+
else:
|
| 199 |
+
return None, f"Error: FluidSynth/SoundFont not available"
|
| 200 |
+
|
| 201 |
+
finally:
|
| 202 |
+
# Clean up temporary MIDI file
|
| 203 |
+
try:
|
| 204 |
+
os.unlink(midi_path)
|
| 205 |
+
except:
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
import traceback
|
| 210 |
+
traceback.print_exc()
|
| 211 |
+
return None, f"Error: {str(e)}"
|
| 212 |
+
|
| 213 |
+
# Create Gradio interface with API enabled
|
| 214 |
+
with gr.Blocks(title="VR Music Generator") as demo:
|
| 215 |
+
gr.Markdown("# VR Game Music Generator")
|
| 216 |
+
gr.Markdown("Generate music from text descriptions using the text2midi AI model")
|
| 217 |
+
|
| 218 |
+
if not MODEL_LOADED:
|
| 219 |
+
gr.Markdown("**Warning:** AI model not loaded - using simple placeholder MIDI")
|
| 220 |
+
if not SOUNDFONT_PATH:
|
| 221 |
+
gr.Markdown("**Note:** FluidSynth not configured - audio generation disabled")
|
| 222 |
+
|
| 223 |
+
with gr.Row():
|
| 224 |
+
with gr.Column():
|
| 225 |
+
prompt_input = gr.Textbox(
|
| 226 |
+
label="Music Prompt",
|
| 227 |
+
placeholder="A cheerful pop song with piano and drums in C major at 120 BPM",
|
| 228 |
+
lines=3
|
| 229 |
+
)
|
| 230 |
+
with gr.Row():
|
| 231 |
+
max_length = gr.Slider(
|
| 232 |
+
minimum=256,
|
| 233 |
+
maximum=2048,
|
| 234 |
+
value=512,
|
| 235 |
+
step=256,
|
| 236 |
+
label="Max Length (tokens)"
|
| 237 |
+
)
|
| 238 |
+
temperature = gr.Slider(
|
| 239 |
+
minimum=0.1,
|
| 240 |
+
maximum=1.5,
|
| 241 |
+
value=0.9,
|
| 242 |
+
step=0.1,
|
| 243 |
+
label="Temperature"
|
| 244 |
+
)
|
| 245 |
+
generate_btn = gr.Button("Generate Music", variant="primary")
|
| 246 |
+
|
| 247 |
+
with gr.Column():
|
| 248 |
+
audio_output = gr.Audio(label="Generated Music", type="numpy")
|
| 249 |
+
status_output = gr.Textbox(label="Status", lines=2)
|
| 250 |
+
|
| 251 |
+
generate_btn.click(
|
| 252 |
+
fn=generate_music,
|
| 253 |
+
inputs=[prompt_input, max_length, temperature],
|
| 254 |
+
outputs=[audio_output, status_output],
|
| 255 |
+
api_name="generate" # Exposes as /api/generate endpoint
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
gr.Markdown("---")
|
| 259 |
+
gr.Markdown("""
|
| 260 |
+
**Example prompts:**
|
| 261 |
+
- A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, and drums
|
| 262 |
+
- An energetic electronic trance track with synth bass and drums at 138 BPM
|
| 263 |
+
- A slow and emotional classical piece featuring cello and violin in C minor
|
| 264 |
+
- A cinematic electronic soundtrack with an epic and dark atmosphere
|
| 265 |
+
|
| 266 |
+
**API Usage (for Unity):**
|
| 267 |
+
```csharp
|
| 268 |
+
// POST to: https://YOUR-SPACE.hf.space/api/generate
|
| 269 |
+
// Body: {"data": ["your music prompt", 512, 0.9]}
|
| 270 |
+
// Response: {"data": [{"path": "audio_url", ...}, "status"]}
|
| 271 |
+
```
|
| 272 |
+
""")
|
| 273 |
+
|
| 274 |
+
# For HuggingFace Spaces - launch() is called automatically
|
| 275 |
+
# For local testing, uncomment below:
|
| 276 |
+
# if __name__ == "__main__":
|
| 277 |
+
# demo.launch()
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fluidsynth
|
| 2 |
+
fluid-soundfont-gm
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
huggingface-hub>=0.20.0
|
| 5 |
+
midiutil>=1.2.1
|
| 6 |
+
miditok>=3.0.0
|
| 7 |
+
scipy>=1.10.0
|
| 8 |
+
numpy>=1.24.0
|
| 9 |
+
tqdm>=4.65.0
|
text2midi_repo/.gitignore
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
text2midi_repo/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 AMAAI Lab
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
text2midi_repo/README.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Text2midi: Generating Symbolic Music from Captions
|
| 2 |
+
|
| 3 |
+
[Demo](https://huggingface.co/spaces/amaai-lab/text2midi) | [Model](https://huggingface.co/amaai-lab/text2midi) | [Examples](https://amaai-lab.github.io/Text2midi/) | [Paper](https://arxiv.org/abs/2412.16526) | [Dataset](https://huggingface.co/datasets/amaai-lab/MidiCaps)
|
| 4 |
+
|
| 5 |
+
[](https://huggingface.co/spaces/amaai-lab/text2midi)
|
| 6 |
+
</div>
|
| 7 |
+
|
| 8 |
+
**text2midi** is the first end-to-end model for generating MIDI files from textual descriptions. By leveraging pretrained large language models and a powerful autoregressive transformer decoder, **text2midi** allows users to create symbolic music that aligns with detailed textual prompts, including musical attributes like chords, tempo, and style. The details of the model are described in [this paper](https://arxiv.org/abs/2412.16526).
|
| 9 |
+
|
| 10 |
+
🔥 Live demo available on [HuggingFace Spaces](https://huggingface.co/spaces/amaai-lab/text2midi).
|
| 11 |
+
|
| 12 |
+
🔥 Update: Text2midi has been accepted at AAAI!
|
| 13 |
+
|
| 14 |
+
<div align="center">
|
| 15 |
+
<img src="text2midi_architecture.jpg" width="500"/>
|
| 16 |
+
</div>
|
| 17 |
+
|
| 18 |
+
## Quickstart Guide
|
| 19 |
+
|
| 20 |
+
Generate symbolic music from a text prompt:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
import pickle
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from transformers import T5Tokenizer
|
| 27 |
+
from model.transformer_model import Transformer
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
|
| 30 |
+
repo_id = "amaai-lab/text2midi"
|
| 31 |
+
# Download the model.bin file
|
| 32 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
|
| 33 |
+
# Download the vocab_remi.pkl file
|
| 34 |
+
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
|
| 35 |
+
|
| 36 |
+
if torch.cuda.is_available():
|
| 37 |
+
device = 'cuda'
|
| 38 |
+
elif torch.backends.mps.is_available():
|
| 39 |
+
device = 'mps'
|
| 40 |
+
else:
|
| 41 |
+
device = 'cpu'
|
| 42 |
+
|
| 43 |
+
print(f"Using device: {device}")
|
| 44 |
+
|
| 45 |
+
# Load the tokenizer dictionary
|
| 46 |
+
with open(tokenizer_path, "rb") as f:
|
| 47 |
+
r_tokenizer = pickle.load(f)
|
| 48 |
+
|
| 49 |
+
# Get the vocab size
|
| 50 |
+
vocab_size = len(r_tokenizer)
|
| 51 |
+
print("Vocab size: ", vocab_size)
|
| 52 |
+
model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
|
| 53 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 54 |
+
model.eval()
|
| 55 |
+
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 56 |
+
|
| 57 |
+
print('Model loaded.')
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Enter the text prompt and tokenize it
|
| 61 |
+
src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
|
| 62 |
+
print('Generating for prompt: ' + src)
|
| 63 |
+
|
| 64 |
+
inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
|
| 65 |
+
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
| 66 |
+
input_ids = input_ids.to(device)
|
| 67 |
+
attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
| 68 |
+
attention_mask = attention_mask.to(device)
|
| 69 |
+
|
| 70 |
+
# Generate the midi
|
| 71 |
+
output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
|
| 72 |
+
output_list = output[0].tolist()
|
| 73 |
+
generated_midi = r_tokenizer.decode(output_list)
|
| 74 |
+
generated_midi.dump_midi("output.mid")
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Installation
|
| 78 |
+
|
| 79 |
+
If you have CUDA supported machine:
|
| 80 |
+
```bash
|
| 81 |
+
git clone https://github.com/AMAAI-Lab/text2midi
|
| 82 |
+
cd text2midi
|
| 83 |
+
pip install -r requirements.txt
|
| 84 |
+
```
|
| 85 |
+
Alternatively, if you have MPS supported machine:
|
| 86 |
+
```bash
|
| 87 |
+
git clone https://github.com/AMAAI-Lab/text2midi
|
| 88 |
+
cd text2midi
|
| 89 |
+
pip install -r requirements-mac.txt
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Datasets
|
| 93 |
+
|
| 94 |
+
The model was trained using two datasets: [SymphonyNet](https://symphonynet.github.io/) for semi-supervised pretraining and MidiCaps for finetuning towards MIDI generation from captions.
|
| 95 |
+
The [MidiCaps dataset](https://huggingface.co/datasets/amaai-lab/MidiCaps) is a large-scale dataset of 168k MIDI files paired with rich text captions. These captions contain musical attributes such as key, tempo, style, and mood, making it ideal for text-to-MIDI generation tasks as described in [this paper](https://arxiv.org/abs/2406.02255).
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
## Citation
|
| 99 |
+
If you use text2midi in your research, please cite:
|
| 100 |
+
```
|
| 101 |
+
@inproceedings{bhandari2025text2midi,
|
| 102 |
+
title={text2midi: Generating Symbolic Music from Captions},
|
| 103 |
+
author={Keshav Bhandari and Abhinaba Roy and Kyra Wang and Geeta Puri and Simon Colton and Dorien Herremans},
|
| 104 |
+
booktitle={Proceedings of the 39th AAAI Conference on Artificial Intelligence (AAAI 2025)},
|
| 105 |
+
year={2025}
|
| 106 |
+
}
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Results of the Listening Study
|
| 110 |
+
|
| 111 |
+
Each question is rated on a Likert scale from 1 (very bad) to 7 (very good). The table shows the average ratings per question for each group of participants.
|
| 112 |
+
|
| 113 |
+
| Question | MidiCaps | text2midi | MuseCoco |
|
| 114 |
+
|---------------------|----------|-----------|----------|
|
| 115 |
+
| Musical Quality | 5.79 | 4.62 | 4.40 |
|
| 116 |
+
| Overall Matching | 5.42 | 4.67 | 4.07 |
|
| 117 |
+
| Genre Matching | 5.54 | 4.98 | 4.40 |
|
| 118 |
+
| Mood Matching | 5.70 | 5.00 | 4.32 |
|
| 119 |
+
| Key Matching | 4.61 | 3.64 | 3.36 |
|
| 120 |
+
| Chord Matching | 3.20 | 2.50 | 2.00 |
|
| 121 |
+
| Tempo Matching | 5.89 | 5.42 | 4.94 |
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
## Objective Evaluations
|
| 125 |
+
Results of objective evaluation for *all* of MidiCaps test set. Please not we have improved from all the numbers written in the paper (the numbers in paper are on a small subset of MidiCaps test set).
|
| 126 |
+
|
| 127 |
+
| Metric | text2midi | MidiCaps | MuseCoco |
|
| 128 |
+
|---------------------|-----------|----------|----------|
|
| 129 |
+
| CR ↑ | 2.31 | 3.43 | 2.12 |
|
| 130 |
+
| CLAP ↑ | 0.22 | 0.26 | 0.21 |
|
| 131 |
+
| TB (%) ↑ | 39.70 | - | 21.71 |
|
| 132 |
+
| TBT (%) ↑ | 65.80 | - | 54.63 |
|
| 133 |
+
| CK (%) ↑ | 33.60 | - | 13.70 |
|
| 134 |
+
| CKD (%) ↑ | 35.60 | - | 14.59 |
|
| 135 |
+
|
| 136 |
+
**Note**:
|
| 137 |
+
CR = Compression ratio
|
| 138 |
+
CLAP = CLAP score
|
| 139 |
+
TB = Tempo Bin
|
| 140 |
+
TBT = Tempo Bin with Tolerance
|
| 141 |
+
CK = Correct Key
|
| 142 |
+
CKD = Correct Key with Duplicates
|
| 143 |
+
↑ = Higher score is better.
|
| 144 |
+
|
| 145 |
+
## Training
|
| 146 |
+
To train text2midi, we recommend using accelerate for multi-GPU support. First, configure accelerate by running:
|
| 147 |
+
```bash
|
| 148 |
+
accelerate config
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
Then, use the following command to start training:
|
| 152 |
+
```bash
|
| 153 |
+
accelerate launch --multi_gpu --num_processes=4 train_accelerate.py --config ../config.yaml
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
## Inference
|
| 157 |
+
We support inference on CUDA, MPS and cpu. Please make sure you have pip installed the correct requirement file (requirments.txt for CUDA, requirements-mac.txt for MPS)
|
| 158 |
+
```bash
|
| 159 |
+
python model/transformer_model.py --caption <your intended descriptions>
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
|
text2midi_repo/artifacts/vocab.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b15330f5ab9c2cd32d359bcc64a1de320a7dc1227180a7658fd0b8f2d35e12c
|
| 3 |
+
size 239637
|
text2midi_repo/artifacts/vocab_remi.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:877d4511d6b9d5eea1c706199fe13a0de3d984c8f5d09c75d727ffe7f6f54ee6
|
| 3 |
+
size 27256
|
text2midi_repo/captions/captions.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c25b36b6196618ff79f111e24c52b97d0bde9e1b47d2c596650944ebe6dcac5
|
| 3 |
+
size 69068459
|
text2midi_repo/configs/config.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
text2midi_model:
|
| 3 |
+
decoder_max_sequence_length: 2048
|
| 4 |
+
decoder_num_layers: 18
|
| 5 |
+
decoder_num_heads: 8
|
| 6 |
+
decoder_d_model: 768
|
| 7 |
+
decoder_intermediate_size: 1024
|
| 8 |
+
use_moe: False
|
| 9 |
+
num_experts: 4
|
| 10 |
+
use_deepspeed: False
|
| 11 |
+
use_accelerate: True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
training:
|
| 15 |
+
text2midi_model:
|
| 16 |
+
epochs: 140
|
| 17 |
+
batch_size: 1
|
| 18 |
+
learning_rate: 0.000001
|
| 19 |
+
weight_decay: 0.01
|
| 20 |
+
gradient_accumulation_steps: 4
|
| 21 |
+
with_tracking: True
|
| 22 |
+
checkpointing_steps: epoch
|
| 23 |
+
report_to: wandb
|
| 24 |
+
output_dir: /root/output_test_new
|
| 25 |
+
per_device_train_batch_size: 32
|
| 26 |
+
use_scheduler: True
|
| 27 |
+
lr_scheduler_type: cosine #choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
|
| 28 |
+
num_warmup_steps: 100
|
| 29 |
+
save_every: 5
|
| 30 |
+
max_train_steps: None
|
| 31 |
+
scheduled_sampling: False
|
| 32 |
+
epsilon: 0
|
| 33 |
+
c: -0.0161
|
| 34 |
+
k: -0.312
|
| 35 |
+
|
| 36 |
+
raw_data:
|
| 37 |
+
caption_dataset_path: /root/captions/train.json
|
| 38 |
+
raw_data_folders:
|
| 39 |
+
lmd:
|
| 40 |
+
folder_path: /import/c4dm-datasets-ext/lakhmidi
|
| 41 |
+
file_extension: midi
|
| 42 |
+
symphonynet:
|
| 43 |
+
folder_path: /root/text2midi/data/symphonynet/data/SymphonyNet_Dataset
|
| 44 |
+
file_extension: mid
|
| 45 |
+
maestro:
|
| 46 |
+
folder_path: /import/c4dm-datasets/maestro-v3.0.0
|
| 47 |
+
file_extension: midi
|
| 48 |
+
pop909:
|
| 49 |
+
folder_path: /import/c4dm-datasets-ext/POP909
|
| 50 |
+
file_extension: mid
|
| 51 |
+
pijama:
|
| 52 |
+
folder_path: /import/c4dm-datasets/PiJAMA/data/midi
|
| 53 |
+
file_extension: midi
|
| 54 |
+
midicaps:
|
| 55 |
+
folder_path: /root/data
|
| 56 |
+
file_extension: mid
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
deepspeed_config:
|
| 60 |
+
deepspeed_config_path: /root/test/text2midi/configs/ds_config.json
|
| 61 |
+
|
| 62 |
+
artifact_folder: ../artifacts
|
text2midi_repo/configs/ds_config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 3 |
+
"gradient_accumulation_steps": 1,
|
| 4 |
+
"optimizer": {
|
| 5 |
+
"type": "Adam",
|
| 6 |
+
"params": {
|
| 7 |
+
"lr": 1e-4
|
| 8 |
+
}
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": true
|
| 12 |
+
},
|
| 13 |
+
"zero_optimization": {
|
| 14 |
+
"stage": 1,
|
| 15 |
+
"offload_optimizer": {
|
| 16 |
+
"device": "cpu",
|
| 17 |
+
"pin_memory": true
|
| 18 |
+
},
|
| 19 |
+
"offload_param": {
|
| 20 |
+
"device": "cpu",
|
| 21 |
+
"pin_memory": true
|
| 22 |
+
},
|
| 23 |
+
"overlap_comm": true,
|
| 24 |
+
"contiguous_gradients": true,
|
| 25 |
+
"sub_group_size": 1e9
|
| 26 |
+
},
|
| 27 |
+
"activation_checkpointing": {
|
| 28 |
+
"partition_activations": true,
|
| 29 |
+
"number_checkpoints": null,
|
| 30 |
+
"contiguous_memory_optimization":true,
|
| 31 |
+
"cpu_checkpointing": true
|
| 32 |
+
}
|
| 33 |
+
}
|
text2midi_repo/model/__pycache__/transformer_model.cpython-314.pyc
ADDED
|
Binary file (77.4 kB). View file
|
|
|
text2midi_repo/model/build_vocab.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import pickle
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import random
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
import sys
|
| 12 |
+
import pickle
|
| 13 |
+
|
| 14 |
+
# Parse command line arguments
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
|
| 17 |
+
help="Path to the config file")
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
# Load config file
|
| 21 |
+
with open(args.config, 'r') as f:
|
| 22 |
+
configs = yaml.safe_load(f)
|
| 23 |
+
|
| 24 |
+
artifact_folder = configs["artifact_folder"]
|
| 25 |
+
raw_data_folders = configs["raw_data"]["raw_data_folders"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Build the vocabulary
|
| 29 |
+
vocab = {}
|
| 30 |
+
|
| 31 |
+
instruments = ['piano', 'chromatic', 'organ', 'guitar', 'bass', 'strings', 'ensemble', 'brass', 'reed', 'pipe', 'synth_lead', 'synth_pad', 'synth_effect', 'ethnic', 'percussive', 'sfx', 'drum']
|
| 32 |
+
|
| 33 |
+
# Special tokens
|
| 34 |
+
for i in instruments:
|
| 35 |
+
vocab[('prefix', 'instrument', i)] = len(vocab) + 1
|
| 36 |
+
|
| 37 |
+
# MIDI velocity range from 0 to 127
|
| 38 |
+
velocity = [0, 15, 30, 45, 60, 75, 90, 105, 120, 127]
|
| 39 |
+
# MIDI pitch range from 0 to 127
|
| 40 |
+
midi_pitch = list(range(0, 128))
|
| 41 |
+
# Onsets are quantized in 10 milliseconds up to 5 seconds
|
| 42 |
+
onset = list(range(0, 5001, 10))
|
| 43 |
+
duration = list(range(0, 5001, 10))
|
| 44 |
+
|
| 45 |
+
# Add the instrument tokens to the vocabulary
|
| 46 |
+
for v in velocity:
|
| 47 |
+
for i in instruments:
|
| 48 |
+
for p in midi_pitch:
|
| 49 |
+
if i == "drum":
|
| 50 |
+
continue
|
| 51 |
+
else:
|
| 52 |
+
vocab[(i, p, v)] = len(vocab) + 1
|
| 53 |
+
|
| 54 |
+
for p in midi_pitch:
|
| 55 |
+
vocab[("drum", p)] = len(vocab) + 1
|
| 56 |
+
|
| 57 |
+
for o in onset:
|
| 58 |
+
vocab[("onset", o)] = len(vocab) + 1
|
| 59 |
+
for d in duration:
|
| 60 |
+
vocab[("dur", d)] = len(vocab) + 1
|
| 61 |
+
|
| 62 |
+
vocab["<T>"] = len(vocab) + 1
|
| 63 |
+
vocab["<D>"] = len(vocab) + 1
|
| 64 |
+
vocab["<U>"] = len(vocab) + 1
|
| 65 |
+
vocab["<SS>"] = len(vocab) + 1
|
| 66 |
+
print('vocab[<ss>]', vocab['<SS>'])
|
| 67 |
+
vocab["<S>"] = len(vocab) + 1
|
| 68 |
+
vocab["<E>"] = len(vocab) + 1
|
| 69 |
+
vocab["SEP"] = len(vocab) + 1
|
| 70 |
+
|
| 71 |
+
# Print the vocabulary length
|
| 72 |
+
print(f"Vocabulary length: {len(vocab)}")
|
| 73 |
+
|
| 74 |
+
# Save the vocabulary
|
| 75 |
+
vocab_path = os.path.join(artifact_folder, "vocab.pkl")
|
| 76 |
+
with open(vocab_path, 'wb') as f:
|
| 77 |
+
pickle.dump(vocab, f)
|
| 78 |
+
|
| 79 |
+
print(f"Vocabulary saved to {vocab_path}")
|
text2midi_repo/model/build_vocab_remi.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import pickle
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import random
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
import sys
|
| 12 |
+
import pickle
|
| 13 |
+
from miditok import REMI, TokenizerConfig # here we choose to use REMI
|
| 14 |
+
import jsonlines
|
| 15 |
+
|
| 16 |
+
# Parse command line arguments
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
|
| 19 |
+
help="Path to the config file")
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
# Load config file
|
| 23 |
+
with open(args.config, 'r') as f:
|
| 24 |
+
configs = yaml.safe_load(f)
|
| 25 |
+
|
| 26 |
+
artifact_folder = configs["artifact_folder"]
|
| 27 |
+
raw_data_folders = configs["raw_data"]["raw_data_folders"]
|
| 28 |
+
caption_dataset_path = configs["raw_data"]["caption_dataset_path"]
|
| 29 |
+
dataset_path = configs["raw_data"]["raw_data_folders"]["lmd"]["folder_path"]
|
| 30 |
+
|
| 31 |
+
# Our parameters
|
| 32 |
+
BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
|
| 33 |
+
TOKENIZER_PARAMS = {
|
| 34 |
+
"pitch_range": (21, 109),
|
| 35 |
+
"beat_res": BEAT_RES,
|
| 36 |
+
"num_velocities": 32,
|
| 37 |
+
"special_tokens": ["PAD", "BOS", "EOS", "MASK"],
|
| 38 |
+
"use_chords": False,
|
| 39 |
+
"use_rests": False,
|
| 40 |
+
"use_tempos": True,
|
| 41 |
+
"use_time_signatures": True,
|
| 42 |
+
"use_programs": True,
|
| 43 |
+
"num_tempos": 32, # number of tempo bins
|
| 44 |
+
"tempo_range": (40, 250), # (min, max)
|
| 45 |
+
}
|
| 46 |
+
config = TokenizerConfig(**TOKENIZER_PARAMS)
|
| 47 |
+
|
| 48 |
+
# Creates the tokenizer
|
| 49 |
+
tokenizer = REMI(config)
|
| 50 |
+
|
| 51 |
+
# Load the caption dataset
|
| 52 |
+
with jsonlines.open(caption_dataset_path) as reader:
|
| 53 |
+
captions = list(reader)
|
| 54 |
+
|
| 55 |
+
midi_paths = [os.path.join(dataset_path, captions[i]['location']) for i in range(len(captions))][0:30000]
|
| 56 |
+
|
| 57 |
+
# Builds the vocabulary with BPE
|
| 58 |
+
# vocab_size = 30000
|
| 59 |
+
# tokenizer.train(vocab_size=vocab_size, files_paths=midi_paths)
|
| 60 |
+
|
| 61 |
+
# Print the vocabulary length
|
| 62 |
+
print(f"Vocabulary length: {tokenizer.vocab_size}")
|
| 63 |
+
|
| 64 |
+
# Save the vocabulary
|
| 65 |
+
vocab_path = os.path.join(artifact_folder, "vocab_remi.pkl")
|
| 66 |
+
with open(vocab_path, 'wb') as f:
|
| 67 |
+
pickle.dump(tokenizer, f)
|
| 68 |
+
|
| 69 |
+
print(f"Vocabulary saved to {vocab_path}")
|
text2midi_repo/model/data_loader.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from aria.data.midi import MidiDict
|
| 2 |
+
# from aria.tokenizer import AbsTokenizer
|
| 3 |
+
# aria_tokenizer = AbsTokenizer()
|
| 4 |
+
import yaml
|
| 5 |
+
import jsonlines
|
| 6 |
+
import glob
|
| 7 |
+
import random
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import pickle
|
| 11 |
+
import json
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
import torch
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from transformers import T5Tokenizer
|
| 19 |
+
from spacy.lang.en import English
|
| 20 |
+
|
| 21 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Text2MusicDataset(Dataset):
|
| 26 |
+
def __init__(self, configs, captions, aria_tokenizer, mode="train", shuffle = False):
|
| 27 |
+
self.mode = mode
|
| 28 |
+
self.captions = captions
|
| 29 |
+
if shuffle:
|
| 30 |
+
random.shuffle(self.captions)
|
| 31 |
+
|
| 32 |
+
# Path to dataset
|
| 33 |
+
self.dataset_path = configs['raw_data']['raw_data_folders']['midicaps']['folder_path']
|
| 34 |
+
|
| 35 |
+
# Artifact folder
|
| 36 |
+
self.artifact_folder = configs['artifact_folder']
|
| 37 |
+
# Load encoder tokenizer json file dictionary
|
| 38 |
+
tokenizer_filepath = os.path.join(self.artifact_folder, "vocab.pkl")
|
| 39 |
+
self.aria_tokenizer = aria_tokenizer #AbsTokenizer()
|
| 40 |
+
# Load the pickled tokenizer dictionary
|
| 41 |
+
with open(tokenizer_filepath, 'rb') as f:
|
| 42 |
+
self.tokenizer = pickle.load(f)
|
| 43 |
+
|
| 44 |
+
# Load the sentencizer
|
| 45 |
+
self.nlp = English()
|
| 46 |
+
self.nlp.add_pipe('sentencizer')
|
| 47 |
+
|
| 48 |
+
# Load the FLAN-T5 tokenizer and encoder
|
| 49 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 50 |
+
|
| 51 |
+
# Get the maximum sequence length
|
| 52 |
+
self.decoder_max_sequence_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
|
| 53 |
+
|
| 54 |
+
# Print length of dataset
|
| 55 |
+
print("Length of dataset: ", len(self.captions))
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.captions)
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx):
|
| 61 |
+
caption = self.captions[idx]['caption']
|
| 62 |
+
midi_filepath = os.path.join(self.dataset_path, self.captions[idx]['location'])
|
| 63 |
+
|
| 64 |
+
# Read the MIDI file
|
| 65 |
+
midi = MidiDict.from_midi(midi_filepath)
|
| 66 |
+
if len(midi.note_msgs) == 0:
|
| 67 |
+
aria_tokenized_midi = ["<SS>", "<E>"]
|
| 68 |
+
else:
|
| 69 |
+
# Get the tokenized MIDI file
|
| 70 |
+
aria_tokenized_midi = self.aria_tokenizer.tokenize(midi)
|
| 71 |
+
# Add the start token
|
| 72 |
+
aria_tokenized_midi = ["<SS>"] + aria_tokenized_midi
|
| 73 |
+
|
| 74 |
+
# Drop a random number of sentences from the caption
|
| 75 |
+
do_drop = random.random() > 0.5
|
| 76 |
+
if do_drop:
|
| 77 |
+
sentences = list(self.nlp(caption).sents)
|
| 78 |
+
sent_length = len(sentences)
|
| 79 |
+
if sent_length<4:
|
| 80 |
+
how_many_to_drop = int(np.floor((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
|
| 81 |
+
else:
|
| 82 |
+
how_many_to_drop = int(np.ceil((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
|
| 83 |
+
which_to_drop = np.random.choice(sent_length, how_many_to_drop, replace=False)
|
| 84 |
+
new_sentences = [sentences[i] for i in range(sent_length) if i not in which_to_drop.tolist()]
|
| 85 |
+
new_sentences = " ".join([new_sentences[i].text for i in range(len(new_sentences))]) # combine sentences back with a space
|
| 86 |
+
else:
|
| 87 |
+
new_sentences = caption
|
| 88 |
+
|
| 89 |
+
# Tokenize the caption
|
| 90 |
+
inputs = self.t5_tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True)
|
| 91 |
+
input_ids = inputs['input_ids']
|
| 92 |
+
attention_mask = inputs['attention_mask']
|
| 93 |
+
|
| 94 |
+
# Tokenize the midi file
|
| 95 |
+
tokenized_midi = [self.tokenizer[token] for token in aria_tokenized_midi if token in self.tokenizer]
|
| 96 |
+
|
| 97 |
+
# Convert the tokenized MIDI file to a tensor and pad it to the maximum sequence length
|
| 98 |
+
if len(tokenized_midi) < self.decoder_max_sequence_length:
|
| 99 |
+
labels = F.pad(torch.tensor(tokenized_midi), (0, self.decoder_max_sequence_length - len(tokenized_midi))).to(torch.int64)
|
| 100 |
+
else:
|
| 101 |
+
labels = torch.tensor(tokenized_midi[-self.decoder_max_sequence_length:]).to(torch.int64)
|
| 102 |
+
|
| 103 |
+
return input_ids, attention_mask, labels
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
# Parse command line arguments
|
| 107 |
+
parser = argparse.ArgumentParser()
|
| 108 |
+
parser.add_argument("--config", type=str, default=os.path.normpath("../configs/config.yaml"),
|
| 109 |
+
help="Path to the config file")
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
|
| 112 |
+
# Load config file
|
| 113 |
+
with open(args.config, 'r') as f:
|
| 114 |
+
configs = yaml.safe_load(f)
|
| 115 |
+
|
| 116 |
+
caption_dataset_path = configs['raw_data']['caption_dataset_path']
|
| 117 |
+
# Load the caption dataset
|
| 118 |
+
with jsonlines.open(caption_dataset_path) as reader:
|
| 119 |
+
captions = list(reader)
|
| 120 |
+
|
| 121 |
+
# Load the dataset
|
| 122 |
+
dataset = Text2MusicDataset(configs, captions, mode="train", shuffle = True)
|
| 123 |
+
a,b,c = dataset[0]
|
| 124 |
+
print(c.shape)
|
text2midi_repo/model/data_loader_remi.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import jsonlines
|
| 3 |
+
import glob
|
| 4 |
+
import random
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import pickle
|
| 8 |
+
import json
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
import torch
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
from transformers import T5Tokenizer
|
| 16 |
+
from spacy.lang.en import English
|
| 17 |
+
|
| 18 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
+
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Text2MusicDataset(Dataset):
|
| 23 |
+
def __init__(self, configs, captions, remi_tokenizer, mode="train", shuffle = False):
|
| 24 |
+
self.mode = mode
|
| 25 |
+
self.captions = captions
|
| 26 |
+
if shuffle:
|
| 27 |
+
random.shuffle(self.captions)
|
| 28 |
+
|
| 29 |
+
# Path to dataset
|
| 30 |
+
self.dataset_path = configs['raw_data']['raw_data_folders']['midicaps']['folder_path']
|
| 31 |
+
|
| 32 |
+
# Artifact folder
|
| 33 |
+
self.artifact_folder = configs['artifact_folder']
|
| 34 |
+
# Load encoder tokenizer json file dictionary
|
| 35 |
+
# tokenizer_filepath = os.path.join(self.artifact_folder, "vocab.pkl")
|
| 36 |
+
# Load the pickled tokenizer dictionary
|
| 37 |
+
# with open(tokenizer_filepath, 'rb') as f:
|
| 38 |
+
# self.tokenizer = pickle.load(f)
|
| 39 |
+
|
| 40 |
+
self.remi_tokenizer = remi_tokenizer
|
| 41 |
+
|
| 42 |
+
# Load the sentencizer
|
| 43 |
+
self.nlp = English()
|
| 44 |
+
self.nlp.add_pipe('sentencizer')
|
| 45 |
+
|
| 46 |
+
# Load the FLAN-T5 tokenizer and encoder
|
| 47 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 48 |
+
|
| 49 |
+
# Get the maximum sequence length
|
| 50 |
+
self.decoder_max_sequence_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
|
| 51 |
+
|
| 52 |
+
# Print length of dataset
|
| 53 |
+
print("Length of dataset: ", len(self.captions))
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.captions)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
caption = self.captions[idx]['caption']
|
| 60 |
+
midi_filepath = os.path.join(self.dataset_path, self.captions[idx]['location'])
|
| 61 |
+
# print(f'midi filepath: {midi_filepath}')
|
| 62 |
+
# Read the MIDI file
|
| 63 |
+
tokens = self.remi_tokenizer(midi_filepath)
|
| 64 |
+
|
| 65 |
+
if len(tokens.ids) == 0:
|
| 66 |
+
tokenized_midi = [self.remi_tokenizer["BOS_None"], self.remi_tokenizer["EOS_None"]]
|
| 67 |
+
else:
|
| 68 |
+
tokenized_midi = [self.remi_tokenizer["BOS_None"]] + tokens.ids + [self.remi_tokenizer["EOS_None"]]
|
| 69 |
+
|
| 70 |
+
# Drop a random number of sentences from the caption
|
| 71 |
+
do_drop = random.random() > 0.5
|
| 72 |
+
if do_drop:
|
| 73 |
+
sentences = list(self.nlp(caption).sents)
|
| 74 |
+
sent_length = len(sentences)
|
| 75 |
+
if sent_length<4:
|
| 76 |
+
how_many_to_drop = int(np.floor((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
|
| 77 |
+
else:
|
| 78 |
+
how_many_to_drop = int(np.ceil((20 + random.random()*30)/100*sent_length)) # between 20 and 50 percent of sentences
|
| 79 |
+
which_to_drop = np.random.choice(sent_length, how_many_to_drop, replace=False)
|
| 80 |
+
new_sentences = [sentences[i] for i in range(sent_length) if i not in which_to_drop.tolist()]
|
| 81 |
+
new_sentences = " ".join([new_sentences[i].text for i in range(len(new_sentences))]) # combine sentences back with a space
|
| 82 |
+
else:
|
| 83 |
+
new_sentences = caption
|
| 84 |
+
|
| 85 |
+
# Tokenize the caption
|
| 86 |
+
inputs = self.t5_tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True)
|
| 87 |
+
input_ids = inputs['input_ids']
|
| 88 |
+
attention_mask = inputs['attention_mask']
|
| 89 |
+
|
| 90 |
+
# Convert the tokenized MIDI file to a tensor and pad it to the maximum sequence length
|
| 91 |
+
if len(tokenized_midi) < self.decoder_max_sequence_length:
|
| 92 |
+
labels = F.pad(torch.tensor(tokenized_midi), (0, self.decoder_max_sequence_length - len(tokenized_midi))).to(torch.int64)
|
| 93 |
+
else:
|
| 94 |
+
labels = torch.tensor(tokenized_midi[0:self.decoder_max_sequence_length]).to(torch.int64)
|
| 95 |
+
|
| 96 |
+
return input_ids, attention_mask, labels
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
# Parse command line arguments
|
| 100 |
+
parser = argparse.ArgumentParser()
|
| 101 |
+
parser.add_argument("--config", type=str, default=os.path.normpath("../configs/config.yaml"),
|
| 102 |
+
help="Path to the config file")
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
|
| 105 |
+
tokenizer_filepath = "../artifacts/vocab_remi.pkl"
|
| 106 |
+
# Load the tokenizer dictionary
|
| 107 |
+
with open(tokenizer_filepath, "rb") as f:
|
| 108 |
+
tokenizer = pickle.load(f)
|
| 109 |
+
bos_token_number = tokenizer["PAD_None"]
|
| 110 |
+
print(f"bos_token_number: {bos_token_number}")
|
| 111 |
+
|
| 112 |
+
# Load config file
|
| 113 |
+
with open(args.config, 'r') as f:
|
| 114 |
+
configs = yaml.safe_load(f)
|
| 115 |
+
caption_dataset_path = configs['raw_data']['caption_dataset_path']
|
| 116 |
+
# Load the caption dataset
|
| 117 |
+
with jsonlines.open(caption_dataset_path) as reader:
|
| 118 |
+
captions = list(reader)
|
| 119 |
+
|
| 120 |
+
# Load the dataset
|
| 121 |
+
dataset = Text2MusicDataset(configs, captions, remi_tokenizer=tokenizer, mode="train", shuffle = True)
|
| 122 |
+
a,b,c = dataset[0]
|
| 123 |
+
print(type(a))
|
| 124 |
+
generated_midi = tokenizer.decode(c)
|
| 125 |
+
print(type(generated_midi))
|
| 126 |
+
generated_midi.dump_midi("decoded_midi.mid")
|
text2midi_repo/model/dict_output.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text2midi_repo/model/train.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"])
|
| 3 |
+
# import torch
|
| 4 |
+
# print("CUDA device count:", torch.cuda.device_count())
|
| 5 |
+
# print("CUDA current device:", torch.cuda.current_device())
|
| 6 |
+
# print("CUDA device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
|
| 7 |
+
# os.environ['CUDA_VISIBLE_DEVICES']="2,3"
|
| 8 |
+
from torch.cuda import is_available as cuda_available, is_bf16_supported
|
| 9 |
+
from torch.backends.mps import is_available as mps_available
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
import yaml
|
| 13 |
+
import json
|
| 14 |
+
import pickle
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import deepspeed
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import torch
|
| 20 |
+
from torch import Tensor, argmax
|
| 21 |
+
from evaluate import load as load_metric
|
| 22 |
+
import sys
|
| 23 |
+
import argparse
|
| 24 |
+
import jsonlines
|
| 25 |
+
from data_loader import Text2MusicDataset
|
| 26 |
+
from transformer_model import Transformer
|
| 27 |
+
from torch.utils.data import DataLoader
|
| 28 |
+
|
| 29 |
+
# Parse command line arguments
|
| 30 |
+
# parser = argparse.ArgumentParser()
|
| 31 |
+
# parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
|
| 32 |
+
# help="Path to the config file")
|
| 33 |
+
# parser = deepspeed.add_config_arguments(parser)
|
| 34 |
+
# args = parser.parse_args()
|
| 35 |
+
config_file = "../configs/config.yaml"
|
| 36 |
+
# Load config file
|
| 37 |
+
with open(config_file, 'r') as f: ##args.config
|
| 38 |
+
configs = yaml.safe_load(f)
|
| 39 |
+
|
| 40 |
+
batch_size = configs['training']['text2midi_model']['batch_size']
|
| 41 |
+
learning_rate = configs['training']['text2midi_model']['learning_rate']
|
| 42 |
+
epochs = configs['training']['text2midi_model']['epochs']
|
| 43 |
+
|
| 44 |
+
# Artifact folder
|
| 45 |
+
artifact_folder = configs['artifact_folder']
|
| 46 |
+
# Load encoder tokenizer json file dictionary
|
| 47 |
+
tokenizer_filepath = os.path.join(artifact_folder, "vocab.pkl")
|
| 48 |
+
# Load the tokenizer dictionary
|
| 49 |
+
with open(tokenizer_filepath, "rb") as f:
|
| 50 |
+
tokenizer = pickle.load(f)
|
| 51 |
+
|
| 52 |
+
# Get the vocab size
|
| 53 |
+
vocab_size = len(tokenizer)+1
|
| 54 |
+
print("Vocab size: ", vocab_size)
|
| 55 |
+
|
| 56 |
+
caption_dataset_path = configs['raw_data']['caption_dataset_path']
|
| 57 |
+
# Load the caption dataset
|
| 58 |
+
with jsonlines.open(caption_dataset_path) as reader:
|
| 59 |
+
captions = list(reader)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def collate_fn(batch):
|
| 63 |
+
"""
|
| 64 |
+
Collate function for the DataLoader
|
| 65 |
+
:param batch: The batch
|
| 66 |
+
:return: The collated batch
|
| 67 |
+
"""
|
| 68 |
+
input_ids = [item[0].squeeze(0) for item in batch]
|
| 69 |
+
# Pad or trim batch to the same length
|
| 70 |
+
input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
|
| 71 |
+
attention_mask = [item[1].squeeze(0) for item in batch]
|
| 72 |
+
# Pad or trim batch to the same length
|
| 73 |
+
attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
|
| 74 |
+
labels = [item[2].squeeze(0) for item in batch]
|
| 75 |
+
# Pad or trim batch to the same length
|
| 76 |
+
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
|
| 77 |
+
return input_ids, attention_mask, labels
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Load the dataset
|
| 81 |
+
dataset = Text2MusicDataset(configs, captions, mode="train", shuffle = True)
|
| 82 |
+
data_length = len(dataset)
|
| 83 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Create the encoder-decoder model
|
| 87 |
+
# Initialize the model
|
| 88 |
+
d_model = configs['model']['text2midi_model']['decoder_d_model'] # Model dimension (same as FLAN-T5 encoder output dimension)
|
| 89 |
+
nhead = configs['model']['text2midi_model']['decoder_num_heads'] # Number of heads in the multiheadattention models
|
| 90 |
+
num_layers = configs['model']['text2midi_model']['decoder_num_layers'] # Number of decoder layers
|
| 91 |
+
max_len = configs['model']['text2midi_model']['decoder_max_sequence_length'] # Maximum length of the input sequence
|
| 92 |
+
use_moe = configs['model']['text2midi_model']['use_moe'] # Use mixture of experts
|
| 93 |
+
num_experts = configs['model']['text2midi_model']['num_experts'] # Number of experts in the mixture of experts
|
| 94 |
+
dim_feedforward = configs['model']['text2midi_model']['decoder_intermediate_size'] # Dimension of the feedforward network model
|
| 95 |
+
use_deepspeed = configs['model']['text2midi_model']['use_deepspeed'] # Use deepspeed
|
| 96 |
+
if use_deepspeed:
|
| 97 |
+
ds_config = configs['deepspeed_config']['deepspeed_config_path']
|
| 98 |
+
import deepspeed
|
| 99 |
+
from deepspeed.accelerator import get_accelerator
|
| 100 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 101 |
+
device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1)
|
| 102 |
+
and get_accelerator().is_available() else torch.device("cpu"))
|
| 103 |
+
deepspeed.init_distributed(dist_backend='nccl')
|
| 104 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
| 105 |
+
torch.backends.cuda.enable_flash_sdp(False)
|
| 106 |
+
else:
|
| 107 |
+
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
|
| 108 |
+
print(f"Device: {device}")
|
| 109 |
+
|
| 110 |
+
print_every = 10
|
| 111 |
+
model = Transformer(vocab_size, d_model, nhead, max_len, num_layers, dim_feedforward, use_moe, num_experts, device=device)
|
| 112 |
+
# Print number of parameters
|
| 113 |
+
num_params = sum(p.numel() for p in model.parameters())
|
| 114 |
+
print(f"Number of parameters: {num_params}")
|
| 115 |
+
# Print number of trainable parameters
|
| 116 |
+
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 117 |
+
print(f"Number of trainable parameters: {num_trainable_params}")
|
| 118 |
+
if not use_deepspeed:
|
| 119 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
| 120 |
+
criterion = nn.CrossEntropyLoss()
|
| 121 |
+
torch.cuda.empty_cache()
|
| 122 |
+
def train_model(model, dataloader, criterion, num_epochs, optimizer=None, data_length=1000):
|
| 123 |
+
if use_deepspeed:
|
| 124 |
+
parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 125 |
+
model, optimizer, _, _ = deepspeed.initialize(model=model,
|
| 126 |
+
optimizer=optimizer,
|
| 127 |
+
model_parameters=model.parameters(),
|
| 128 |
+
config=ds_config)
|
| 129 |
+
else:
|
| 130 |
+
model = model.to(device)
|
| 131 |
+
model.train()
|
| 132 |
+
for epoch in range(num_epochs):
|
| 133 |
+
total_loss = 0
|
| 134 |
+
with tqdm(total=int(data_length/batch_size), desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
|
| 135 |
+
for step, batch in enumerate(dataloader):
|
| 136 |
+
if use_deepspeed:
|
| 137 |
+
model.zero_grad()
|
| 138 |
+
else:
|
| 139 |
+
optimizer.zero_grad()
|
| 140 |
+
|
| 141 |
+
# Get the batch
|
| 142 |
+
encoder_input, attention_mask, tgt = batch
|
| 143 |
+
# print(encoder_input.shape)
|
| 144 |
+
encoder_input = encoder_input.to(device)
|
| 145 |
+
attention_mask = attention_mask.to(device)
|
| 146 |
+
tgt = tgt.to(device)
|
| 147 |
+
|
| 148 |
+
tgt_input = tgt[:, :-1]
|
| 149 |
+
tgt_output = tgt[:, 1:]
|
| 150 |
+
|
| 151 |
+
if use_moe:
|
| 152 |
+
outputs, aux_loss = model(encoder_input, attention_mask, tgt_input)
|
| 153 |
+
else:
|
| 154 |
+
outputs = model(encoder_input, attention_mask, tgt_input)
|
| 155 |
+
aux_loss = 0
|
| 156 |
+
|
| 157 |
+
loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_output.reshape(-1))
|
| 158 |
+
loss += aux_loss
|
| 159 |
+
if use_deepspeed:
|
| 160 |
+
model.backward(loss)
|
| 161 |
+
model.step()
|
| 162 |
+
else:
|
| 163 |
+
loss.backward()
|
| 164 |
+
optimizer.step()
|
| 165 |
+
|
| 166 |
+
total_loss += loss.item()
|
| 167 |
+
if step % print_every == 0:
|
| 168 |
+
pbar.set_postfix({"Loss": loss.item()})
|
| 169 |
+
pbar.update(1)
|
| 170 |
+
|
| 171 |
+
pbar.set_postfix({"Loss": total_loss / len(dataloader)})
|
| 172 |
+
pbar.update(1)
|
| 173 |
+
|
| 174 |
+
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# Train the model
|
| 178 |
+
if use_deepspeed:
|
| 179 |
+
train_model(model, dataloader, criterion, num_epochs=epochs)
|
| 180 |
+
else:
|
| 181 |
+
train_model(model, dataloader, criterion, num_epochs=epochs, optimizer=optimizer, data_length=data_length)
|
| 182 |
+
|
| 183 |
+
# Save the trained model
|
| 184 |
+
torch.save(model.state_dict(), "transformer_decoder_remi_plus.pth")
|
| 185 |
+
print("Model saved as transformer_decoder_remi_plus.pth")
|
text2midi_repo/model/train_accelerate.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import yaml
|
| 5 |
+
import math
|
| 6 |
+
import time
|
| 7 |
+
from transformers import get_scheduler
|
| 8 |
+
import wandb
|
| 9 |
+
import pickle
|
| 10 |
+
import numpy as np
|
| 11 |
+
import json
|
| 12 |
+
import jsonlines
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import torch
|
| 15 |
+
from accelerate import DistributedDataParallelKwargs, Accelerator
|
| 16 |
+
from accelerate.logging import get_logger
|
| 17 |
+
from data_loader_remi import Text2MusicDataset
|
| 18 |
+
from transformer_model import Transformer
|
| 19 |
+
from torch.utils.data import DataLoader
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
# Load config file
|
| 25 |
+
config_file = "../configs/config.yaml"
|
| 26 |
+
with open(config_file, 'r') as f:
|
| 27 |
+
configs = yaml.safe_load(f)
|
| 28 |
+
|
| 29 |
+
batch_size = configs['training']['text2midi_model']['batch_size']
|
| 30 |
+
learning_rate = configs['training']['text2midi_model']['learning_rate']
|
| 31 |
+
epochs = configs['training']['text2midi_model']['epochs']
|
| 32 |
+
artifact_folder = configs['artifact_folder']
|
| 33 |
+
tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
|
| 34 |
+
with open(tokenizer_filepath, "rb") as f:
|
| 35 |
+
tokenizer = pickle.load(f)
|
| 36 |
+
vocab_size = len(tokenizer)
|
| 37 |
+
caption_dataset_path = configs['raw_data']['caption_dataset_path']
|
| 38 |
+
|
| 39 |
+
# Load the caption dataset
|
| 40 |
+
with jsonlines.open(caption_dataset_path) as reader:
|
| 41 |
+
captions = list(reader)
|
| 42 |
+
# captions = list(reader)
|
| 43 |
+
|
| 44 |
+
def collate_fn(batch):
|
| 45 |
+
input_ids = [item[0].squeeze(0) for item in batch]
|
| 46 |
+
input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
|
| 47 |
+
attention_mask = [item[1].squeeze(0) for item in batch]
|
| 48 |
+
attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
|
| 49 |
+
labels = [item[2].squeeze(0) for item in batch]
|
| 50 |
+
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
|
| 51 |
+
return input_ids, attention_mask, labels
|
| 52 |
+
|
| 53 |
+
d_model = configs['model']['text2midi_model']['decoder_d_model']
|
| 54 |
+
nhead = configs['model']['text2midi_model']['decoder_num_heads']
|
| 55 |
+
num_layers = configs['model']['text2midi_model']['decoder_num_layers']
|
| 56 |
+
max_len = configs['model']['text2midi_model']['decoder_max_sequence_length']
|
| 57 |
+
use_moe = configs['model']['text2midi_model']['use_moe']
|
| 58 |
+
num_experts = configs['model']['text2midi_model']['num_experts']
|
| 59 |
+
dim_feedforward = configs['model']['text2midi_model']['decoder_intermediate_size']
|
| 60 |
+
gradient_accumulation_steps = configs['training']['text2midi_model']['gradient_accumulation_steps']
|
| 61 |
+
use_scheduler = configs['training']['text2midi_model']['use_scheduler']
|
| 62 |
+
checkpointing_steps = configs['training']['text2midi_model']['checkpointing_steps']
|
| 63 |
+
lr_scheduler_type = configs['training']['text2midi_model']['lr_scheduler_type']
|
| 64 |
+
num_warmup_steps = configs['training']['text2midi_model']['num_warmup_steps']
|
| 65 |
+
max_train_steps = configs['training']['text2midi_model']['max_train_steps']
|
| 66 |
+
with_tracking = configs['training']['text2midi_model']['with_tracking']
|
| 67 |
+
report_to = configs['training']['text2midi_model']['report_to']
|
| 68 |
+
output_dir = configs['training']['text2midi_model']['output_dir']
|
| 69 |
+
per_device_train_batch_size = configs['training']['text2midi_model']['per_device_train_batch_size']
|
| 70 |
+
save_every = configs['training']['text2midi_model']['save_every']
|
| 71 |
+
|
| 72 |
+
accelerator_log_kwargs = {}
|
| 73 |
+
if with_tracking:
|
| 74 |
+
accelerator_log_kwargs["log_with"] = report_to
|
| 75 |
+
# Remove the logging_dir argument in case of error
|
| 76 |
+
accelerator_log_kwargs["logging_dir"] = output_dir
|
| 77 |
+
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision='fp16', kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)], **accelerator_log_kwargs)
|
| 78 |
+
logging.basicConfig(
|
| 79 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 80 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 81 |
+
level=logging.INFO,
|
| 82 |
+
)
|
| 83 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 84 |
+
if accelerator.is_main_process:
|
| 85 |
+
if output_dir is None or output_dir == "":
|
| 86 |
+
output_dir = "saved/" + str(int(time.time()))
|
| 87 |
+
if not os.path.exists("saved"):
|
| 88 |
+
os.makedirs("saved")
|
| 89 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 90 |
+
elif output_dir is not None:
|
| 91 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 92 |
+
os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
|
| 93 |
+
accelerator.project_configuration.automatic_checkpoint_naming = False
|
| 94 |
+
wandb.login()
|
| 95 |
+
wandb.init(project="Text-2-Midi", settings=wandb.Settings(init_timeout=120))
|
| 96 |
+
accelerator.wait_for_everyone()
|
| 97 |
+
device = accelerator.device
|
| 98 |
+
|
| 99 |
+
with accelerator.main_process_first():
|
| 100 |
+
dataset = Text2MusicDataset(configs, captions, remi_tokenizer=tokenizer, mode="train", shuffle=True)
|
| 101 |
+
dataloader = DataLoader(dataset, batch_size=per_device_train_batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn, drop_last=True)
|
| 102 |
+
|
| 103 |
+
model = Transformer(vocab_size, d_model, nhead, max_len, num_layers, dim_feedforward, use_moe, num_experts, device=device)
|
| 104 |
+
model.load_state_dict(torch.load('/root/output_test_new/epoch_68/pytorch_model.bin', map_location=device))
|
| 105 |
+
def count_parameters(model):
|
| 106 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 107 |
+
|
| 108 |
+
total_params = count_parameters(model)
|
| 109 |
+
print(f"Total number of trainable parameters: {total_params}")
|
| 110 |
+
|
| 111 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
| 112 |
+
overrode_max_train_steps = False
|
| 113 |
+
num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps)
|
| 114 |
+
print("num_update_steps_per_epoch", num_update_steps_per_epoch)
|
| 115 |
+
print("max_train_steps", max_train_steps)
|
| 116 |
+
if max_train_steps == 'None':
|
| 117 |
+
max_train_steps = epochs * num_update_steps_per_epoch
|
| 118 |
+
print("max_train_steps", max_train_steps)
|
| 119 |
+
overrode_max_train_steps = True
|
| 120 |
+
num_warmup_steps = 20000
|
| 121 |
+
elif isinstance(max_train_steps, str):
|
| 122 |
+
max_train_steps = int(max_train_steps)
|
| 123 |
+
lr_scheduler = get_scheduler(
|
| 124 |
+
name=lr_scheduler_type,
|
| 125 |
+
optimizer=optimizer,
|
| 126 |
+
num_warmup_steps=num_warmup_steps,
|
| 127 |
+
num_training_steps=max_train_steps,
|
| 128 |
+
)
|
| 129 |
+
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
| 130 |
+
dataloader = accelerator.prepare(dataloader)
|
| 131 |
+
if overrode_max_train_steps:
|
| 132 |
+
max_train_steps = epochs * num_update_steps_per_epoch
|
| 133 |
+
epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
| 134 |
+
# checkpointing_steps = checkpointing_steps if checkpointing_steps.isdigit() else None
|
| 135 |
+
total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
| 136 |
+
logger.info("***** Running training *****")
|
| 137 |
+
logger.info(f" Num examples = {len(dataset)}")
|
| 138 |
+
logger.info(f" Num Epochs = {epochs}")
|
| 139 |
+
logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}")
|
| 140 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 141 |
+
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
| 142 |
+
logger.info(f" Total optimization steps = {max_train_steps}")
|
| 143 |
+
|
| 144 |
+
criterion = nn.CrossEntropyLoss()
|
| 145 |
+
|
| 146 |
+
def train_model_accelerate(model, dataloader, criterion, num_epochs, max_train_steps, optimizer=None, out_dir=None, checkpointing_steps='epoch', with_tracking=False, save_every=5, device='cpu'):
|
| 147 |
+
progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
|
| 148 |
+
completed_steps = 0
|
| 149 |
+
starting_epoch = 68
|
| 150 |
+
model = model.to(device)
|
| 151 |
+
model.train()
|
| 152 |
+
best_loss = np.inf
|
| 153 |
+
for epoch in range(starting_epoch, num_epochs):
|
| 154 |
+
total_loss = 0
|
| 155 |
+
for step, batch in enumerate(dataloader):
|
| 156 |
+
with accelerator.accumulate(model):
|
| 157 |
+
encoder_input, attention_mask, tgt = batch
|
| 158 |
+
encoder_input = encoder_input.to(device)
|
| 159 |
+
attention_mask = attention_mask.to(device)
|
| 160 |
+
tgt = tgt.to(device)
|
| 161 |
+
tgt_input = tgt[:, :-1]
|
| 162 |
+
tgt_output = tgt[:, 1:]
|
| 163 |
+
if use_moe:
|
| 164 |
+
outputs, aux_loss = model(encoder_input, attention_mask, tgt_input)
|
| 165 |
+
else:
|
| 166 |
+
outputs = model(encoder_input, attention_mask, tgt_input)
|
| 167 |
+
aux_loss = 0
|
| 168 |
+
loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_output.reshape(-1))
|
| 169 |
+
loss += aux_loss
|
| 170 |
+
total_loss += loss.detach().float()
|
| 171 |
+
accelerator.backward(loss)
|
| 172 |
+
optimizer.step()
|
| 173 |
+
lr_scheduler.step()
|
| 174 |
+
optimizer.zero_grad()
|
| 175 |
+
if accelerator.sync_gradients:
|
| 176 |
+
progress_bar.set_postfix({"Loss": loss.item()})
|
| 177 |
+
progress_bar.update(1)
|
| 178 |
+
completed_steps += 1
|
| 179 |
+
if accelerator.is_main_process:
|
| 180 |
+
result = {}
|
| 181 |
+
result["epoch"] = epoch+1
|
| 182 |
+
result["step"] = completed_steps
|
| 183 |
+
result["train_loss"] = round(total_loss.item()/(gradient_accumulation_steps*completed_steps),4)
|
| 184 |
+
wandb.log(result)
|
| 185 |
+
if isinstance(checkpointing_steps, int):
|
| 186 |
+
if completed_steps % checkpointing_steps == 0:
|
| 187 |
+
output_dir = f"step_{completed_steps }"
|
| 188 |
+
if out_dir is not None:
|
| 189 |
+
output_dir = os.path.join(out_dir, output_dir)
|
| 190 |
+
accelerator.save_state(output_dir)
|
| 191 |
+
if completed_steps >= max_train_steps:
|
| 192 |
+
break
|
| 193 |
+
if accelerator.is_main_process:
|
| 194 |
+
result = {}
|
| 195 |
+
result["epoch"] = epoch+1
|
| 196 |
+
result["step"] = completed_steps
|
| 197 |
+
result["train_loss"] = round(total_loss.item()/len(dataloader), 4)
|
| 198 |
+
result_string = "Epoch: {}, Loss Train: {}\n".format(epoch, result["train_loss"])
|
| 199 |
+
accelerator.print(result_string)
|
| 200 |
+
with open("{}/summary.jsonl".format(out_dir), "a") as f:
|
| 201 |
+
f.write(json.dumps(result) + "\n\n")
|
| 202 |
+
logger.info(result)
|
| 203 |
+
if accelerator.is_main_process:
|
| 204 |
+
if total_loss < best_loss:
|
| 205 |
+
best_loss = total_loss
|
| 206 |
+
save_checkpoint = True
|
| 207 |
+
else:
|
| 208 |
+
save_checkpoint = False
|
| 209 |
+
accelerator.wait_for_everyone()
|
| 210 |
+
if accelerator.is_main_process and checkpointing_steps == "best":
|
| 211 |
+
if save_checkpoint:
|
| 212 |
+
accelerator.save_state("{}/{}".format(out_dir, "best"))
|
| 213 |
+
if (epoch + 1) % save_every == 0:
|
| 214 |
+
logger.info("Saving checkpoint at epoch {}".format(epoch+1))
|
| 215 |
+
accelerator.save_state("{}/{}".format(out_dir, "epoch_" + str(epoch+1)))
|
| 216 |
+
if accelerator.is_main_process and checkpointing_steps == "epoch":
|
| 217 |
+
accelerator.save_state("{}/{}".format(out_dir, "epoch_" + str(epoch+1)))
|
| 218 |
+
|
| 219 |
+
train_model_accelerate(model, dataloader, criterion, num_epochs=epochs, max_train_steps=max_train_steps,
|
| 220 |
+
optimizer=optimizer, out_dir=output_dir, checkpointing_steps=checkpointing_steps,
|
| 221 |
+
with_tracking=with_tracking, save_every=save_every, device=device)
|
| 222 |
+
|
| 223 |
+
# torch.save(model.state_dict(), "transformer_decoder_remi_plus.pth")
|
| 224 |
+
# print("Model saved as transformer_decoder_remi_plus.pth")
|
text2midi_repo/model/train_hf.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torch.cuda import is_available as cuda_available, is_bf16_supported
|
| 3 |
+
from torch.backends.mps import is_available as mps_available
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
import yaml
|
| 7 |
+
import json
|
| 8 |
+
import pickle
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from transformers import T5EncoderModel, BertModel, BertConfig, Trainer, TrainingArguments, PreTrainedModel, T5Config, T5EncoderModel, BertLMHeadModel
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor, argmax
|
| 15 |
+
from evaluate import load as load_metric
|
| 16 |
+
import sys
|
| 17 |
+
import argparse
|
| 18 |
+
import jsonlines
|
| 19 |
+
from data_loader_remi import Text2MusicDataset
|
| 20 |
+
from transformer_model import Transformer
|
| 21 |
+
from torch.utils.data import DataLoader
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
|
| 26 |
+
help="Path to the config file")
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
# Load config file
|
| 30 |
+
with open(args.config, 'r') as f: ##args.config
|
| 31 |
+
configs = yaml.safe_load(f)
|
| 32 |
+
|
| 33 |
+
batch_size = configs['training']['text2midi_model']['batch_size']
|
| 34 |
+
learning_rate = configs['training']['text2midi_model']['learning_rate']
|
| 35 |
+
epochs = configs['training']['text2midi_model']['epochs']
|
| 36 |
+
|
| 37 |
+
# Artifact folder
|
| 38 |
+
artifact_folder = configs['artifact_folder']
|
| 39 |
+
# Load remi tokenizer
|
| 40 |
+
tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
|
| 41 |
+
# Load the tokenizer dictionary
|
| 42 |
+
with open(tokenizer_filepath, "rb") as f:
|
| 43 |
+
tokenizer = pickle.load(f)
|
| 44 |
+
|
| 45 |
+
# Get the vocab size
|
| 46 |
+
vocab_size = tokenizer.vocab_size + 1
|
| 47 |
+
print("Vocab size: ", vocab_size)
|
| 48 |
+
|
| 49 |
+
caption_dataset_path = configs['raw_data']['caption_dataset_path']
|
| 50 |
+
# Load the caption dataset
|
| 51 |
+
with jsonlines.open(caption_dataset_path) as reader:
|
| 52 |
+
captions = list(reader)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def collate_fn(batch):
|
| 56 |
+
"""
|
| 57 |
+
Collate function for the DataLoader
|
| 58 |
+
:param batch: The batch
|
| 59 |
+
:return: The collated batch
|
| 60 |
+
"""
|
| 61 |
+
input_ids = [item[0].squeeze(0) for item in batch]
|
| 62 |
+
# Pad or trim batch to the same length
|
| 63 |
+
input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
|
| 64 |
+
attention_mask = [item[1].squeeze(0) for item in batch]
|
| 65 |
+
# Pad or trim batch to the same length
|
| 66 |
+
attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
|
| 67 |
+
labels = [item[2].squeeze(0) for item in batch]
|
| 68 |
+
# Pad or trim batch to the same length
|
| 69 |
+
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
|
| 70 |
+
decoder_input_ids = labels[:, :-1].contiguous()
|
| 71 |
+
labels = labels[:, 1:].contiguous()
|
| 72 |
+
# return input_ids, attention_mask, labels
|
| 73 |
+
return {
|
| 74 |
+
'input_ids': input_ids,
|
| 75 |
+
'attention_mask': attention_mask,
|
| 76 |
+
'decoder_input_ids': decoder_input_ids,
|
| 77 |
+
'labels': labels
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Train test split captions
|
| 81 |
+
random.seed(444)
|
| 82 |
+
random.shuffle(captions)
|
| 83 |
+
train_size = int(0.8 * len(captions))
|
| 84 |
+
train_captions = captions[:train_size]
|
| 85 |
+
test_captions = captions[train_size:]
|
| 86 |
+
|
| 87 |
+
# Load the dataset
|
| 88 |
+
train_dataset = Text2MusicDataset(configs, train_captions, tokenizer, mode="train", shuffle = True)
|
| 89 |
+
print(f"Train Data length: {len(train_dataset)}")
|
| 90 |
+
test_dataset = Text2MusicDataset(configs, test_captions, tokenizer, mode="eval", shuffle = False)
|
| 91 |
+
print(f"Test Data length: {len(test_dataset)}")
|
| 92 |
+
|
| 93 |
+
# Dataloader
|
| 94 |
+
# train_dataset = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
|
| 95 |
+
# test_dataset = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
|
| 96 |
+
|
| 97 |
+
# Create the encoder-decoder model
|
| 98 |
+
class CustomEncoderDecoderModel(PreTrainedModel):
|
| 99 |
+
def __init__(self, encoder, decoder, encoder_config, decoder_config):
|
| 100 |
+
super().__init__(encoder_config)
|
| 101 |
+
self.encoder = encoder
|
| 102 |
+
self.decoder = decoder
|
| 103 |
+
self.encoder_config = encoder_config
|
| 104 |
+
self.decoder_config = decoder_config
|
| 105 |
+
|
| 106 |
+
def forward(self, input_ids, decoder_input_ids, attention_mask=None, decoder_attention_mask=None, labels=None, **kwargs):
|
| 107 |
+
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 108 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
| 109 |
+
|
| 110 |
+
# Assume the decoder can take encoder hidden states as inputs
|
| 111 |
+
output = self.decoder(
|
| 112 |
+
input_ids=decoder_input_ids,
|
| 113 |
+
attention_mask=decoder_attention_mask,
|
| 114 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 115 |
+
encoder_attention_mask=attention_mask,
|
| 116 |
+
labels=labels
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
logits = output.logits
|
| 120 |
+
|
| 121 |
+
loss = output.loss
|
| 122 |
+
|
| 123 |
+
return {'loss': loss, 'logits': logits}
|
| 124 |
+
|
| 125 |
+
# Load the pre-trained FLAN T5 encoder and freeze its parameters
|
| 126 |
+
flan_t5_encoder = T5EncoderModel.from_pretrained('google/flan-t5-small')
|
| 127 |
+
for param in flan_t5_encoder.parameters():
|
| 128 |
+
param.requires_grad = False
|
| 129 |
+
|
| 130 |
+
# Load the configurations
|
| 131 |
+
encoder_config = T5Config.from_pretrained('google/flan-t5-small')
|
| 132 |
+
|
| 133 |
+
# Define a configuration for the BERT decoder
|
| 134 |
+
config_decoder = BertConfig()
|
| 135 |
+
config_decoder.vocab_size = vocab_size
|
| 136 |
+
config_decoder.max_position_embeddings = configs['model']['text2midi_model']['decoder_max_sequence_length']
|
| 137 |
+
config_decoder.max_length = configs['model']['text2midi_model']['decoder_max_sequence_length']
|
| 138 |
+
config_decoder.bos_token_id = tokenizer["BOS_None"]
|
| 139 |
+
config_decoder.eos_token_id = tokenizer["EOS_None"]
|
| 140 |
+
config_decoder.pad_token_id = 0
|
| 141 |
+
config_decoder.num_hidden_layers = configs['model']['text2midi_model']['decoder_num_layers']
|
| 142 |
+
config_decoder.num_attention_heads = configs['model']['text2midi_model']['decoder_num_heads']
|
| 143 |
+
config_decoder.hidden_size = configs['model']['text2midi_model']['decoder_d_model']
|
| 144 |
+
config_decoder.intermediate_size = configs['model']['text2midi_model']['decoder_intermediate_size']
|
| 145 |
+
|
| 146 |
+
# set decoder config to causal lm
|
| 147 |
+
config_decoder.is_decoder = True
|
| 148 |
+
config_decoder.add_cross_attention = True
|
| 149 |
+
config_decoder.tie_encoder_decoder = False
|
| 150 |
+
config_decoder.tie_word_embeddings = False
|
| 151 |
+
|
| 152 |
+
# Create a BERT model based on the configuration
|
| 153 |
+
custom_decoder = BertLMHeadModel(config_decoder)
|
| 154 |
+
|
| 155 |
+
# Initialize the custom model
|
| 156 |
+
model = CustomEncoderDecoderModel(
|
| 157 |
+
encoder=flan_t5_encoder,
|
| 158 |
+
decoder=custom_decoder,
|
| 159 |
+
encoder_config=encoder_config,
|
| 160 |
+
decoder_config=config_decoder
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Print the number of parameters in the model
|
| 164 |
+
num_params = sum(p.numel() for p in model.parameters())
|
| 165 |
+
print(f"Number of parameters in the model: {num_params}")
|
| 166 |
+
|
| 167 |
+
# Create config for the Trainer
|
| 168 |
+
USE_CUDA = cuda_available()
|
| 169 |
+
print(f"USE_CUDA: {USE_CUDA}")
|
| 170 |
+
if not cuda_available():
|
| 171 |
+
FP16 = FP16_EVAL = BF16 = BF16_EVAL = False
|
| 172 |
+
elif is_bf16_supported():
|
| 173 |
+
BF16 = BF16_EVAL = True
|
| 174 |
+
FP16 = FP16_EVAL = False
|
| 175 |
+
else:
|
| 176 |
+
BF16 = BF16_EVAL = False
|
| 177 |
+
FP16 = FP16_EVAL = True
|
| 178 |
+
USE_MPS = not USE_CUDA and mps_available()
|
| 179 |
+
|
| 180 |
+
metrics = {metric: load_metric(metric) for metric in ["accuracy"]}
|
| 181 |
+
|
| 182 |
+
def compute_metrics(eval_pred):
|
| 183 |
+
"""
|
| 184 |
+
Compute metrics for pretraining.
|
| 185 |
+
|
| 186 |
+
Must use preprocess_logits function that converts logits to predictions (argmax or sampling).
|
| 187 |
+
|
| 188 |
+
:param eval_pred: EvalPrediction containing predictions and labels
|
| 189 |
+
:return: metrics
|
| 190 |
+
"""
|
| 191 |
+
predictions, labels = eval_pred
|
| 192 |
+
not_pad_mask = labels != 0
|
| 193 |
+
labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]
|
| 194 |
+
return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten())
|
| 195 |
+
|
| 196 |
+
def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:
|
| 197 |
+
"""
|
| 198 |
+
Preprocess the logits before accumulating them during evaluation.
|
| 199 |
+
|
| 200 |
+
This allows to significantly reduce the memory usage and make the training tractable.
|
| 201 |
+
"""
|
| 202 |
+
pred_ids = argmax(logits, dim=-1) # long dtype
|
| 203 |
+
return pred_ids
|
| 204 |
+
|
| 205 |
+
run_name = configs['training']['text2midi_model']['run_name']
|
| 206 |
+
model_dir = os.path.join(artifact_folder, run_name)
|
| 207 |
+
log_dir = os.path.join(model_dir, "logs")
|
| 208 |
+
# Clear the logs directory before training
|
| 209 |
+
os.system(f"rm -rf {log_dir}")
|
| 210 |
+
|
| 211 |
+
# Define the training arguments
|
| 212 |
+
training_args = TrainingArguments(
|
| 213 |
+
output_dir=model_dir,
|
| 214 |
+
per_device_train_batch_size=batch_size,
|
| 215 |
+
per_device_eval_batch_size=batch_size,
|
| 216 |
+
save_strategy="epoch", # "steps" or "epoch"
|
| 217 |
+
save_total_limit=1,
|
| 218 |
+
learning_rate=learning_rate,
|
| 219 |
+
lr_scheduler_type="cosine_with_restarts",
|
| 220 |
+
warmup_ratio=0.3,
|
| 221 |
+
max_grad_norm=3.0,
|
| 222 |
+
weight_decay= configs['training']['text2midi_model']['weight_decay'],
|
| 223 |
+
num_train_epochs=epochs,
|
| 224 |
+
evaluation_strategy="epoch",
|
| 225 |
+
gradient_accumulation_steps=configs['training']['text2midi_model']['gradient_accumulation_steps'],
|
| 226 |
+
# gradient_checkpointing=True,
|
| 227 |
+
optim="adafactor",
|
| 228 |
+
seed=444,
|
| 229 |
+
logging_strategy="steps",
|
| 230 |
+
logging_steps=10,
|
| 231 |
+
logging_dir=log_dir,
|
| 232 |
+
no_cuda=not USE_CUDA,
|
| 233 |
+
fp16=FP16,
|
| 234 |
+
fp16_full_eval=FP16_EVAL,
|
| 235 |
+
bf16=BF16,
|
| 236 |
+
bf16_full_eval=BF16_EVAL,
|
| 237 |
+
load_best_model_at_end=True,
|
| 238 |
+
# metric_for_best_model="loss",
|
| 239 |
+
greater_is_better=False,
|
| 240 |
+
report_to="tensorboard",
|
| 241 |
+
run_name=run_name,
|
| 242 |
+
push_to_hub=False,
|
| 243 |
+
dataloader_num_workers=5
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# # Define the Trainer
|
| 247 |
+
# trainer = Trainer(
|
| 248 |
+
# model=model,
|
| 249 |
+
# args=training_args,
|
| 250 |
+
# train_dataset=train_dataset,
|
| 251 |
+
# eval_dataset=test_dataset,
|
| 252 |
+
# compute_metrics=compute_metrics,
|
| 253 |
+
# preprocess_logits_for_metrics=preprocess_logits,
|
| 254 |
+
# # callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
|
| 255 |
+
# )
|
| 256 |
+
|
| 257 |
+
class CustomTrainer(Trainer):
|
| 258 |
+
def get_train_dataloader(self):
|
| 259 |
+
return DataLoader(self.train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
|
| 260 |
+
|
| 261 |
+
def get_eval_dataloader(self, eval_dataset):
|
| 262 |
+
return DataLoader(eval_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
|
| 263 |
+
|
| 264 |
+
def get_test_dataloader(self, test_dataset):
|
| 265 |
+
return DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=5)
|
| 266 |
+
|
| 267 |
+
# Define the Trainer
|
| 268 |
+
trainer = CustomTrainer(
|
| 269 |
+
model=model,
|
| 270 |
+
args=training_args,
|
| 271 |
+
train_dataset=train_dataset,
|
| 272 |
+
eval_dataset=test_dataset,
|
| 273 |
+
compute_metrics=compute_metrics,
|
| 274 |
+
preprocess_logits_for_metrics=preprocess_logits,
|
| 275 |
+
# callbacks=[EarlyStoppingCallback(early_stopping_patience=30)]
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Train and save the model
|
| 279 |
+
train_result = trainer.train()
|
| 280 |
+
trainer.save_model()
|
| 281 |
+
trainer.log_metrics("train", train_result.metrics)
|
| 282 |
+
trainer.save_metrics("train", train_result.metrics)
|
| 283 |
+
trainer.save_state()
|
text2midi_repo/model/transformer_model.py
ADDED
|
@@ -0,0 +1,1509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from aria.tokenizer import AbsTokenizer
|
| 2 |
+
# aria_tokenizer = AbsTokenizer()
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
from typing import Optional, Any, Union, Callable
|
| 6 |
+
import torch.multiprocessing as mp
|
| 7 |
+
from torch.nn import DataParallel
|
| 8 |
+
import jsonlines
|
| 9 |
+
import math
|
| 10 |
+
import time
|
| 11 |
+
import torch
|
| 12 |
+
import os
|
| 13 |
+
import warnings
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
# from aria.tokenizer import AbsTokenizer
|
| 17 |
+
import pickle
|
| 18 |
+
from torch.nn import Module, LayerNorm, Dropout, Linear
|
| 19 |
+
from torch.nn.modules.container import ModuleList
|
| 20 |
+
from torch.nn.modules.activation import MultiheadAttention
|
| 21 |
+
from torch.multiprocessing import Process, set_start_method
|
| 22 |
+
from torch.nn.init import xavier_uniform_
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
|
| 26 |
+
from st_moe_pytorch import MoE
|
| 27 |
+
from st_moe_pytorch import SparseMoEBlock
|
| 28 |
+
|
| 29 |
+
from einops import rearrange
|
| 30 |
+
|
| 31 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
| 32 |
+
|
| 33 |
+
import sys
|
| 34 |
+
import torch.distributed as dist
|
| 35 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 36 |
+
from torch.utils.data import DataLoader, Dataset
|
| 37 |
+
|
| 38 |
+
import torch.profiler
|
| 39 |
+
|
| 40 |
+
from accelerate import Accelerator
|
| 41 |
+
import argparse # Add this import
|
| 42 |
+
|
| 43 |
+
class CaptionDataset(Dataset):
|
| 44 |
+
def __init__(self, captions):
|
| 45 |
+
self.captions = captions
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return len(self.captions)
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, idx):
|
| 51 |
+
return self.captions[idx]
|
| 52 |
+
|
| 53 |
+
def custom_collate_fn(batch):
|
| 54 |
+
captions = [item['caption'] for item in batch]
|
| 55 |
+
locations = [item['location'] for item in batch]
|
| 56 |
+
return captions, locations
|
| 57 |
+
|
| 58 |
+
def ensure_log_dir_exists(log_dir):
|
| 59 |
+
if not os.path.exists(log_dir):
|
| 60 |
+
os.makedirs(log_dir)
|
| 61 |
+
|
| 62 |
+
__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
|
| 63 |
+
|
| 64 |
+
def _generate_square_subsequent_mask(
|
| 65 |
+
sz: int,
|
| 66 |
+
device: Optional[torch.device] = None,
|
| 67 |
+
dtype: Optional[torch.dtype] = None,
|
| 68 |
+
) -> Tensor:
|
| 69 |
+
r"""Generate a square causal mask for the sequence.
|
| 70 |
+
|
| 71 |
+
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
| 72 |
+
"""
|
| 73 |
+
if device is None:
|
| 74 |
+
device = torch.device('cpu')
|
| 75 |
+
if dtype is None:
|
| 76 |
+
dtype = torch.float32
|
| 77 |
+
return torch.triu(
|
| 78 |
+
torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
|
| 79 |
+
diagonal=1,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _get_seq_len(
|
| 84 |
+
src: Tensor,
|
| 85 |
+
batch_first: bool
|
| 86 |
+
) -> Optional[int]:
|
| 87 |
+
|
| 88 |
+
if src.is_nested:
|
| 89 |
+
return None
|
| 90 |
+
else:
|
| 91 |
+
src_size = src.size()
|
| 92 |
+
if len(src_size) == 2:
|
| 93 |
+
# unbatched: S, E
|
| 94 |
+
return src_size[0]
|
| 95 |
+
else:
|
| 96 |
+
# batched: B, S, E if batch_first else S, B, E
|
| 97 |
+
seq_len_pos = 1 if batch_first else 0
|
| 98 |
+
return src_size[seq_len_pos]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class PositionalEncoding(nn.Module):
|
| 102 |
+
r"""Inject some information about the relative or absolute position of the tokens in the sequence.
|
| 103 |
+
The positional encodings have the same dimension as the embeddings, so that the two can be summed.
|
| 104 |
+
Here, we use sine and cosine functions of different frequencies.
|
| 105 |
+
.. math:
|
| 106 |
+
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
| 107 |
+
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
| 108 |
+
\text{where pos is the word position and i is the embed idx)
|
| 109 |
+
Args:
|
| 110 |
+
d_model: the embed dim (required).
|
| 111 |
+
dropout: the dropout value (default=0.1).
|
| 112 |
+
max_len: the max. length of the incoming sequence (default=5000).
|
| 113 |
+
Examples:
|
| 114 |
+
>>> pos_encoder = PositionalEncoding(d_model)
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 118 |
+
super(PositionalEncoding, self).__init__()
|
| 119 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 120 |
+
|
| 121 |
+
pe = torch.zeros(max_len, d_model)
|
| 122 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 123 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 124 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 125 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 126 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 127 |
+
# self.register_buffer('pe', pe)
|
| 128 |
+
self.register_parameter('pe', nn.Parameter(pe, requires_grad=False))
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
r"""Inputs of forward function
|
| 132 |
+
Args:
|
| 133 |
+
x: the sequence fed to the positional encoder model (required).
|
| 134 |
+
Shape:
|
| 135 |
+
x: [sequence length, batch size, embed dim]
|
| 136 |
+
output: [sequence length, batch size, embed dim]
|
| 137 |
+
Examples:
|
| 138 |
+
>>> output = pos_encoder(x)
|
| 139 |
+
"""
|
| 140 |
+
x = x + self.pe[:x.size(0), :]
|
| 141 |
+
return self.dropout(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def precompute_freqs_cis(
|
| 145 |
+
seq_len: int,
|
| 146 |
+
n_elem: int,
|
| 147 |
+
base: int = 10000,
|
| 148 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 149 |
+
):
|
| 150 |
+
freqs = 1.0 / (
|
| 151 |
+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
| 152 |
+
)
|
| 153 |
+
t = torch.arange(seq_len, device=freqs.device)
|
| 154 |
+
freqs = torch.outer(t, freqs)
|
| 155 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 156 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
| 157 |
+
|
| 158 |
+
return cache.to(dtype=dtype)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@torch.jit.script
|
| 162 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""
|
| 164 |
+
In-place RoPE. Credits to Katherine Crowson:
|
| 165 |
+
x shape (b_sz, n_head, s_len, d_head).
|
| 166 |
+
cos, sin shape (s_len, d_head // 2).
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
x = x.permute(0, 2, 1, 3)
|
| 170 |
+
d = x.shape[-1] // 2
|
| 171 |
+
cos = freqs_cis[..., 0][None, :, None]
|
| 172 |
+
sin = freqs_cis[..., 1][None, :, None]
|
| 173 |
+
x1, x2 = x[..., :d], x[..., d : d * 2]
|
| 174 |
+
tmp = x1.clone()
|
| 175 |
+
# x1.mul_(cos).addcmul_(x2, sin, value=-1)
|
| 176 |
+
# x2.mul_(cos).addcmul_(tmp, sin, value=1) ##was throwing some error: RuntimeError: Output 0 of SliceBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
|
| 177 |
+
x1_new = x1.mul(cos) - x2.mul(sin)
|
| 178 |
+
x2_new = x2.mul(cos) + tmp.mul(sin)
|
| 179 |
+
x = torch.cat((x1_new, x2_new), dim=-1)
|
| 180 |
+
x = x.permute(0, 2, 1, 3)
|
| 181 |
+
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 186 |
+
r"""Multi-head self-attention module.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
embed_dim (int): The input embedding dimension.
|
| 190 |
+
num_heads (int, optional): The number of attention heads (default: 4).
|
| 191 |
+
dropout (float, optional): The dropout probability (default: 0.1).
|
| 192 |
+
device (torch.device, optional): The device to use (default: None).
|
| 193 |
+
dtype (torch.dtype, optional): The data type to use (default: None).
|
| 194 |
+
|
| 195 |
+
Attributes:
|
| 196 |
+
dim_head (int): The dimension of each attention head.
|
| 197 |
+
scale (float): The scaling factor for attention scores.
|
| 198 |
+
heads (int): The number of attention heads.
|
| 199 |
+
to_qkv (nn.Linear): Linear layer for projecting input to query, key, and value.
|
| 200 |
+
to_out (nn.Linear): Linear layer for projecting attention output to the original embedding dimension.
|
| 201 |
+
dropout (nn.Dropout): Dropout layer.
|
| 202 |
+
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
embed_dim: int,
|
| 208 |
+
num_heads: int = 4,
|
| 209 |
+
dropout: float = 0.1,
|
| 210 |
+
batch_first: bool = True,
|
| 211 |
+
device: Optional[torch.device] = None,
|
| 212 |
+
dtype: Optional[torch.dtype] = None,
|
| 213 |
+
):
|
| 214 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.embed_dim = embed_dim
|
| 217 |
+
self.batch_first = batch_first
|
| 218 |
+
self.dim_head = embed_dim // num_heads
|
| 219 |
+
self.scale = self.dim_head ** -0.5
|
| 220 |
+
self.heads = num_heads
|
| 221 |
+
hidden_dim = self.dim_head * num_heads
|
| 222 |
+
self.to_qkv = nn.Linear(embed_dim, hidden_dim * 3, bias=False, **factory_kwargs)
|
| 223 |
+
self.to_out = nn.Linear(hidden_dim, embed_dim, bias=False, **factory_kwargs)
|
| 224 |
+
self.dropout = nn.Dropout(dropout)
|
| 225 |
+
|
| 226 |
+
def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
|
| 227 |
+
|
| 228 |
+
r"""Forward pass of the multi-head self-attention module.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, embed_dim).
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
torch.Tensor: The output tensor of shape (batch_size, sequence_length, embed_dim).
|
| 235 |
+
|
| 236 |
+
"""
|
| 237 |
+
if not self.batch_first:
|
| 238 |
+
x = x.transpose(0, 1)
|
| 239 |
+
b, n, _ = x.size()
|
| 240 |
+
q, k, v = torch.chunk(self.to_qkv(x), chunks=3, dim=-1)
|
| 241 |
+
q, k, v = map(lambda t: t.contiguous().view(b, self.heads, n, -1), (q, k, v))
|
| 242 |
+
|
| 243 |
+
self.freqs_cis = precompute_freqs_cis(
|
| 244 |
+
seq_len=n,
|
| 245 |
+
n_elem=self.embed_dim // self.heads,
|
| 246 |
+
base=10000,
|
| 247 |
+
dtype=x.dtype,
|
| 248 |
+
).to(x.device)
|
| 249 |
+
freqs_cis = self.freqs_cis[: x.shape[1]]
|
| 250 |
+
# q = apply_rotary_emb(q, freqs_cis)
|
| 251 |
+
# k = apply_rotary_emb(k, freqs_cis)
|
| 252 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
| 253 |
+
out = out.contiguous().view(b, n, -1)
|
| 254 |
+
out = self.dropout(out)
|
| 255 |
+
return self.to_out(out)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class Transformer(Module):
|
| 259 |
+
r"""A transformer model.
|
| 260 |
+
|
| 261 |
+
User is able to modify the attributes as needed. The architecture
|
| 262 |
+
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
| 263 |
+
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
| 264 |
+
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
| 265 |
+
Processing Systems, pages 6000-6010.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
| 269 |
+
nhead: the number of heads in the multiheadattention models (default=8).
|
| 270 |
+
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
| 271 |
+
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
| 272 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 273 |
+
use_moe: if True, use MoE instead of linear layer for feedforward network (default=False).
|
| 274 |
+
dropout: the dropout value (default=0.1).
|
| 275 |
+
activation: the activation function of encoder/decoder intermediate layer, can be a string
|
| 276 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 277 |
+
custom_encoder: custom encoder (default=None).
|
| 278 |
+
custom_decoder: custom decoder (default=None).
|
| 279 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 280 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 281 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 282 |
+
norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
|
| 283 |
+
other attention and feedforward operations, otherwise after. Default: ``False`` (after).
|
| 284 |
+
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
| 285 |
+
bias. Default: ``True``.
|
| 286 |
+
|
| 287 |
+
Examples::
|
| 288 |
+
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
| 289 |
+
>>> src = torch.rand((32, 512))
|
| 290 |
+
>>> tgt = torch.rand((32, 512, 30000))
|
| 291 |
+
>>> out = transformer_model(src, tgt)
|
| 292 |
+
|
| 293 |
+
Note: A full example to apply nn.Transformer module for the word language model is available in
|
| 294 |
+
https://github.com/pytorch/examples/tree/master/word_language_model
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
def __init__(self, n_vocab: int = 30000, d_model: int = 512, nhead: int = 8, max_len: int = 5000,
|
| 298 |
+
num_decoder_layers: int = 6, dim_feedforward: int = 2048, use_moe: bool = False,
|
| 299 |
+
num_experts: int = 16, dropout: float = 0.1,
|
| 300 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 301 |
+
layer_norm_eps: float = 1e-5, batch_first: bool = True, norm_first: bool = False,
|
| 302 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 303 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 304 |
+
super().__init__()
|
| 305 |
+
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
| 306 |
+
|
| 307 |
+
self.use_moe = use_moe
|
| 308 |
+
|
| 309 |
+
self.input_emb = nn.Embedding(n_vocab, d_model, **factory_kwargs)
|
| 310 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout, max_len).to(device)
|
| 311 |
+
|
| 312 |
+
# Load the FLAN-T5 encoder
|
| 313 |
+
self.encoder = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device)
|
| 314 |
+
# Freeze the encoder
|
| 315 |
+
for param in self.encoder.parameters():
|
| 316 |
+
param.requires_grad = False
|
| 317 |
+
|
| 318 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, use_moe, num_experts, dropout,
|
| 319 |
+
activation, layer_norm_eps, batch_first, norm_first,
|
| 320 |
+
bias, **factory_kwargs)
|
| 321 |
+
decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 322 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, use_moe, decoder_norm)
|
| 323 |
+
|
| 324 |
+
self.projection = nn.Linear(d_model, n_vocab).to(device)
|
| 325 |
+
|
| 326 |
+
self._reset_parameters()
|
| 327 |
+
|
| 328 |
+
self.d_model = d_model
|
| 329 |
+
self.nhead = nhead
|
| 330 |
+
|
| 331 |
+
self.batch_first = batch_first
|
| 332 |
+
|
| 333 |
+
def forward(self, src: Tensor, src_mask: Tensor, tgt: Tensor, memory_mask: Optional[Tensor] = None,
|
| 334 |
+
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: bool = True,
|
| 335 |
+
memory_is_causal: bool = False) -> Tensor:
|
| 336 |
+
r"""Take in and process masked source/target sequences.
|
| 337 |
+
|
| 338 |
+
.. note::
|
| 339 |
+
|
| 340 |
+
If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
|
| 341 |
+
not allowed to participate in the attention,
|
| 342 |
+
which is the opposite of the definition for :attr:`attn_mask`
|
| 343 |
+
in :func:`torch.nn.functional.scaled_dot_product_attention`.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
src: the sequence to the encoder (required).
|
| 347 |
+
src_attn_mask: the attention mask for the src sequence (required).
|
| 348 |
+
tgt: the sequence to the decoder (required).
|
| 349 |
+
tgt_mask: the additive mask for the tgt sequence (optional).
|
| 350 |
+
memory_mask: the additive mask for the encoder output (optional).
|
| 351 |
+
tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
|
| 352 |
+
memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
|
| 353 |
+
tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
|
| 354 |
+
Default: ``None``; try to detect a causal mask.
|
| 355 |
+
Warning:
|
| 356 |
+
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
| 357 |
+
the causal mask. Providing incorrect hints can result in
|
| 358 |
+
incorrect execution, including forward and backward
|
| 359 |
+
compatibility.
|
| 360 |
+
memory_is_causal: If specified, applies a causal mask as
|
| 361 |
+
``memory_mask``.
|
| 362 |
+
Default: ``False``.
|
| 363 |
+
Warning:
|
| 364 |
+
``memory_is_causal`` provides a hint that
|
| 365 |
+
``memory_mask`` is the causal mask. Providing incorrect
|
| 366 |
+
hints can result in incorrect execution, including
|
| 367 |
+
forward and backward compatibility.
|
| 368 |
+
|
| 369 |
+
Shape:
|
| 370 |
+
- src: :math:`(S, S)` for unbatched input, :math:`(S, N)` if `batch_first=False` or
|
| 371 |
+
`(N, S)` if `batch_first=True`.
|
| 372 |
+
- src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
|
| 373 |
+
- tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
| 374 |
+
`(N, T, E)` if `batch_first=True`.
|
| 375 |
+
- tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
|
| 376 |
+
- memory_mask: :math:`(T, S)`.
|
| 377 |
+
- src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
| 378 |
+
- tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
|
| 379 |
+
- memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
|
| 380 |
+
|
| 381 |
+
Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
|
| 382 |
+
positions. If a BoolTensor is provided, positions with ``True``
|
| 383 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
| 384 |
+
is provided, it will be added to the attention weight.
|
| 385 |
+
[src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
|
| 386 |
+
the attention. If a BoolTensor is provided, the positions with the
|
| 387 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
| 388 |
+
|
| 389 |
+
- output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
|
| 390 |
+
`(N, T, E)` if `batch_first=True`.
|
| 391 |
+
|
| 392 |
+
Note: Due to the multi-head attention architecture in the transformer model,
|
| 393 |
+
the output sequence length of a transformer is same as the input sequence
|
| 394 |
+
(i.e. target) length of the decoder.
|
| 395 |
+
|
| 396 |
+
where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
|
| 397 |
+
batch size, :math:`E` is the feature number
|
| 398 |
+
|
| 399 |
+
Examples:
|
| 400 |
+
>>> # xdoctest: +SKIP
|
| 401 |
+
>>> output = transformer_model(src, tgt, src_mask=src_mask)
|
| 402 |
+
"""
|
| 403 |
+
if src.dim() != tgt.dim():
|
| 404 |
+
raise RuntimeError("the number of dimensions in src and tgt must be equal")
|
| 405 |
+
|
| 406 |
+
memory = self.encoder(src, attention_mask=src_mask).last_hidden_state
|
| 407 |
+
|
| 408 |
+
tgt = self.input_emb(tgt) * math.sqrt(self.d_model)
|
| 409 |
+
tgt = self.pos_encoder(tgt)
|
| 410 |
+
# tgt = tgt + tgt_pos
|
| 411 |
+
|
| 412 |
+
if self.use_moe:
|
| 413 |
+
with torch.cuda.amp.autocast(enabled =False):
|
| 414 |
+
output, sum_total_aux_loss = self.decoder(tgt, memory, memory_mask=memory_mask,
|
| 415 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 416 |
+
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
|
| 417 |
+
else:
|
| 418 |
+
output = self.decoder(tgt, memory, memory_mask=memory_mask,
|
| 419 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 420 |
+
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
|
| 421 |
+
|
| 422 |
+
output = self.projection(output)
|
| 423 |
+
# output = F.log_softmax(output, dim=-1)
|
| 424 |
+
|
| 425 |
+
if self.use_moe:
|
| 426 |
+
return output, sum_total_aux_loss
|
| 427 |
+
else:
|
| 428 |
+
return output
|
| 429 |
+
|
| 430 |
+
def generate(self, src: Tensor, src_mask: Tensor, max_len: int = 100, temperature: float = 1.0):
|
| 431 |
+
## ADD A START OF SEQUENCE TOKEN <SS> token to the src tensor
|
| 432 |
+
r"""Generate a sequence of tokens from the given inputs.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
src: the sequence to the encoder (required).
|
| 436 |
+
src_mask: the attention mask for the src sequence (required).
|
| 437 |
+
max_len: the maximum length of the sequence to generate (default=100).
|
| 438 |
+
temperature: the temperature for the softmax (default=1.0).
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
torch.Tensor: The generated sequence of tokens.
|
| 442 |
+
|
| 443 |
+
"""
|
| 444 |
+
if src.dim() != 2:
|
| 445 |
+
raise RuntimeError("The src tensor should be 2-dimensional")
|
| 446 |
+
tgt_fin = torch.full((src.size(0), 1), 1, dtype=torch.long, device=src.device)
|
| 447 |
+
# values = [21631, 8, 10, 9, 6, 7, 17, 21632, 11474, 20626, 21151, 9426, 20627, 21143, 11476, 20640, 21143, 11477, 20655, 21145, 11476, 20669, 21145, 11477, 20683, 21145, 13527, 20697, 21146, 13529, 20712, 21145, 7013, 20769, 21143, 7006, 20769, 21143, 7006, 20769, 21141, 7009, 20769, 21143, 9426, 20797, 21144, 11474, 20797, 21173, 11476, 20812, 21144, 11477, 20826, 21145, 11476, 20840, 21145, 11477, 20855, 21145, 13527, 20869, 21144, 13529, 20883, 21143, 7006, 20940, 21139, 7013, 20940, 21140, 7006, 20940, 21147, 7009, 20940, 21147, 11474, 20969, 21144, 11474, 20969, 21170, 11476, 20983, 21144, 11477, 20997, 21145, 11476, 21012, 21144, 11477, 21026, 21144, 11479, 21040]
|
| 448 |
+
# values_tensor = torch.tensor(values, dtype=torch.long, device=src.device)
|
| 449 |
+
# tgt_fin = values_tensor.unsqueeze(0).repeat(src.size(0), 1)
|
| 450 |
+
for i in tqdm(range(max_len)):
|
| 451 |
+
max_index = tgt_fin.max()
|
| 452 |
+
# assert max_index < 21634, "tgt_fin contains index out of range. Adjust n_vocab or fix tgt_fin indices."
|
| 453 |
+
tgt = tgt_fin
|
| 454 |
+
if self.use_moe:
|
| 455 |
+
output, _ = self.froward(src, src_mask, tgt, memory_mask=None,
|
| 456 |
+
memory_key_padding_mask=None,
|
| 457 |
+
tgt_is_causal=True, memory_is_causal=False)
|
| 458 |
+
else:
|
| 459 |
+
output = self.forward(src, src_mask, tgt, memory_mask=None,
|
| 460 |
+
memory_key_padding_mask=None,
|
| 461 |
+
tgt_is_causal=True, memory_is_causal=False)
|
| 462 |
+
# logits = self.projection(output)
|
| 463 |
+
logits = output
|
| 464 |
+
output = F.log_softmax(logits/temperature, dim=-1)
|
| 465 |
+
output = output.view(-1, output.size(-1))
|
| 466 |
+
next_tokens = torch.multinomial(torch.exp(output), 1)[-1] # taking the last logit and adding to the sequence
|
| 467 |
+
tgt_fin = torch.cat((tgt_fin, next_tokens.unsqueeze(-1)), dim=1)
|
| 468 |
+
return tgt_fin[:, 1:]
|
| 469 |
+
|
| 470 |
+
@staticmethod
|
| 471 |
+
def generate_square_subsequent_mask(
|
| 472 |
+
sz: int,
|
| 473 |
+
device: Optional[torch.device] = None,
|
| 474 |
+
dtype: Optional[torch.dtype] = None,
|
| 475 |
+
) -> Tensor:
|
| 476 |
+
r"""Generate a square causal mask for the sequence.
|
| 477 |
+
|
| 478 |
+
The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
| 479 |
+
"""
|
| 480 |
+
return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def _reset_parameters(self):
|
| 484 |
+
r"""Initiate parameters in the transformer model."""
|
| 485 |
+
for p in self.parameters():
|
| 486 |
+
if p.dim() > 1:
|
| 487 |
+
xavier_uniform_(p)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class TransformerEncoder(Module):
|
| 493 |
+
r"""TransformerEncoder is a stack of N encoder layers.
|
| 494 |
+
|
| 495 |
+
Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
| 499 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
| 500 |
+
norm: the layer normalization component (optional).
|
| 501 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
| 502 |
+
(and convert back on output). This will improve the overall performance of
|
| 503 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
| 504 |
+
|
| 505 |
+
Examples::
|
| 506 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
| 507 |
+
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
| 508 |
+
>>> src = torch.rand(10, 32, 512)
|
| 509 |
+
>>> out = transformer_encoder(src)
|
| 510 |
+
"""
|
| 511 |
+
|
| 512 |
+
__constants__ = ['norm']
|
| 513 |
+
|
| 514 |
+
def __init__(
|
| 515 |
+
self,
|
| 516 |
+
encoder_layer: "TransformerEncoderLayer",
|
| 517 |
+
num_layers: int,
|
| 518 |
+
norm: Optional[Module] = None,
|
| 519 |
+
enable_nested_tensor: bool = True,
|
| 520 |
+
mask_check: bool = True
|
| 521 |
+
) -> None:
|
| 522 |
+
super().__init__()
|
| 523 |
+
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
| 524 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 525 |
+
self.num_layers = num_layers
|
| 526 |
+
self.norm = norm
|
| 527 |
+
# this attribute saves the value providedat object construction
|
| 528 |
+
self.enable_nested_tensor = enable_nested_tensor
|
| 529 |
+
# this attribute controls whether nested tensors are used
|
| 530 |
+
self.use_nested_tensor = enable_nested_tensor
|
| 531 |
+
self.mask_check = mask_check
|
| 532 |
+
|
| 533 |
+
enc_layer = "encoder_layer"
|
| 534 |
+
why_not_sparsity_fast_path = ''
|
| 535 |
+
if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
|
| 536 |
+
why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
|
| 537 |
+
elif encoder_layer.norm_first :
|
| 538 |
+
why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
|
| 539 |
+
elif not encoder_layer.self_attn.batch_first:
|
| 540 |
+
why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
|
| 541 |
+
"(use batch_first for better inference performance)")
|
| 542 |
+
elif not encoder_layer.self_attn._qkv_same_embed_dim:
|
| 543 |
+
why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
|
| 544 |
+
elif encoder_layer.self_attn.in_proj_bias is None:
|
| 545 |
+
why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
|
| 546 |
+
elif not encoder_layer.activation_relu_or_gelu:
|
| 547 |
+
why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
|
| 548 |
+
elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
|
| 549 |
+
why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
|
| 550 |
+
elif encoder_layer.self_attn.num_heads % 2 == 1:
|
| 551 |
+
why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
|
| 552 |
+
|
| 553 |
+
if enable_nested_tensor and why_not_sparsity_fast_path:
|
| 554 |
+
warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
|
| 555 |
+
self.use_nested_tensor = False
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def forward(
|
| 560 |
+
self,
|
| 561 |
+
src: Tensor,
|
| 562 |
+
mask: Optional[Tensor] = None,
|
| 563 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 564 |
+
is_causal: Optional[bool] = None) -> Tensor:
|
| 565 |
+
r"""Pass the input through the encoder layers in turn.
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
src: the sequence to the encoder (required).
|
| 569 |
+
mask: the mask for the src sequence (optional).
|
| 570 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
| 571 |
+
is_causal: If specified, applies a causal mask as ``mask``.
|
| 572 |
+
Default: ``None``; try to detect a causal mask.
|
| 573 |
+
Warning:
|
| 574 |
+
``is_causal`` provides a hint that ``mask`` is the
|
| 575 |
+
causal mask. Providing incorrect hints can result in
|
| 576 |
+
incorrect execution, including forward and backward
|
| 577 |
+
compatibility.
|
| 578 |
+
|
| 579 |
+
Shape:
|
| 580 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 581 |
+
"""
|
| 582 |
+
src_key_padding_mask = F._canonical_mask(
|
| 583 |
+
mask=src_key_padding_mask,
|
| 584 |
+
mask_name="src_key_padding_mask",
|
| 585 |
+
other_type=F._none_or_dtype(mask),
|
| 586 |
+
other_name="mask",
|
| 587 |
+
target_type=src.dtype
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
mask = F._canonical_mask(
|
| 591 |
+
mask=mask,
|
| 592 |
+
mask_name="mask",
|
| 593 |
+
other_type=None,
|
| 594 |
+
other_name="",
|
| 595 |
+
target_type=src.dtype,
|
| 596 |
+
check_other=False,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
output = src
|
| 600 |
+
convert_to_nested = False
|
| 601 |
+
first_layer = self.layers[0]
|
| 602 |
+
src_key_padding_mask_for_layers = src_key_padding_mask
|
| 603 |
+
why_not_sparsity_fast_path = ''
|
| 604 |
+
str_first_layer = "self.layers[0]"
|
| 605 |
+
batch_first = first_layer.self_attn.batch_first
|
| 606 |
+
# is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
| 607 |
+
|
| 608 |
+
# if not is_fastpath_enabled:
|
| 609 |
+
# why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
| 610 |
+
if not hasattr(self, "use_nested_tensor"):
|
| 611 |
+
why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
|
| 612 |
+
elif not self.use_nested_tensor:
|
| 613 |
+
why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
|
| 614 |
+
elif first_layer.training:
|
| 615 |
+
why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
|
| 616 |
+
elif not src.dim() == 3:
|
| 617 |
+
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
| 618 |
+
elif src_key_padding_mask is None:
|
| 619 |
+
why_not_sparsity_fast_path = "src_key_padding_mask was None"
|
| 620 |
+
elif (((not hasattr(self, "mask_check")) or self.mask_check)
|
| 621 |
+
and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
|
| 622 |
+
why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
|
| 623 |
+
elif output.is_nested:
|
| 624 |
+
why_not_sparsity_fast_path = "NestedTensor input is not supported"
|
| 625 |
+
elif mask is not None:
|
| 626 |
+
why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
|
| 627 |
+
elif torch.is_autocast_enabled():
|
| 628 |
+
why_not_sparsity_fast_path = "autocast is enabled"
|
| 629 |
+
|
| 630 |
+
if not why_not_sparsity_fast_path:
|
| 631 |
+
tensor_args = (
|
| 632 |
+
src,
|
| 633 |
+
first_layer.self_attn.in_proj_weight,
|
| 634 |
+
first_layer.self_attn.in_proj_bias,
|
| 635 |
+
first_layer.self_attn.out_proj.weight,
|
| 636 |
+
first_layer.self_attn.out_proj.bias,
|
| 637 |
+
first_layer.norm1.weight,
|
| 638 |
+
first_layer.norm1.bias,
|
| 639 |
+
first_layer.norm2.weight,
|
| 640 |
+
first_layer.norm2.bias,
|
| 641 |
+
first_layer.linear1.weight,
|
| 642 |
+
first_layer.linear1.bias,
|
| 643 |
+
first_layer.linear2.weight,
|
| 644 |
+
first_layer.linear2.bias,
|
| 645 |
+
)
|
| 646 |
+
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
| 647 |
+
if torch.overrides.has_torch_function(tensor_args):
|
| 648 |
+
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
| 649 |
+
elif src.device.type not in _supported_device_type:
|
| 650 |
+
why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
|
| 651 |
+
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
| 652 |
+
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
| 653 |
+
"input/output projection weights or biases requires_grad")
|
| 654 |
+
|
| 655 |
+
if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
|
| 656 |
+
convert_to_nested = True
|
| 657 |
+
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
|
| 658 |
+
src_key_padding_mask_for_layers = None
|
| 659 |
+
|
| 660 |
+
seq_len = _get_seq_len(src, batch_first)
|
| 661 |
+
is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
|
| 662 |
+
|
| 663 |
+
for mod in self.layers:
|
| 664 |
+
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
|
| 665 |
+
|
| 666 |
+
if convert_to_nested:
|
| 667 |
+
output = output.to_padded_tensor(0., src.size())
|
| 668 |
+
|
| 669 |
+
if self.norm is not None:
|
| 670 |
+
output = self.norm(output)
|
| 671 |
+
|
| 672 |
+
return output
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class TransformerDecoder(Module):
|
| 678 |
+
r"""TransformerDecoder is a stack of N decoder layers.
|
| 679 |
+
|
| 680 |
+
Args:
|
| 681 |
+
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
| 682 |
+
num_layers: the number of sub-decoder-layers in the decoder (required).
|
| 683 |
+
norm: the layer normalization component (optional).
|
| 684 |
+
|
| 685 |
+
Examples::
|
| 686 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
| 687 |
+
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
| 688 |
+
>>> memory = torch.rand(10, 32, 512)
|
| 689 |
+
>>> tgt = torch.rand(20, 32, 512)
|
| 690 |
+
>>> out = transformer_decoder(tgt, memory)
|
| 691 |
+
"""
|
| 692 |
+
|
| 693 |
+
__constants__ = ['norm']
|
| 694 |
+
|
| 695 |
+
def __init__(
|
| 696 |
+
self,
|
| 697 |
+
decoder_layer: "TransformerDecoderLayer",
|
| 698 |
+
num_layers: int,
|
| 699 |
+
use_moe: bool = False,
|
| 700 |
+
norm: Optional[Module] = None
|
| 701 |
+
) -> None:
|
| 702 |
+
super().__init__()
|
| 703 |
+
torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
|
| 704 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 705 |
+
self.num_layers = num_layers
|
| 706 |
+
self.use_moe = use_moe
|
| 707 |
+
self.norm = norm
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
| 711 |
+
memory_mask: Optional[Tensor] = None,
|
| 712 |
+
memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
|
| 713 |
+
memory_is_causal: bool = False) -> Tensor:
|
| 714 |
+
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
tgt: the sequence to the decoder (required).
|
| 718 |
+
memory: the sequence from the last layer of the encoder (required).
|
| 719 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
| 720 |
+
memory_mask: the mask for the memory sequence (optional).
|
| 721 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
| 722 |
+
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
| 723 |
+
Default: ``None``; try to detect a causal mask.
|
| 724 |
+
Warning:
|
| 725 |
+
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
| 726 |
+
the causal mask. Providing incorrect hints can result in
|
| 727 |
+
incorrect execution, including forward and backward
|
| 728 |
+
compatibility.
|
| 729 |
+
memory_is_causal: If specified, applies a causal mask as
|
| 730 |
+
``memory mask``.
|
| 731 |
+
Default: ``False``.
|
| 732 |
+
Warning:
|
| 733 |
+
``memory_is_causal`` provides a hint that
|
| 734 |
+
``memory_mask`` is the causal mask. Providing incorrect
|
| 735 |
+
hints can result in incorrect execution, including
|
| 736 |
+
forward and backward compatibility.
|
| 737 |
+
|
| 738 |
+
Shape:
|
| 739 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 740 |
+
"""
|
| 741 |
+
output = tgt
|
| 742 |
+
|
| 743 |
+
seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
|
| 744 |
+
tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
|
| 745 |
+
# print(f'target is causal: {tgt_is_causal}')
|
| 746 |
+
|
| 747 |
+
if self.use_moe:
|
| 748 |
+
sum_total_aux_loss = 0
|
| 749 |
+
for mod in self.layers:
|
| 750 |
+
output, total_aux_loss, balance_loss, router_z_loss = mod(output, memory,
|
| 751 |
+
memory_mask=memory_mask,
|
| 752 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 753 |
+
tgt_is_causal=tgt_is_causal,
|
| 754 |
+
memory_is_causal=memory_is_causal)
|
| 755 |
+
sum_total_aux_loss += total_aux_loss
|
| 756 |
+
else:
|
| 757 |
+
for mod in self.layers:
|
| 758 |
+
output = mod(output, memory,
|
| 759 |
+
memory_mask=memory_mask,
|
| 760 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 761 |
+
tgt_is_causal=tgt_is_causal,
|
| 762 |
+
memory_is_causal=memory_is_causal)
|
| 763 |
+
|
| 764 |
+
if self.norm is not None:
|
| 765 |
+
output = self.norm(output)
|
| 766 |
+
|
| 767 |
+
if self.use_moe:
|
| 768 |
+
return output, sum_total_aux_loss
|
| 769 |
+
else:
|
| 770 |
+
return output
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
class TransformerEncoderLayer(Module):
|
| 775 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
| 776 |
+
|
| 777 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
| 778 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
| 779 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
| 780 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
| 781 |
+
in a different way during application.
|
| 782 |
+
|
| 783 |
+
TransformerEncoderLayer can handle either traditional torch.tensor inputs,
|
| 784 |
+
or Nested Tensor inputs. Derived classes are expected to similarly accept
|
| 785 |
+
both input formats. (Not all combinations of inputs are currently
|
| 786 |
+
supported by TransformerEncoderLayer while Nested Tensor is in prototype
|
| 787 |
+
state.)
|
| 788 |
+
|
| 789 |
+
If you are implementing a custom layer, you may derive it either from
|
| 790 |
+
the Module or TransformerEncoderLayer class. If your custom layer
|
| 791 |
+
supports both torch.Tensors and Nested Tensors inputs, make its
|
| 792 |
+
implementation a derived class of TransformerEncoderLayer. If your custom
|
| 793 |
+
Layer supports only torch.Tensor inputs, derive its implementation from
|
| 794 |
+
Module.
|
| 795 |
+
|
| 796 |
+
Args:
|
| 797 |
+
d_model: the number of expected features in the input (required).
|
| 798 |
+
nhead: the number of heads in the multiheadattention models (required).
|
| 799 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 800 |
+
dropout: the dropout value (default=0.1).
|
| 801 |
+
activation: the activation function of the intermediate layer, can be a string
|
| 802 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 803 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 804 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 805 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 806 |
+
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
| 807 |
+
operations, respectively. Otherwise it's done after. Default: ``False`` (after).
|
| 808 |
+
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
| 809 |
+
bias. Default: ``True``.
|
| 810 |
+
|
| 811 |
+
Examples::
|
| 812 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
| 813 |
+
>>> src = torch.rand(10, 32, 512)
|
| 814 |
+
>>> out = encoder_layer(src)
|
| 815 |
+
|
| 816 |
+
Alternatively, when ``batch_first`` is ``True``:
|
| 817 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
| 818 |
+
>>> src = torch.rand(32, 10, 512)
|
| 819 |
+
>>> out = encoder_layer(src)
|
| 820 |
+
|
| 821 |
+
Fast path:
|
| 822 |
+
forward() will use a special optimized implementation described in
|
| 823 |
+
`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
|
| 824 |
+
conditions are met:
|
| 825 |
+
|
| 826 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
|
| 827 |
+
argument ``requires_grad``
|
| 828 |
+
- training is disabled (using ``.eval()``)
|
| 829 |
+
- batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
|
| 830 |
+
- activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
|
| 831 |
+
- at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
|
| 832 |
+
- if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
|
| 833 |
+
nor ``src_key_padding_mask`` is passed
|
| 834 |
+
- the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
|
| 835 |
+
unless the caller has manually modified one without modifying the other)
|
| 836 |
+
|
| 837 |
+
If the optimized implementation is in use, a
|
| 838 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
|
| 839 |
+
passed for ``src`` to represent padding more efficiently than using a padding
|
| 840 |
+
mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
|
| 841 |
+
returned, and an additional speedup proportional to the fraction of the input that
|
| 842 |
+
is padding can be expected.
|
| 843 |
+
|
| 844 |
+
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
| 845 |
+
https://arxiv.org/abs/2205.14135
|
| 846 |
+
|
| 847 |
+
"""
|
| 848 |
+
|
| 849 |
+
__constants__ = ['norm_first']
|
| 850 |
+
|
| 851 |
+
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
| 852 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 853 |
+
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
| 854 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 855 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 856 |
+
super().__init__()
|
| 857 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
|
| 858 |
+
bias=bias, batch_first=batch_first,
|
| 859 |
+
**factory_kwargs)
|
| 860 |
+
# Implementation of Feedforward model
|
| 861 |
+
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
| 862 |
+
self.dropout = Dropout(dropout)
|
| 863 |
+
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
| 864 |
+
|
| 865 |
+
self.norm_first = norm_first
|
| 866 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 867 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 868 |
+
self.dropout1 = Dropout(dropout)
|
| 869 |
+
self.dropout2 = Dropout(dropout)
|
| 870 |
+
|
| 871 |
+
# Legacy string support for activation function.
|
| 872 |
+
if isinstance(activation, str):
|
| 873 |
+
activation = _get_activation_fn(activation)
|
| 874 |
+
|
| 875 |
+
# We can't test self.activation in forward() in TorchScript,
|
| 876 |
+
# so stash some information about it instead.
|
| 877 |
+
if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
| 878 |
+
self.activation_relu_or_gelu = 1
|
| 879 |
+
elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
| 880 |
+
self.activation_relu_or_gelu = 2
|
| 881 |
+
else:
|
| 882 |
+
self.activation_relu_or_gelu = 0
|
| 883 |
+
self.activation = activation
|
| 884 |
+
|
| 885 |
+
def __setstate__(self, state):
|
| 886 |
+
super().__setstate__(state)
|
| 887 |
+
if not hasattr(self, 'activation'):
|
| 888 |
+
self.activation = F.relu
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def forward(
|
| 893 |
+
self,
|
| 894 |
+
src: Tensor,
|
| 895 |
+
src_mask: Optional[Tensor] = None,
|
| 896 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 897 |
+
is_causal: bool = False) -> Tensor:
|
| 898 |
+
r"""Pass the input through the encoder layer.
|
| 899 |
+
|
| 900 |
+
Args:
|
| 901 |
+
src: the sequence to the encoder layer (required).
|
| 902 |
+
src_mask: the mask for the src sequence (optional).
|
| 903 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
| 904 |
+
is_causal: If specified, applies a causal mask as ``src mask``.
|
| 905 |
+
Default: ``False``.
|
| 906 |
+
Warning:
|
| 907 |
+
``is_causal`` provides a hint that ``src_mask`` is the
|
| 908 |
+
causal mask. Providing incorrect hints can result in
|
| 909 |
+
incorrect execution, including forward and backward
|
| 910 |
+
compatibility.
|
| 911 |
+
|
| 912 |
+
Shape:
|
| 913 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 914 |
+
"""
|
| 915 |
+
src_key_padding_mask = F._canonical_mask(
|
| 916 |
+
mask=src_key_padding_mask,
|
| 917 |
+
mask_name="src_key_padding_mask",
|
| 918 |
+
other_type=F._none_or_dtype(src_mask),
|
| 919 |
+
other_name="src_mask",
|
| 920 |
+
target_type=src.dtype
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
src_mask = F._canonical_mask(
|
| 924 |
+
mask=src_mask,
|
| 925 |
+
mask_name="src_mask",
|
| 926 |
+
other_type=None,
|
| 927 |
+
other_name="",
|
| 928 |
+
target_type=src.dtype,
|
| 929 |
+
check_other=False,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
| 933 |
+
|
| 934 |
+
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
| 935 |
+
why_not_sparsity_fast_path = ''
|
| 936 |
+
# if not is_fastpath_enabled:
|
| 937 |
+
# why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
| 938 |
+
if not src.dim() == 3:
|
| 939 |
+
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
|
| 940 |
+
elif self.training:
|
| 941 |
+
why_not_sparsity_fast_path = "training is enabled"
|
| 942 |
+
elif not self.self_attn.batch_first:
|
| 943 |
+
why_not_sparsity_fast_path = "self_attn.batch_first was not True"
|
| 944 |
+
elif self.self_attn.in_proj_bias is None:
|
| 945 |
+
why_not_sparsity_fast_path = "self_attn was passed bias=False"
|
| 946 |
+
elif not self.self_attn._qkv_same_embed_dim:
|
| 947 |
+
why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
|
| 948 |
+
elif not self.activation_relu_or_gelu:
|
| 949 |
+
why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
|
| 950 |
+
elif not (self.norm1.eps == self.norm2.eps):
|
| 951 |
+
why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
|
| 952 |
+
elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
|
| 953 |
+
why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
|
| 954 |
+
elif self.self_attn.num_heads % 2 == 1:
|
| 955 |
+
why_not_sparsity_fast_path = "num_head is odd"
|
| 956 |
+
elif torch.is_autocast_enabled():
|
| 957 |
+
why_not_sparsity_fast_path = "autocast is enabled"
|
| 958 |
+
if not why_not_sparsity_fast_path:
|
| 959 |
+
tensor_args = (
|
| 960 |
+
src,
|
| 961 |
+
self.self_attn.in_proj_weight,
|
| 962 |
+
self.self_attn.in_proj_bias,
|
| 963 |
+
self.self_attn.out_proj.weight,
|
| 964 |
+
self.self_attn.out_proj.bias,
|
| 965 |
+
self.norm1.weight,
|
| 966 |
+
self.norm1.bias,
|
| 967 |
+
self.norm2.weight,
|
| 968 |
+
self.norm2.bias,
|
| 969 |
+
self.linear1.weight,
|
| 970 |
+
self.linear1.bias,
|
| 971 |
+
self.linear2.weight,
|
| 972 |
+
self.linear2.bias,
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
# We have to use list comprehensions below because TorchScript does not support
|
| 976 |
+
# generator expressions.
|
| 977 |
+
_supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
|
| 978 |
+
if torch.overrides.has_torch_function(tensor_args):
|
| 979 |
+
why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
|
| 980 |
+
elif not all((x.device.type in _supported_device_type) for x in tensor_args):
|
| 981 |
+
why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
|
| 982 |
+
f"{_supported_device_type}")
|
| 983 |
+
elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
|
| 984 |
+
why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
|
| 985 |
+
"input/output projection weights or biases requires_grad")
|
| 986 |
+
|
| 987 |
+
if not why_not_sparsity_fast_path:
|
| 988 |
+
merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
|
| 989 |
+
return torch._transformer_encoder_layer_fwd(
|
| 990 |
+
src,
|
| 991 |
+
self.self_attn.embed_dim,
|
| 992 |
+
self.self_attn.num_heads,
|
| 993 |
+
self.self_attn.in_proj_weight,
|
| 994 |
+
self.self_attn.in_proj_bias,
|
| 995 |
+
self.self_attn.out_proj.weight,
|
| 996 |
+
self.self_attn.out_proj.bias,
|
| 997 |
+
self.activation_relu_or_gelu == 2,
|
| 998 |
+
self.norm_first,
|
| 999 |
+
self.norm1.eps,
|
| 1000 |
+
self.norm1.weight,
|
| 1001 |
+
self.norm1.bias,
|
| 1002 |
+
self.norm2.weight,
|
| 1003 |
+
self.norm2.bias,
|
| 1004 |
+
self.linear1.weight,
|
| 1005 |
+
self.linear1.bias,
|
| 1006 |
+
self.linear2.weight,
|
| 1007 |
+
self.linear2.bias,
|
| 1008 |
+
merged_mask,
|
| 1009 |
+
mask_type,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
x = src
|
| 1014 |
+
if self.norm_first:
|
| 1015 |
+
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
|
| 1016 |
+
x = x + self._ff_block(self.norm2(x))
|
| 1017 |
+
else:
|
| 1018 |
+
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
|
| 1019 |
+
x = self.norm2(x + self._ff_block(x))
|
| 1020 |
+
|
| 1021 |
+
return x
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
# self-attention block
|
| 1025 |
+
def _sa_block(self, x: Tensor,
|
| 1026 |
+
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
| 1027 |
+
x = self.self_attn(x, x, x,
|
| 1028 |
+
attn_mask=attn_mask,
|
| 1029 |
+
key_padding_mask=key_padding_mask,
|
| 1030 |
+
need_weights=False, is_causal=is_causal)[0]
|
| 1031 |
+
return self.dropout1(x)
|
| 1032 |
+
|
| 1033 |
+
# feed forward block
|
| 1034 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 1035 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 1036 |
+
return self.dropout2(x)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
class TransformerDecoderLayer(Module):
|
| 1042 |
+
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
| 1043 |
+
|
| 1044 |
+
This standard decoder layer is based on the paper "Attention Is All You Need".
|
| 1045 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
| 1046 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
| 1047 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
| 1048 |
+
in a different way during application.
|
| 1049 |
+
|
| 1050 |
+
Args:
|
| 1051 |
+
d_model: the number of expected features in the input (required).
|
| 1052 |
+
nhead: the number of heads in the multiheadattention models (required).
|
| 1053 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 1054 |
+
dropout: the dropout value (default=0.1).
|
| 1055 |
+
activation: the activation function of the intermediate layer, can be a string
|
| 1056 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 1057 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 1058 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 1059 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 1060 |
+
norm_first: if ``True``, layer norm is done prior to self attention, multihead
|
| 1061 |
+
attention and feedforward operations, respectively. Otherwise it's done after.
|
| 1062 |
+
Default: ``False`` (after).
|
| 1063 |
+
bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
|
| 1064 |
+
bias. Default: ``True``.
|
| 1065 |
+
|
| 1066 |
+
Examples::
|
| 1067 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
| 1068 |
+
>>> memory = torch.rand(10, 32, 512)
|
| 1069 |
+
>>> tgt = torch.rand(20, 32, 512)
|
| 1070 |
+
>>> out = decoder_layer(tgt, memory)
|
| 1071 |
+
|
| 1072 |
+
Alternatively, when ``batch_first`` is ``True``:
|
| 1073 |
+
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
|
| 1074 |
+
>>> memory = torch.rand(32, 10, 512)
|
| 1075 |
+
>>> tgt = torch.rand(32, 20, 512)
|
| 1076 |
+
>>> out = decoder_layer(tgt, memory)
|
| 1077 |
+
"""
|
| 1078 |
+
|
| 1079 |
+
__constants__ = ['norm_first']
|
| 1080 |
+
|
| 1081 |
+
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, use_moe: bool = False, num_experts: int = 16,
|
| 1082 |
+
dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 1083 |
+
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
| 1084 |
+
bias: bool = True, device=None, dtype=None) -> None:
|
| 1085 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 1086 |
+
super().__init__()
|
| 1087 |
+
|
| 1088 |
+
self.self_attn = MultiHeadSelfAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
|
| 1089 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
| 1090 |
+
bias=bias, **factory_kwargs)
|
| 1091 |
+
self.use_moe = use_moe
|
| 1092 |
+
|
| 1093 |
+
if use_moe:
|
| 1094 |
+
self.moe = MoE(
|
| 1095 |
+
dim = d_model,
|
| 1096 |
+
num_experts = num_experts, # increase the experts (# parameters) of your model without increasing computation
|
| 1097 |
+
gating_top_n = 2, # default to top 2 gating, but can also be more (3 was tested in the paper with a lower threshold)
|
| 1098 |
+
threshold_train = 0.2, # at what threshold to accept a token to be routed to second expert and beyond - 0.2 was optimal for 2 expert routing, and apparently should be lower for 3
|
| 1099 |
+
threshold_eval = 0.2,
|
| 1100 |
+
capacity_factor_train = 1.25, # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
|
| 1101 |
+
capacity_factor_eval = 2., # capacity_factor_* should be set to a value >=1
|
| 1102 |
+
balance_loss_coef = 1e-2, # multiplier on the auxiliary expert balancing auxiliary loss
|
| 1103 |
+
router_z_loss_coef = 1e-3, # loss weight for router z-loss
|
| 1104 |
+
).to(device)
|
| 1105 |
+
self.moe_block = SparseMoEBlock(
|
| 1106 |
+
self.moe,
|
| 1107 |
+
add_ff_before = True,
|
| 1108 |
+
add_ff_after = True
|
| 1109 |
+
).to(device)
|
| 1110 |
+
else:
|
| 1111 |
+
# Implementation of Feedforward model
|
| 1112 |
+
self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
|
| 1113 |
+
self.dropout = Dropout(dropout)
|
| 1114 |
+
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
| 1115 |
+
|
| 1116 |
+
self.norm_first = norm_first
|
| 1117 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 1118 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 1119 |
+
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
| 1120 |
+
self.dropout1 = Dropout(dropout)
|
| 1121 |
+
self.dropout2 = Dropout(dropout)
|
| 1122 |
+
self.dropout3 = Dropout(dropout)
|
| 1123 |
+
|
| 1124 |
+
# Legacy string support for activation function.
|
| 1125 |
+
if isinstance(activation, str):
|
| 1126 |
+
self.activation = _get_activation_fn(activation)
|
| 1127 |
+
else:
|
| 1128 |
+
self.activation = activation
|
| 1129 |
+
|
| 1130 |
+
def __setstate__(self, state):
|
| 1131 |
+
if 'activation' not in state:
|
| 1132 |
+
state['activation'] = F.relu
|
| 1133 |
+
super().__setstate__(state)
|
| 1134 |
+
|
| 1135 |
+
|
| 1136 |
+
def forward(
|
| 1137 |
+
self,
|
| 1138 |
+
tgt: Tensor,
|
| 1139 |
+
memory: Tensor,
|
| 1140 |
+
memory_mask: Optional[Tensor] = None,
|
| 1141 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 1142 |
+
tgt_is_causal: bool = False,
|
| 1143 |
+
memory_is_causal: bool = False,
|
| 1144 |
+
) -> Tensor:
|
| 1145 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
| 1146 |
+
|
| 1147 |
+
Args:
|
| 1148 |
+
tgt: the sequence to the decoder layer (required).
|
| 1149 |
+
memory: the sequence from the last layer of the encoder (required).
|
| 1150 |
+
memory_mask: the mask for the memory sequence (optional).
|
| 1151 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
| 1152 |
+
tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
|
| 1153 |
+
Default: ``False``.
|
| 1154 |
+
Warning:
|
| 1155 |
+
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
|
| 1156 |
+
the causal mask. Providing incorrect hints can result in
|
| 1157 |
+
incorrect execution, including forward and backward
|
| 1158 |
+
compatibility.
|
| 1159 |
+
memory_is_causal: If specified, applies a causal mask as
|
| 1160 |
+
``memory mask``.
|
| 1161 |
+
Default: ``False``.
|
| 1162 |
+
Warning:
|
| 1163 |
+
``memory_is_causal`` provides a hint that
|
| 1164 |
+
``memory_mask`` is the causal mask. Providing incorrect
|
| 1165 |
+
hints can result in incorrect execution, including
|
| 1166 |
+
forward and backward compatibility.
|
| 1167 |
+
|
| 1168 |
+
Shape:
|
| 1169 |
+
see the docs in :class:`~torch.nn.Transformer`.
|
| 1170 |
+
"""
|
| 1171 |
+
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
| 1172 |
+
|
| 1173 |
+
x = tgt
|
| 1174 |
+
# print(f'target is causal: {tgt_is_causal}')
|
| 1175 |
+
if self.norm_first:
|
| 1176 |
+
x = x + self._sa_block(self.norm1(x), tgt_is_causal)
|
| 1177 |
+
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
|
| 1178 |
+
if self.use_moe:
|
| 1179 |
+
m, total_aux_loss, balance_loss, router_z_loss = self.moe_block(x)
|
| 1180 |
+
x = x + m
|
| 1181 |
+
else:
|
| 1182 |
+
x = x + self._ff_block(self.norm3(x))
|
| 1183 |
+
else:
|
| 1184 |
+
x = self.norm1(x + self._sa_block(x, tgt_is_causal))
|
| 1185 |
+
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
|
| 1186 |
+
if self.use_moe:
|
| 1187 |
+
m, total_aux_loss, balance_loss, router_z_loss = self.moe_block(x)
|
| 1188 |
+
x = x + m
|
| 1189 |
+
else:
|
| 1190 |
+
x = self.norm3(x + self._ff_block(x))
|
| 1191 |
+
|
| 1192 |
+
if self.use_moe:
|
| 1193 |
+
return x, total_aux_loss, balance_loss, router_z_loss
|
| 1194 |
+
else:
|
| 1195 |
+
return x
|
| 1196 |
+
|
| 1197 |
+
|
| 1198 |
+
# self-attention block
|
| 1199 |
+
def _sa_block(self, x: Tensor,
|
| 1200 |
+
is_causal: bool = False) -> Tensor:
|
| 1201 |
+
x = self.self_attn(x, is_causal=is_causal)
|
| 1202 |
+
return self.dropout1(x)
|
| 1203 |
+
|
| 1204 |
+
# multihead attention block
|
| 1205 |
+
def _mha_block(self, x: Tensor, mem: Tensor,
|
| 1206 |
+
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
|
| 1207 |
+
x = self.multihead_attn(x, mem, mem,
|
| 1208 |
+
attn_mask=attn_mask,
|
| 1209 |
+
key_padding_mask=key_padding_mask,
|
| 1210 |
+
is_causal=is_causal,
|
| 1211 |
+
need_weights=False)[0]
|
| 1212 |
+
return self.dropout2(x)
|
| 1213 |
+
|
| 1214 |
+
# feed forward block
|
| 1215 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 1216 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 1217 |
+
return self.dropout3(x)
|
| 1218 |
+
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
+
def _get_clones(module, N):
|
| 1222 |
+
# FIXME: copy.deepcopy() is not defined on nn.module
|
| 1223 |
+
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
| 1227 |
+
if activation == "relu":
|
| 1228 |
+
return F.relu
|
| 1229 |
+
elif activation == "gelu":
|
| 1230 |
+
return F.gelu
|
| 1231 |
+
|
| 1232 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}")
|
| 1233 |
+
|
| 1234 |
+
|
| 1235 |
+
def _detect_is_causal_mask(
|
| 1236 |
+
mask: Optional[Tensor],
|
| 1237 |
+
is_causal: Optional[bool] = None,
|
| 1238 |
+
size: Optional[int] = None,
|
| 1239 |
+
) -> bool:
|
| 1240 |
+
"""Return whether the given attention mask is causal.
|
| 1241 |
+
|
| 1242 |
+
Warning:
|
| 1243 |
+
If ``is_causal`` is not ``None``, its value will be returned as is. If a
|
| 1244 |
+
user supplies an incorrect ``is_causal`` hint,
|
| 1245 |
+
|
| 1246 |
+
``is_causal=False`` when the mask is in fact a causal attention.mask
|
| 1247 |
+
may lead to reduced performance relative to what would be achievable
|
| 1248 |
+
with ``is_causal=True``;
|
| 1249 |
+
``is_causal=True`` when the mask is in fact not a causal attention.mask
|
| 1250 |
+
may lead to incorrect and unpredictable execution - in some scenarios,
|
| 1251 |
+
a causal mask may be applied based on the hint, in other execution
|
| 1252 |
+
scenarios the specified mask may be used. The choice may not appear
|
| 1253 |
+
to be deterministic, in that a number of factors like alignment,
|
| 1254 |
+
hardware SKU, etc influence the decision whether to use a mask or
|
| 1255 |
+
rely on the hint.
|
| 1256 |
+
``size`` if not None, check whether the mask is a causal mask of the provided size
|
| 1257 |
+
Otherwise, checks for any causal mask.
|
| 1258 |
+
"""
|
| 1259 |
+
# Prevent type refinement
|
| 1260 |
+
make_causal = (is_causal is True)
|
| 1261 |
+
|
| 1262 |
+
if is_causal is None and mask is not None:
|
| 1263 |
+
sz = size if size is not None else mask.size(-2)
|
| 1264 |
+
causal_comparison = _generate_square_subsequent_mask(
|
| 1265 |
+
sz, device=mask.device, dtype=mask.dtype)
|
| 1266 |
+
|
| 1267 |
+
# Do not use `torch.equal` so we handle batched masks by
|
| 1268 |
+
# broadcasting the comparison.
|
| 1269 |
+
if mask.size() == causal_comparison.size():
|
| 1270 |
+
make_causal = bool((mask == causal_comparison).all())
|
| 1271 |
+
else:
|
| 1272 |
+
make_causal = False
|
| 1273 |
+
|
| 1274 |
+
return make_causal
|
| 1275 |
+
|
| 1276 |
+
def check_instruments(genereated_seq):
|
| 1277 |
+
ins_present = []
|
| 1278 |
+
ins_count = 0
|
| 1279 |
+
instrument_list = ["piano", "chromatic", "organ", "guitar", "bass", "strings", "ensemble", "brass", "reed", "drum", "pipe", "synth_lead", "synth_pad", "synth_effect", "ethnic", "percussive", "sfx"]
|
| 1280 |
+
for token in genereated_seq:
|
| 1281 |
+
try:
|
| 1282 |
+
ins, pitch, vel = token
|
| 1283 |
+
# print(str(ins))
|
| 1284 |
+
except ValueError:
|
| 1285 |
+
try:
|
| 1286 |
+
ins, pitch = token
|
| 1287 |
+
except ValueError:
|
| 1288 |
+
ins = token
|
| 1289 |
+
if str(ins) in instrument_list:
|
| 1290 |
+
# print('coming here')
|
| 1291 |
+
|
| 1292 |
+
if ('prefix', 'instrument', str(ins)) not in ins_present and ins_count < 15:
|
| 1293 |
+
ins_count += 1
|
| 1294 |
+
print(f'adding instrument {ins}')
|
| 1295 |
+
ins_present.append(('prefix', 'instrument', str(ins)))
|
| 1296 |
+
if ins_present != []:
|
| 1297 |
+
genereated_seq = ins_present + ['<S>']+ genereated_seq +['<E>']
|
| 1298 |
+
else:
|
| 1299 |
+
genereated_seq = genereated_seq +['<E>']
|
| 1300 |
+
print(genereated_seq)
|
| 1301 |
+
return genereated_seq
|
| 1302 |
+
|
| 1303 |
+
def process_caption(gpu_id, captions, model, tokenizer, r_tokenizer):
|
| 1304 |
+
# Detect device: CUDA, MPS, or CPU
|
| 1305 |
+
if torch.cuda.is_available():
|
| 1306 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 1307 |
+
torch.cuda.set_device(gpu_id)
|
| 1308 |
+
print(f"Using CUDA on GPU {gpu_id}")
|
| 1309 |
+
elif torch.backends.mps.is_available():
|
| 1310 |
+
device = torch.device("mps")
|
| 1311 |
+
print("Using MPS on macOS")
|
| 1312 |
+
else:
|
| 1313 |
+
device = torch.device("cpu")
|
| 1314 |
+
print("Using CPU")
|
| 1315 |
+
|
| 1316 |
+
# Move the model to the selected device
|
| 1317 |
+
model.to(device)
|
| 1318 |
+
model.eval()
|
| 1319 |
+
|
| 1320 |
+
for caption in captions:
|
| 1321 |
+
src = caption['caption']
|
| 1322 |
+
location = caption['location']
|
| 1323 |
+
|
| 1324 |
+
# Tokenize input
|
| 1325 |
+
inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
|
| 1326 |
+
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
| 1327 |
+
input_ids = input_ids.to(device)
|
| 1328 |
+
attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
| 1329 |
+
attention_mask = attention_mask.to(device)
|
| 1330 |
+
|
| 1331 |
+
# Generate output
|
| 1332 |
+
output = model.generate(input_ids, attention_mask, max_len=5000, temperature=0.9)
|
| 1333 |
+
output_list = output[0].tolist()
|
| 1334 |
+
|
| 1335 |
+
# Decode MIDI and save it
|
| 1336 |
+
generated_midi = r_tokenizer.decode(output_list)
|
| 1337 |
+
generated_midi.dump_midi(f"../res/{location}")
|
| 1338 |
+
|
| 1339 |
+
# def process_caption(gpu_id, captions, model, tokenizer, r_tokenizer):
|
| 1340 |
+
# device = gpu_id
|
| 1341 |
+
# torch.cuda.set_device(gpu_id)
|
| 1342 |
+
# model.to(gpu_id)
|
| 1343 |
+
# model.eval()
|
| 1344 |
+
# for caption in captions:
|
| 1345 |
+
# src = caption['caption']
|
| 1346 |
+
# location = caption['location']
|
| 1347 |
+
# #src = "A cinematic electronic soundtrack that evokes an epic and dark atmosphere, featuring cello, contrabass, and drums. The song is set in A minor with a moderate tempo and a 4/4 time signature, creating an emotional and action-packed ambiance suitable for film."
|
| 1348 |
+
# '''
|
| 1349 |
+
# example 1: "A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, vibraphone, bass, and drums, set in the key of Eb minor with a fast tempo of 123 bpm and a 4/4 time signature, creating a joyful and relaxing atmosphere."lmd_full/1/1b9f5f325c2080d345d877f590aa3dbe.mid
|
| 1350 |
+
# example 2: "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."lmd_full/1/152891ac63017b234c33e75e4a4a28c5.mid
|
| 1351 |
+
# example 3: "This motivational electronic and pop song features a clean electric guitar, rock organ, synth voice, acoustic guitar, and vibraphone, creating a melodic and uplifting atmosphere. Set in the key of G# minor with a 4/4 time signature, the track moves at an energetic Allegro tempo of 120 beats per minute. The chord progression of Bbm7 and F# adds to the song's inspiring and corporate feel." lmd_full/1/14347e50e9e8149a9da09f49b188180b.mid
|
| 1352 |
+
# example 4: "This short electronic song in C minor features a brass section, string ensemble, tenor saxophone, clean electric guitar, and slap bass, creating a melodic and slightly dark atmosphere. With a tempo of 124 BPM (Allegro) and a 4/4 time signature, the track incorporates a chord progression of C7/E, Eb6, and Bbm6, adding a touch of corporate and motivational vibes to the overall composition." lmd_full/1/1dc4cd50a5509d8042d27d80bc7e668e.mid
|
| 1353 |
+
# example 5: "An energetic and melodic electronic trance track with a space and retro vibe, featuring drums, distortion guitar, flute, synth bass, and slap bass. Set in A minor with a fast tempo of 138 BPM, the song maintains a 4/4 time signature throughout its duration." lmd_full/3/3328b854ebe7a2fc9a746ede74c410ae.mid
|
| 1354 |
+
# example 6: "A short but energetic rock fragment in C minor, featuring overdriven guitars, electric bass, and drums, with a vivacious tempo of 155 BPM and a 4/4 time signature, evoking a blend of dark and melodic tones." lmd_full/4/4c2232688c5f869b8470a408d197f5e3.mid
|
| 1355 |
+
# example 7: "A classical piece with a cinematic flair, this composition is characterized by its fast tempo and 4/4 time signature. The soprano saxophone and flute take turns leading the melody, supported by the lush tones of the string ensemble, acoustic bass, and pan flute. Set in the key of F minor, the harmonic landscape is painted with the chords Gm7b5, Cm7b5, Fm7, Eaug, and Ab/Eb. The overall mood evokes images of film, with hints of Christmas, drama, documentary, and adventure." lmd_full/9/95bce1b489a11829b4fef39200291f60.mid
|
| 1356 |
+
# exmaple 8: "A slow, dark, and emotional classical piece featuring cello, violin, and viola, likely to be used in a dramatic film soundtrack. The composition is in the key of C minor with a 4/4 time signature, and the main chord progression consists of Cm, G, Cm, and Fm." lmd_full/a/a22aad98ecfe4b3d8a353c2a72132834.mid
|
| 1357 |
+
# example 9: "A slow and emotional classical piece, likely used in a film soundtrack, featuring a church organ as the sole instrument. Written in the key of Eb major with a 3/4 time signature, it evokes a sense of drama and romance. The chord progression of Bb7, Eb, and Ab contributes to the relaxing atmosphere throughout the song." lmd_full/a/af4302a036c9df71e0435df9b08f8c4b.mid
|
| 1358 |
+
# example 10: "A cinematic electronic soundtrack that evokes an epic and dark atmosphere, featuring cello, contrabass, and drums. The song is set in A minor with a moderate tempo and a 4/4 time signature, creating an emotional and action-packed ambiance suitable for film." lmd_full/d/d920b6f451d7a72ae06f154e7c06c4c1.mid
|
| 1359 |
+
# '''
|
| 1360 |
+
# inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
|
| 1361 |
+
# input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
| 1362 |
+
# input_ids = input_ids.to(device)
|
| 1363 |
+
# attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
| 1364 |
+
# attention_mask = attention_mask.to(device)
|
| 1365 |
+
# output = model.generate(input_ids, attention_mask,max_len=5000,temperature = 0.9)
|
| 1366 |
+
# output_list = output[0].tolist()
|
| 1367 |
+
# print(type(output_list))
|
| 1368 |
+
# # generated_sequences = [dict_tokenizer[token] for token in output_list[0]]
|
| 1369 |
+
# # generated_sequences = check_instruments(generated_sequences)
|
| 1370 |
+
# # # generated_sequences = [('prefix', 'instrument', 'bass'), ('prefix', 'instrument', 'guitar'), ('prefix', 'instrument', 'piano'), ('prefix', 'instrument', 'guitar'), '<S>' ]+ generated_sequences +['<E>']
|
| 1371 |
+
# # generated_sequences = [token for token in generated_sequences]# if token not in ["<SS>", "<S>", "<E>", "<SEP>"]]
|
| 1372 |
+
# # # print("Generated sequences:", generated_sequences)
|
| 1373 |
+
# # with open('../../generated_seq.pkl', 'wb') as f:
|
| 1374 |
+
# # pickle.dump(generated_sequences, f)
|
| 1375 |
+
# # mid_dict = aria_tokenizer.detokenize(generated_sequences)
|
| 1376 |
+
# # mid = mid_dict.to_midi()
|
| 1377 |
+
# generated_midi = r_tokenizer.decode(output_list)
|
| 1378 |
+
# # print(type(generated_midi))
|
| 1379 |
+
# generated_midi.dump_midi(f"../res/{location}")
|
| 1380 |
+
|
| 1381 |
+
def test_generate(caption):
|
| 1382 |
+
# Detect device: CUDA, MPS, or CPU
|
| 1383 |
+
if torch.cuda.is_available():
|
| 1384 |
+
device = torch.device("cuda")
|
| 1385 |
+
print("Using CUDA on NVIDIA GPU")
|
| 1386 |
+
elif torch.backends.mps.is_available():
|
| 1387 |
+
device = torch.device("mps")
|
| 1388 |
+
print("Using MPS on macOS")
|
| 1389 |
+
else:
|
| 1390 |
+
device = torch.device("cpu")
|
| 1391 |
+
print("Using CPU")
|
| 1392 |
+
|
| 1393 |
+
artifact_folder = '../artifacts'
|
| 1394 |
+
tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
|
| 1395 |
+
caption_dataset_path = '/root/text2midi/captions/train.json'
|
| 1396 |
+
print(f'caption_dataset_path: {caption_dataset_path}')
|
| 1397 |
+
|
| 1398 |
+
# Load the tokenizer dictionary
|
| 1399 |
+
with open(tokenizer_filepath, "rb") as f:
|
| 1400 |
+
r_tokenizer = pickle.load(f)
|
| 1401 |
+
vocab_size = len(r_tokenizer) # +1
|
| 1402 |
+
print("Vocab size: ", vocab_size)
|
| 1403 |
+
|
| 1404 |
+
# Initialize model
|
| 1405 |
+
model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
|
| 1406 |
+
model.load_state_dict(torch.load('/root/test/text2midi/output_new/epoch_30/pytorch_model.bin', map_location=device))
|
| 1407 |
+
model.to(device) # Move model to detected device
|
| 1408 |
+
model.eval()
|
| 1409 |
+
|
| 1410 |
+
# Prepare input
|
| 1411 |
+
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 1412 |
+
|
| 1413 |
+
'''
|
| 1414 |
+
# num_gpus = torch.cuda.device_count()
|
| 1415 |
+
# captions_per_gpu = len(captions) // num_gpus
|
| 1416 |
+
# processes = []
|
| 1417 |
+
# for i in range(num_gpus):
|
| 1418 |
+
# start_idx = i * captions_per_gpu
|
| 1419 |
+
# end_idx = (i + 1) * captions_per_gpu if i != num_gpus - 1 else len(captions)
|
| 1420 |
+
# p = mp.Process(target=process_caption, args=(i, captions[start_idx:end_idx], model, tokenizer, r_tokenizer))
|
| 1421 |
+
# p.start()
|
| 1422 |
+
# processes.append(p)
|
| 1423 |
+
|
| 1424 |
+
# for p in processes:
|
| 1425 |
+
# p.join()
|
| 1426 |
+
'''
|
| 1427 |
+
# src = "A pop song with nostalgic feeling."
|
| 1428 |
+
# src = "A happy christmas song suitable for festive mood."
|
| 1429 |
+
# src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
|
| 1430 |
+
# src="An energetic and melodic electronic trance track with a space and retro vibe, featuring drums, distortion guitar, flute, synth bass, and slap bass. Set in A minor with a fast tempo of 138 BPM, the song maintains a 4/4 time signature throughout its duration."
|
| 1431 |
+
# src="A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, vibraphone, bass, and drums, set in the key of Eb minor with a fast tempo of 123 bpm and a 4/4 time signature, creating a joyful and relaxing atmosphere."
|
| 1432 |
+
# src = "This short electronic song in C minor features a brass section, string ensemble, tenor saxophone, clean electric guitar, and slap bass, creating a melodic and slightly dark atmosphere. With a tempo of 124 BPM (Allegro) and a 4/4 time signature, the track incorporates a chord progression of C7/E, Eb6, and Bbm6, adding a touch of corporate and motivational vibes to the overall composition."
|
| 1433 |
+
# src="This motivational electronic and pop song features a clean electric guitar, rock organ, synth voice, acoustic guitar, and vibraphone, creating a melodic and uplifting atmosphere. Set in the key of G# minor with a 4/4 time signature, the track moves at an energetic Allegro tempo of 120 beats per minute. The chord progression of Bbm7 and F# adds to the song's inspiring and corporate feel."
|
| 1434 |
+
# src = "Played at 149 beats per minute in 2/4 time signature and the key of G major, classical piece with instruments: bassoon, clarinet, flute, horn, oboe, and trumpet."
|
| 1435 |
+
# src= 'Played at 114 beats per minute in 1/4 time signature and the key of g# minor, classical piece with the following instruments: clarinet, english horn, flute, horn, piccolo, trombone, and trumpet.'
|
| 1436 |
+
inputs = tokenizer(caption, return_tensors='pt', padding=True, truncation=True)
|
| 1437 |
+
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
| 1438 |
+
input_ids = input_ids.to(device)
|
| 1439 |
+
attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
| 1440 |
+
attention_mask = attention_mask.to(device)
|
| 1441 |
+
output = model.generate(input_ids, attention_mask, max_len=2000, temperature=0.9)
|
| 1442 |
+
output_list = output[0].tolist()
|
| 1443 |
+
|
| 1444 |
+
# Decode and save MIDI
|
| 1445 |
+
generated_midi = r_tokenizer.decode(output_list)
|
| 1446 |
+
generated_midi.dump_midi(f"../../output_christmas_2.mid")
|
| 1447 |
+
|
| 1448 |
+
def load_model_and_tokenizer(accelerator, model_path, vocab_size, tokenizer_filepath):
|
| 1449 |
+
device = accelerator.device
|
| 1450 |
+
with open(tokenizer_filepath, "rb") as f:
|
| 1451 |
+
r_tokenizer = pickle.load(f)
|
| 1452 |
+
model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
|
| 1453 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 1454 |
+
model.to(device)
|
| 1455 |
+
model.eval()
|
| 1456 |
+
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 1457 |
+
return model, tokenizer, r_tokenizer
|
| 1458 |
+
|
| 1459 |
+
def process_example(accelerator, model, tokenizer, r_tokenizer, example, location, output_path):
|
| 1460 |
+
device = accelerator.device
|
| 1461 |
+
inputs = tokenizer(example, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 1462 |
+
input_ids = inputs['input_ids']
|
| 1463 |
+
attention_mask = inputs['attention_mask']
|
| 1464 |
+
with torch.no_grad():
|
| 1465 |
+
output = model.module.generate(input_ids, attention_mask, max_len=2000, temperature=0.9)
|
| 1466 |
+
output_list = output[0].tolist()
|
| 1467 |
+
generated_midi = r_tokenizer.decode(output_list)
|
| 1468 |
+
generated_midi.dump_midi(output_path)
|
| 1469 |
+
|
| 1470 |
+
def run_accelerate_generation():
|
| 1471 |
+
accelerator = Accelerator()
|
| 1472 |
+
artifact_folder = '../artifacts'
|
| 1473 |
+
tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
|
| 1474 |
+
model_path = '/root/output_test_new/epoch_30/pytorch_model.bin'
|
| 1475 |
+
captions_path = '/root/captions/train.json'
|
| 1476 |
+
|
| 1477 |
+
with jsonlines.open(captions_path) as reader:
|
| 1478 |
+
selected_captions = [line for line in reader if line.get('test_set') is True]
|
| 1479 |
+
|
| 1480 |
+
with open(tokenizer_filepath, "rb") as f:
|
| 1481 |
+
r_tokenizer = pickle.load(f)
|
| 1482 |
+
|
| 1483 |
+
model, tokenizer, r_tokenizer = load_model_and_tokenizer(accelerator, model_path, len(r_tokenizer), tokenizer_filepath)
|
| 1484 |
+
model = accelerator.prepare(model)
|
| 1485 |
+
|
| 1486 |
+
dataset = CaptionDataset(selected_captions)
|
| 1487 |
+
dataloader = DataLoader(dataset, batch_size=8, num_workers=4, shuffle=False, collate_fn=custom_collate_fn)
|
| 1488 |
+
dataloader = accelerator.prepare(dataloader)
|
| 1489 |
+
|
| 1490 |
+
for captions, locations in dataloader:
|
| 1491 |
+
for example, location in zip(captions, locations):
|
| 1492 |
+
output_path = os.path.join(f'/root/Text2midi/res_acc', location)
|
| 1493 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 1494 |
+
process_example(accelerator, model, tokenizer, r_tokenizer, example, location, output_path)
|
| 1495 |
+
|
| 1496 |
+
# run_accelerate_generation() #uncomment this and comment __main__ to run accelerate generation
|
| 1497 |
+
|
| 1498 |
+
def main():
|
| 1499 |
+
parser = argparse.ArgumentParser(description="Generate MIDI from caption")
|
| 1500 |
+
parser.add_argument('--caption', type=str, required=True, help='Caption to generate MIDI from')
|
| 1501 |
+
args = parser.parse_args()
|
| 1502 |
+
test_generate(args.caption)
|
| 1503 |
+
|
| 1504 |
+
'''
|
| 1505 |
+
comment out the next section function and uncomment the run_accelerate_generation() function to run the accelerate generation
|
| 1506 |
+
'''
|
| 1507 |
+
if __name__ == "__main__":
|
| 1508 |
+
main()
|
| 1509 |
+
print("Done")
|
text2midi_repo/requirements-mac.txt
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This file is autogenerated by pip-compile with Python 3.10
|
| 3 |
+
# by the following command:
|
| 4 |
+
#
|
| 5 |
+
# pip-compile requirements.in
|
| 6 |
+
#
|
| 7 |
+
accelerate==0.18.0
|
| 8 |
+
# via -r requirements.in
|
| 9 |
+
aiohappyeyeballs==2.4.4
|
| 10 |
+
# via aiohttp
|
| 11 |
+
aiohttp==3.11.10
|
| 12 |
+
# via
|
| 13 |
+
# datasets
|
| 14 |
+
# fsspec
|
| 15 |
+
aiosignal==1.3.1
|
| 16 |
+
# via aiohttp
|
| 17 |
+
annotated-types==0.7.0
|
| 18 |
+
# via pydantic
|
| 19 |
+
async-timeout==5.0.1
|
| 20 |
+
# via aiohttp
|
| 21 |
+
attrs==24.2.0
|
| 22 |
+
# via
|
| 23 |
+
# aiohttp
|
| 24 |
+
# jsonlines
|
| 25 |
+
beartype==0.19.0
|
| 26 |
+
# via st-moe-pytorch
|
| 27 |
+
blis==1.0.1
|
| 28 |
+
# via thinc
|
| 29 |
+
catalogue==2.0.10
|
| 30 |
+
# via
|
| 31 |
+
# spacy
|
| 32 |
+
# srsly
|
| 33 |
+
# thinc
|
| 34 |
+
certifi==2024.8.30
|
| 35 |
+
# via
|
| 36 |
+
# requests
|
| 37 |
+
# sentry-sdk
|
| 38 |
+
charset-normalizer==3.4.0
|
| 39 |
+
# via requests
|
| 40 |
+
click==8.1.7
|
| 41 |
+
# via
|
| 42 |
+
# typer
|
| 43 |
+
# wandb
|
| 44 |
+
cloudpathlib==0.20.0
|
| 45 |
+
# via weasel
|
| 46 |
+
colt5-attention==0.11.1
|
| 47 |
+
# via st-moe-pytorch
|
| 48 |
+
confection==0.1.5
|
| 49 |
+
# via
|
| 50 |
+
# thinc
|
| 51 |
+
# weasel
|
| 52 |
+
cymem==2.0.10
|
| 53 |
+
# via
|
| 54 |
+
# preshed
|
| 55 |
+
# spacy
|
| 56 |
+
# thinc
|
| 57 |
+
datasets==3.1.0
|
| 58 |
+
# via evaluate
|
| 59 |
+
dill==0.3.8
|
| 60 |
+
# via
|
| 61 |
+
# datasets
|
| 62 |
+
# evaluate
|
| 63 |
+
# multiprocess
|
| 64 |
+
docker-pycreds==0.4.0
|
| 65 |
+
# via wandb
|
| 66 |
+
einops==0.8.0
|
| 67 |
+
# via
|
| 68 |
+
# -r requirements.in
|
| 69 |
+
# colt5-attention
|
| 70 |
+
# local-attention
|
| 71 |
+
# st-moe-pytorch
|
| 72 |
+
evaluate==0.4.3
|
| 73 |
+
# via -r requirements.in
|
| 74 |
+
filelock==3.16.1
|
| 75 |
+
# via
|
| 76 |
+
# datasets
|
| 77 |
+
# huggingface-hub
|
| 78 |
+
# torch
|
| 79 |
+
# transformers
|
| 80 |
+
# triton
|
| 81 |
+
frozenlist==1.5.0
|
| 82 |
+
# via
|
| 83 |
+
# aiohttp
|
| 84 |
+
# aiosignal
|
| 85 |
+
fsspec[http]==2024.9.0
|
| 86 |
+
# via
|
| 87 |
+
# datasets
|
| 88 |
+
# evaluate
|
| 89 |
+
# huggingface-hub
|
| 90 |
+
# torch
|
| 91 |
+
gitdb==4.0.11
|
| 92 |
+
# via gitpython
|
| 93 |
+
gitpython==3.1.43
|
| 94 |
+
# via wandb
|
| 95 |
+
huggingface-hub==0.26.3
|
| 96 |
+
# via
|
| 97 |
+
# accelerate
|
| 98 |
+
# datasets
|
| 99 |
+
# evaluate
|
| 100 |
+
# miditok
|
| 101 |
+
# tokenizers
|
| 102 |
+
# transformers
|
| 103 |
+
idna==3.10
|
| 104 |
+
# via
|
| 105 |
+
# requests
|
| 106 |
+
# yarl
|
| 107 |
+
jinja2==3.1.4
|
| 108 |
+
# via
|
| 109 |
+
# spacy
|
| 110 |
+
# torch
|
| 111 |
+
jsonlines==4.0.0
|
| 112 |
+
# via -r requirements.in
|
| 113 |
+
langcodes==3.5.0
|
| 114 |
+
# via spacy
|
| 115 |
+
language-data==1.3.0
|
| 116 |
+
# via langcodes
|
| 117 |
+
local-attention==1.9.15
|
| 118 |
+
# via colt5-attention
|
| 119 |
+
marisa-trie==1.2.1
|
| 120 |
+
# via language-data
|
| 121 |
+
markdown-it-py==3.0.0
|
| 122 |
+
# via rich
|
| 123 |
+
markupsafe==3.0.2
|
| 124 |
+
# via jinja2
|
| 125 |
+
mdurl==0.1.2
|
| 126 |
+
# via markdown-it-py
|
| 127 |
+
miditok==3.0.3
|
| 128 |
+
# via -r requirements.in
|
| 129 |
+
mpmath==1.3.0
|
| 130 |
+
# via sympy
|
| 131 |
+
multidict==6.1.0
|
| 132 |
+
# via
|
| 133 |
+
# aiohttp
|
| 134 |
+
# yarl
|
| 135 |
+
multiprocess==0.70.16
|
| 136 |
+
# via
|
| 137 |
+
# datasets
|
| 138 |
+
# evaluate
|
| 139 |
+
murmurhash==1.0.11
|
| 140 |
+
# via
|
| 141 |
+
# preshed
|
| 142 |
+
# spacy
|
| 143 |
+
# thinc
|
| 144 |
+
networkx==3.4.2
|
| 145 |
+
# via torch
|
| 146 |
+
numpy==2.0.2
|
| 147 |
+
# via
|
| 148 |
+
# -r requirements.in
|
| 149 |
+
# accelerate
|
| 150 |
+
# blis
|
| 151 |
+
# datasets
|
| 152 |
+
# evaluate
|
| 153 |
+
# miditok
|
| 154 |
+
# pandas
|
| 155 |
+
# spacy
|
| 156 |
+
# symusic
|
| 157 |
+
# thinc
|
| 158 |
+
# transformers
|
| 159 |
+
#nvidia-cublas-cu12==12.4.5.8
|
| 160 |
+
# via
|
| 161 |
+
# nvidia-cudnn-cu12
|
| 162 |
+
# nvidia-cusolver-cu12
|
| 163 |
+
# torch
|
| 164 |
+
#nvidia-cuda-cupti-cu12==12.4.127
|
| 165 |
+
# via torch
|
| 166 |
+
#nvidia-cuda-nvrtc-cu12==12.4.127
|
| 167 |
+
# via torch
|
| 168 |
+
#nvidia-cuda-runtime-cu12==12.4.127
|
| 169 |
+
# via torch
|
| 170 |
+
#nvidia-cudnn-cu12==9.1.0.70
|
| 171 |
+
# via torch
|
| 172 |
+
#nvidia-cufft-cu12==11.2.1.3
|
| 173 |
+
# via torch
|
| 174 |
+
#nvidia-curand-cu12==10.3.5.147
|
| 175 |
+
# via torch
|
| 176 |
+
#nvidia-cusolver-cu12==11.6.1.9
|
| 177 |
+
# via torch
|
| 178 |
+
#nvidia-cusparse-cu12==12.3.1.170
|
| 179 |
+
# via
|
| 180 |
+
# nvidia-cusolver-cu12
|
| 181 |
+
# torch
|
| 182 |
+
#nvidia-nccl-cu12==2.21.5
|
| 183 |
+
# via torch
|
| 184 |
+
#nvidia-nvjitlink-cu12==12.4.127
|
| 185 |
+
# via
|
| 186 |
+
# nvidia-cusolver-cu12
|
| 187 |
+
# nvidia-cusparse-cu12
|
| 188 |
+
# torch
|
| 189 |
+
#nvidia-nvtx-cu12==12.4.127
|
| 190 |
+
# via torch
|
| 191 |
+
packaging==24.2
|
| 192 |
+
# via
|
| 193 |
+
# accelerate
|
| 194 |
+
# colt5-attention
|
| 195 |
+
# datasets
|
| 196 |
+
# evaluate
|
| 197 |
+
# huggingface-hub
|
| 198 |
+
# spacy
|
| 199 |
+
# thinc
|
| 200 |
+
# transformers
|
| 201 |
+
# weasel
|
| 202 |
+
pandas==2.2.3
|
| 203 |
+
# via
|
| 204 |
+
# datasets
|
| 205 |
+
# evaluate
|
| 206 |
+
platformdirs==4.3.6
|
| 207 |
+
# via
|
| 208 |
+
# symusic
|
| 209 |
+
# wandb
|
| 210 |
+
preshed==3.0.9
|
| 211 |
+
# via
|
| 212 |
+
# spacy
|
| 213 |
+
# thinc
|
| 214 |
+
propcache==0.2.1
|
| 215 |
+
# via
|
| 216 |
+
# aiohttp
|
| 217 |
+
# yarl
|
| 218 |
+
protobuf==5.29.1
|
| 219 |
+
# via wandb
|
| 220 |
+
psutil==6.1.0
|
| 221 |
+
# via
|
| 222 |
+
# accelerate
|
| 223 |
+
# wandb
|
| 224 |
+
pyarrow==18.1.0
|
| 225 |
+
# via datasets
|
| 226 |
+
pydantic==2.10.3
|
| 227 |
+
# via
|
| 228 |
+
# confection
|
| 229 |
+
# spacy
|
| 230 |
+
# thinc
|
| 231 |
+
# wandb
|
| 232 |
+
# weasel
|
| 233 |
+
pydantic-core==2.27.1
|
| 234 |
+
# via pydantic
|
| 235 |
+
pygments==2.18.0
|
| 236 |
+
# via rich
|
| 237 |
+
pysmartdl==1.3.4
|
| 238 |
+
# via symusic
|
| 239 |
+
python-dateutil==2.9.0.post0
|
| 240 |
+
# via pandas
|
| 241 |
+
pytz==2024.2
|
| 242 |
+
# via pandas
|
| 243 |
+
pyyaml==6.0.2
|
| 244 |
+
# via
|
| 245 |
+
# -r requirements.in
|
| 246 |
+
# accelerate
|
| 247 |
+
# datasets
|
| 248 |
+
# huggingface-hub
|
| 249 |
+
# transformers
|
| 250 |
+
# wandb
|
| 251 |
+
regex==2024.11.6
|
| 252 |
+
# via transformers
|
| 253 |
+
requests==2.32.3
|
| 254 |
+
# via
|
| 255 |
+
# datasets
|
| 256 |
+
# evaluate
|
| 257 |
+
# huggingface-hub
|
| 258 |
+
# spacy
|
| 259 |
+
# transformers
|
| 260 |
+
# wandb
|
| 261 |
+
# weasel
|
| 262 |
+
rich==13.9.4
|
| 263 |
+
# via typer
|
| 264 |
+
safetensors==0.4.5
|
| 265 |
+
# via
|
| 266 |
+
# accelerate
|
| 267 |
+
# transformers
|
| 268 |
+
sentry-sdk==2.19.2
|
| 269 |
+
# via wandb
|
| 270 |
+
sentencepiece==0.2.0
|
| 271 |
+
|
| 272 |
+
setproctitle==1.3.4
|
| 273 |
+
# via wandb
|
| 274 |
+
shellingham==1.5.4
|
| 275 |
+
# via typer
|
| 276 |
+
six==1.17.0
|
| 277 |
+
# via
|
| 278 |
+
# docker-pycreds
|
| 279 |
+
# python-dateutil
|
| 280 |
+
smart-open==7.0.5
|
| 281 |
+
# via weasel
|
| 282 |
+
smmap==5.0.1
|
| 283 |
+
# via gitdb
|
| 284 |
+
spacy==3.8.2
|
| 285 |
+
# via -r requirements.in
|
| 286 |
+
spacy-legacy==3.0.12
|
| 287 |
+
# via spacy
|
| 288 |
+
spacy-loggers==1.0.5
|
| 289 |
+
# via spacy
|
| 290 |
+
srsly==2.4.8
|
| 291 |
+
# via
|
| 292 |
+
# confection
|
| 293 |
+
# spacy
|
| 294 |
+
# thinc
|
| 295 |
+
# weasel
|
| 296 |
+
st-moe-pytorch==0.1.8
|
| 297 |
+
# via -r requirements.in
|
| 298 |
+
sympy==1.13.1
|
| 299 |
+
# via torch
|
| 300 |
+
symusic==0.5.5
|
| 301 |
+
# via miditok
|
| 302 |
+
thinc==8.3.2
|
| 303 |
+
# via spacy
|
| 304 |
+
tokenizers==0.21.0
|
| 305 |
+
# via
|
| 306 |
+
# miditok
|
| 307 |
+
# transformers
|
| 308 |
+
torch==2.5.1
|
| 309 |
+
# via
|
| 310 |
+
# -r requirements.in
|
| 311 |
+
# accelerate
|
| 312 |
+
# colt5-attention
|
| 313 |
+
# local-attention
|
| 314 |
+
# st-moe-pytorch
|
| 315 |
+
tqdm==4.67.1
|
| 316 |
+
# via
|
| 317 |
+
# -r requirements.in
|
| 318 |
+
# datasets
|
| 319 |
+
# evaluate
|
| 320 |
+
# huggingface-hub
|
| 321 |
+
# miditok
|
| 322 |
+
# spacy
|
| 323 |
+
# transformers
|
| 324 |
+
transformers==4.47.0
|
| 325 |
+
# via -r requirements.in
|
| 326 |
+
#triton==3.1.0
|
| 327 |
+
# via torch
|
| 328 |
+
typer==0.15.1
|
| 329 |
+
# via
|
| 330 |
+
# spacy
|
| 331 |
+
# weasel
|
| 332 |
+
typing-extensions==4.12.2
|
| 333 |
+
# via
|
| 334 |
+
# cloudpathlib
|
| 335 |
+
# huggingface-hub
|
| 336 |
+
# multidict
|
| 337 |
+
# pydantic
|
| 338 |
+
# pydantic-core
|
| 339 |
+
# rich
|
| 340 |
+
# torch
|
| 341 |
+
# typer
|
| 342 |
+
# wandb
|
| 343 |
+
tzdata==2024.2
|
| 344 |
+
# via pandas
|
| 345 |
+
urllib3==2.2.3
|
| 346 |
+
# via
|
| 347 |
+
# requests
|
| 348 |
+
# sentry-sdk
|
| 349 |
+
wandb==0.19.0
|
| 350 |
+
# via -r requirements.in
|
| 351 |
+
wasabi==1.1.3
|
| 352 |
+
# via
|
| 353 |
+
# spacy
|
| 354 |
+
# thinc
|
| 355 |
+
# weasel
|
| 356 |
+
weasel==0.4.1
|
| 357 |
+
# via spacy
|
| 358 |
+
wrapt==1.17.0
|
| 359 |
+
# via smart-open
|
| 360 |
+
xxhash==3.5.0
|
| 361 |
+
# via
|
| 362 |
+
# datasets
|
| 363 |
+
# evaluate
|
| 364 |
+
yarl==1.18.3
|
| 365 |
+
# via aiohttp
|
| 366 |
+
|
| 367 |
+
# The following packages are considered to be unsafe in a requirements file:
|
| 368 |
+
# setuptools
|
text2midi_repo/requirements.txt
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# This file is autogenerated by pip-compile with Python 3.10
|
| 3 |
+
# by the following command:
|
| 4 |
+
#
|
| 5 |
+
# pip-compile requirements.in
|
| 6 |
+
#
|
| 7 |
+
accelerate==0.18.0
|
| 8 |
+
# via -r requirements.in
|
| 9 |
+
aiohappyeyeballs==2.4.4
|
| 10 |
+
# via aiohttp
|
| 11 |
+
aiohttp==3.11.10
|
| 12 |
+
# via
|
| 13 |
+
# datasets
|
| 14 |
+
# fsspec
|
| 15 |
+
aiosignal==1.3.1
|
| 16 |
+
# via aiohttp
|
| 17 |
+
annotated-types==0.7.0
|
| 18 |
+
# via pydantic
|
| 19 |
+
async-timeout==5.0.1
|
| 20 |
+
# via aiohttp
|
| 21 |
+
attrs==24.2.0
|
| 22 |
+
# via
|
| 23 |
+
# aiohttp
|
| 24 |
+
# jsonlines
|
| 25 |
+
beartype==0.19.0
|
| 26 |
+
# via st-moe-pytorch
|
| 27 |
+
blis==1.0.1
|
| 28 |
+
# via thinc
|
| 29 |
+
catalogue==2.0.10
|
| 30 |
+
# via
|
| 31 |
+
# spacy
|
| 32 |
+
# srsly
|
| 33 |
+
# thinc
|
| 34 |
+
certifi==2024.8.30
|
| 35 |
+
# via
|
| 36 |
+
# requests
|
| 37 |
+
# sentry-sdk
|
| 38 |
+
charset-normalizer==3.4.0
|
| 39 |
+
# via requests
|
| 40 |
+
click==8.1.7
|
| 41 |
+
# via
|
| 42 |
+
# typer
|
| 43 |
+
# wandb
|
| 44 |
+
cloudpathlib==0.20.0
|
| 45 |
+
# via weasel
|
| 46 |
+
colt5-attention==0.11.1
|
| 47 |
+
# via st-moe-pytorch
|
| 48 |
+
confection==0.1.5
|
| 49 |
+
# via
|
| 50 |
+
# thinc
|
| 51 |
+
# weasel
|
| 52 |
+
cymem==2.0.10
|
| 53 |
+
# via
|
| 54 |
+
# preshed
|
| 55 |
+
# spacy
|
| 56 |
+
# thinc
|
| 57 |
+
datasets==3.1.0
|
| 58 |
+
# via evaluate
|
| 59 |
+
dill==0.3.8
|
| 60 |
+
# via
|
| 61 |
+
# datasets
|
| 62 |
+
# evaluate
|
| 63 |
+
# multiprocess
|
| 64 |
+
docker-pycreds==0.4.0
|
| 65 |
+
# via wandb
|
| 66 |
+
einops==0.8.0
|
| 67 |
+
# via
|
| 68 |
+
# -r requirements.in
|
| 69 |
+
# colt5-attention
|
| 70 |
+
# local-attention
|
| 71 |
+
# st-moe-pytorch
|
| 72 |
+
evaluate==0.4.3
|
| 73 |
+
# via -r requirements.in
|
| 74 |
+
filelock==3.16.1
|
| 75 |
+
# via
|
| 76 |
+
# datasets
|
| 77 |
+
# huggingface-hub
|
| 78 |
+
# torch
|
| 79 |
+
# transformers
|
| 80 |
+
# triton
|
| 81 |
+
frozenlist==1.5.0
|
| 82 |
+
# via
|
| 83 |
+
# aiohttp
|
| 84 |
+
# aiosignal
|
| 85 |
+
fsspec[http]==2024.9.0
|
| 86 |
+
# via
|
| 87 |
+
# datasets
|
| 88 |
+
# evaluate
|
| 89 |
+
# huggingface-hub
|
| 90 |
+
# torch
|
| 91 |
+
gitdb==4.0.11
|
| 92 |
+
# via gitpython
|
| 93 |
+
gitpython==3.1.43
|
| 94 |
+
# via wandb
|
| 95 |
+
huggingface-hub==0.26.3
|
| 96 |
+
# via
|
| 97 |
+
# accelerate
|
| 98 |
+
# datasets
|
| 99 |
+
# evaluate
|
| 100 |
+
# miditok
|
| 101 |
+
# tokenizers
|
| 102 |
+
# transformers
|
| 103 |
+
idna==3.10
|
| 104 |
+
# via
|
| 105 |
+
# requests
|
| 106 |
+
# yarl
|
| 107 |
+
jinja2==3.1.4
|
| 108 |
+
# via
|
| 109 |
+
# spacy
|
| 110 |
+
# torch
|
| 111 |
+
jsonlines==4.0.0
|
| 112 |
+
# via -r requirements.in
|
| 113 |
+
langcodes==3.5.0
|
| 114 |
+
# via spacy
|
| 115 |
+
language-data==1.3.0
|
| 116 |
+
# via langcodes
|
| 117 |
+
local-attention==1.9.15
|
| 118 |
+
# via colt5-attention
|
| 119 |
+
marisa-trie==1.2.1
|
| 120 |
+
# via language-data
|
| 121 |
+
markdown-it-py==3.0.0
|
| 122 |
+
# via rich
|
| 123 |
+
markupsafe==3.0.2
|
| 124 |
+
# via jinja2
|
| 125 |
+
mdurl==0.1.2
|
| 126 |
+
# via markdown-it-py
|
| 127 |
+
miditok==3.0.3
|
| 128 |
+
# via -r requirements.in
|
| 129 |
+
mpmath==1.3.0
|
| 130 |
+
# via sympy
|
| 131 |
+
multidict==6.1.0
|
| 132 |
+
# via
|
| 133 |
+
# aiohttp
|
| 134 |
+
# yarl
|
| 135 |
+
multiprocess==0.70.16
|
| 136 |
+
# via
|
| 137 |
+
# datasets
|
| 138 |
+
# evaluate
|
| 139 |
+
murmurhash==1.0.11
|
| 140 |
+
# via
|
| 141 |
+
# preshed
|
| 142 |
+
# spacy
|
| 143 |
+
# thinc
|
| 144 |
+
networkx==3.4.2
|
| 145 |
+
# via torch
|
| 146 |
+
numpy==2.0.2
|
| 147 |
+
# via
|
| 148 |
+
# -r requirements.in
|
| 149 |
+
# accelerate
|
| 150 |
+
# blis
|
| 151 |
+
# datasets
|
| 152 |
+
# evaluate
|
| 153 |
+
# miditok
|
| 154 |
+
# pandas
|
| 155 |
+
# spacy
|
| 156 |
+
# symusic
|
| 157 |
+
# thinc
|
| 158 |
+
# transformers
|
| 159 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 160 |
+
# via
|
| 161 |
+
# nvidia-cudnn-cu12
|
| 162 |
+
# nvidia-cusolver-cu12
|
| 163 |
+
# torch
|
| 164 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 165 |
+
# via torch
|
| 166 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 167 |
+
# via torch
|
| 168 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 169 |
+
# via torch
|
| 170 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 171 |
+
# via torch
|
| 172 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 173 |
+
# via torch
|
| 174 |
+
nvidia-curand-cu12==10.3.5.147
|
| 175 |
+
# via torch
|
| 176 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 177 |
+
# via torch
|
| 178 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 179 |
+
# via
|
| 180 |
+
# nvidia-cusolver-cu12
|
| 181 |
+
# torch
|
| 182 |
+
nvidia-nccl-cu12==2.21.5
|
| 183 |
+
# via torch
|
| 184 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 185 |
+
# via
|
| 186 |
+
# nvidia-cusolver-cu12
|
| 187 |
+
# nvidia-cusparse-cu12
|
| 188 |
+
# torch
|
| 189 |
+
nvidia-nvtx-cu12==12.4.127
|
| 190 |
+
# via torch
|
| 191 |
+
packaging==24.2
|
| 192 |
+
# via
|
| 193 |
+
# accelerate
|
| 194 |
+
# colt5-attention
|
| 195 |
+
# datasets
|
| 196 |
+
# evaluate
|
| 197 |
+
# huggingface-hub
|
| 198 |
+
# spacy
|
| 199 |
+
# thinc
|
| 200 |
+
# transformers
|
| 201 |
+
# weasel
|
| 202 |
+
pandas==2.2.3
|
| 203 |
+
# via
|
| 204 |
+
# datasets
|
| 205 |
+
# evaluate
|
| 206 |
+
platformdirs==4.3.6
|
| 207 |
+
# via
|
| 208 |
+
# symusic
|
| 209 |
+
# wandb
|
| 210 |
+
preshed==3.0.9
|
| 211 |
+
# via
|
| 212 |
+
# spacy
|
| 213 |
+
# thinc
|
| 214 |
+
propcache==0.2.1
|
| 215 |
+
# via
|
| 216 |
+
# aiohttp
|
| 217 |
+
# yarl
|
| 218 |
+
protobuf==5.29.1
|
| 219 |
+
# via wandb
|
| 220 |
+
psutil==6.1.0
|
| 221 |
+
# via
|
| 222 |
+
# accelerate
|
| 223 |
+
# wandb
|
| 224 |
+
pyarrow==18.1.0
|
| 225 |
+
# via datasets
|
| 226 |
+
pydantic==2.10.3
|
| 227 |
+
# via
|
| 228 |
+
# confection
|
| 229 |
+
# spacy
|
| 230 |
+
# thinc
|
| 231 |
+
# wandb
|
| 232 |
+
# weasel
|
| 233 |
+
pydantic-core==2.27.1
|
| 234 |
+
# via pydantic
|
| 235 |
+
pygments==2.18.0
|
| 236 |
+
# via rich
|
| 237 |
+
pysmartdl==1.3.4
|
| 238 |
+
# via symusic
|
| 239 |
+
python-dateutil==2.9.0.post0
|
| 240 |
+
# via pandas
|
| 241 |
+
pytz==2024.2
|
| 242 |
+
# via pandas
|
| 243 |
+
pyyaml==6.0.2
|
| 244 |
+
# via
|
| 245 |
+
# -r requirements.in
|
| 246 |
+
# accelerate
|
| 247 |
+
# datasets
|
| 248 |
+
# huggingface-hub
|
| 249 |
+
# transformers
|
| 250 |
+
# wandb
|
| 251 |
+
regex==2024.11.6
|
| 252 |
+
# via transformers
|
| 253 |
+
requests==2.32.3
|
| 254 |
+
# via
|
| 255 |
+
# datasets
|
| 256 |
+
# evaluate
|
| 257 |
+
# huggingface-hub
|
| 258 |
+
# spacy
|
| 259 |
+
# transformers
|
| 260 |
+
# wandb
|
| 261 |
+
# weasel
|
| 262 |
+
rich==13.9.4
|
| 263 |
+
# via typer
|
| 264 |
+
safetensors==0.4.5
|
| 265 |
+
# via
|
| 266 |
+
# accelerate
|
| 267 |
+
# transformers
|
| 268 |
+
sentry-sdk==2.19.2
|
| 269 |
+
# via wandb
|
| 270 |
+
sentencepiece==0.2.0
|
| 271 |
+
|
| 272 |
+
setproctitle==1.3.4
|
| 273 |
+
# via wandb
|
| 274 |
+
shellingham==1.5.4
|
| 275 |
+
# via typer
|
| 276 |
+
six==1.17.0
|
| 277 |
+
# via
|
| 278 |
+
# docker-pycreds
|
| 279 |
+
# python-dateutil
|
| 280 |
+
smart-open==7.0.5
|
| 281 |
+
# via weasel
|
| 282 |
+
smmap==5.0.1
|
| 283 |
+
# via gitdb
|
| 284 |
+
spacy==3.8.2
|
| 285 |
+
# via -r requirements.in
|
| 286 |
+
spacy-legacy==3.0.12
|
| 287 |
+
# via spacy
|
| 288 |
+
spacy-loggers==1.0.5
|
| 289 |
+
# via spacy
|
| 290 |
+
srsly==2.4.8
|
| 291 |
+
# via
|
| 292 |
+
# confection
|
| 293 |
+
# spacy
|
| 294 |
+
# thinc
|
| 295 |
+
# weasel
|
| 296 |
+
st-moe-pytorch==0.1.8
|
| 297 |
+
# via -r requirements.in
|
| 298 |
+
sympy==1.13.1
|
| 299 |
+
# via torch
|
| 300 |
+
symusic==0.5.5
|
| 301 |
+
# via miditok
|
| 302 |
+
thinc==8.3.2
|
| 303 |
+
# via spacy
|
| 304 |
+
tokenizers==0.21.0
|
| 305 |
+
# via
|
| 306 |
+
# miditok
|
| 307 |
+
# transformers
|
| 308 |
+
torch==2.5.1
|
| 309 |
+
# via
|
| 310 |
+
# -r requirements.in
|
| 311 |
+
# accelerate
|
| 312 |
+
# colt5-attention
|
| 313 |
+
# local-attention
|
| 314 |
+
# st-moe-pytorch
|
| 315 |
+
tqdm==4.67.1
|
| 316 |
+
# via
|
| 317 |
+
# -r requirements.in
|
| 318 |
+
# datasets
|
| 319 |
+
# evaluate
|
| 320 |
+
# huggingface-hub
|
| 321 |
+
# miditok
|
| 322 |
+
# spacy
|
| 323 |
+
# transformers
|
| 324 |
+
transformers==4.47.0
|
| 325 |
+
# via -r requirements.in
|
| 326 |
+
triton==3.1.0
|
| 327 |
+
# via torch
|
| 328 |
+
typer==0.15.1
|
| 329 |
+
# via
|
| 330 |
+
# spacy
|
| 331 |
+
# weasel
|
| 332 |
+
typing-extensions==4.12.2
|
| 333 |
+
# via
|
| 334 |
+
# cloudpathlib
|
| 335 |
+
# huggingface-hub
|
| 336 |
+
# multidict
|
| 337 |
+
# pydantic
|
| 338 |
+
# pydantic-core
|
| 339 |
+
# rich
|
| 340 |
+
# torch
|
| 341 |
+
# typer
|
| 342 |
+
# wandb
|
| 343 |
+
tzdata==2024.2
|
| 344 |
+
# via pandas
|
| 345 |
+
urllib3==2.2.3
|
| 346 |
+
# via
|
| 347 |
+
# requests
|
| 348 |
+
# sentry-sdk
|
| 349 |
+
wandb==0.19.0
|
| 350 |
+
# via -r requirements.in
|
| 351 |
+
wasabi==1.1.3
|
| 352 |
+
# via
|
| 353 |
+
# spacy
|
| 354 |
+
# thinc
|
| 355 |
+
# weasel
|
| 356 |
+
weasel==0.4.1
|
| 357 |
+
# via spacy
|
| 358 |
+
wrapt==1.17.0
|
| 359 |
+
# via smart-open
|
| 360 |
+
xxhash==3.5.0
|
| 361 |
+
# via
|
| 362 |
+
# datasets
|
| 363 |
+
# evaluate
|
| 364 |
+
yarl==1.18.3
|
| 365 |
+
# via aiohttp
|
| 366 |
+
|
| 367 |
+
# The following packages are considered to be unsafe in a requirements file:
|
| 368 |
+
# setuptools
|
text2midi_repo/text2midi_architecture.jpg
ADDED
|
Git LFS Details
|
text2midi_repo/utils/midi_to_wav.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from multiprocessing import Pool, cpu_count
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
soundfont_filepath = "/root/soundfont/soundfont.sf"
|
| 7 |
+
|
| 8 |
+
def save_wav(midi_filepath, wav_filepath):
|
| 9 |
+
# Check if the .wav file already exists
|
| 10 |
+
if os.path.isfile(wav_filepath):
|
| 11 |
+
print(f"{wav_filepath} already exists, skipping")
|
| 12 |
+
return wav_filepath
|
| 13 |
+
else:
|
| 14 |
+
print(f"Creating {wav_filepath} from {midi_filepath}")
|
| 15 |
+
|
| 16 |
+
# Run the fluidsynth command to convert MIDI to WAV
|
| 17 |
+
command = f"fluidsynth -r 48000 {soundfont_filepath} -g 1.0 --quiet --no-shell {midi_filepath} -T wav -F {wav_filepath}"
|
| 18 |
+
print(f"Running command: {command}")
|
| 19 |
+
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 20 |
+
stdout, stderr = process.communicate()
|
| 21 |
+
|
| 22 |
+
if process.returncode != 0:
|
| 23 |
+
print(f"Error converting {midi_filepath} to {wav_filepath}: {stderr.decode('utf-8')}")
|
| 24 |
+
else:
|
| 25 |
+
print(f"Successfully created {wav_filepath}")
|
| 26 |
+
|
| 27 |
+
return wav_filepath
|
| 28 |
+
|
| 29 |
+
def process_midi_file(midi_filepath):
|
| 30 |
+
# Determine the corresponding wav file path
|
| 31 |
+
relative_path = os.path.relpath(midi_filepath, "/root/Text2midi/res_acc")
|
| 32 |
+
wav_filepath = os.path.join("/root/wav", relative_path.replace('.mid', '.wav'))
|
| 33 |
+
wav_directory = os.path.dirname(wav_filepath)
|
| 34 |
+
|
| 35 |
+
# Ensure the directory exists
|
| 36 |
+
os.makedirs(wav_directory, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# Convert the MIDI file to WAV
|
| 39 |
+
save_wav(midi_filepath, wav_filepath)
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
# Find all .mid files in /root/Text2midi/res_acc
|
| 43 |
+
midi_files = []
|
| 44 |
+
for root, _, files in os.walk("/root/Text2midi/res_acc"):
|
| 45 |
+
for file in files:
|
| 46 |
+
if file.endswith(".mid"):
|
| 47 |
+
midi_files.append(os.path.join(root, file))
|
| 48 |
+
|
| 49 |
+
# Use half of the available CPU cores for multiprocessing
|
| 50 |
+
num_cores = cpu_count() // 2
|
| 51 |
+
with Pool(num_cores) as pool:
|
| 52 |
+
list(tqdm(pool.imap(process_midi_file, midi_files), total=len(midi_files), desc="Processing MIDI files"))
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
main()
|
text2midi_repo/utils/split_caption.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import jsonlines
|
| 5 |
+
|
| 6 |
+
def select_and_split_captions(input_path, output_dir, num_splits=6):
|
| 7 |
+
with jsonlines.open(input_path) as reader:
|
| 8 |
+
captions = [line for line in reader if line.get('test_set') is True]
|
| 9 |
+
|
| 10 |
+
selected_captions = captions #random.sample(captions, 500)
|
| 11 |
+
|
| 12 |
+
# Split the selected captions into num_splits groups
|
| 13 |
+
split_size = len(selected_captions) // num_splits
|
| 14 |
+
for i in range(num_splits):
|
| 15 |
+
start_idx = i * split_size
|
| 16 |
+
end_idx = (i + 1) * split_size if i != num_splits - 1 else len(selected_captions)
|
| 17 |
+
split_captions = selected_captions[start_idx:end_idx]
|
| 18 |
+
|
| 19 |
+
output_path = os.path.join(output_dir, f'selected_captions_{i}.json')
|
| 20 |
+
with open(output_path, 'w') as f:
|
| 21 |
+
json.dump(split_captions, f, indent=4)
|
| 22 |
+
print(f'Saved {len(split_captions)} captions to {output_path}')
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
input_path = '/root/captions/train.json'
|
| 26 |
+
output_dir = '/root/captions/'
|
| 27 |
+
select_and_split_captions(input_path, output_dir)
|