Upload 111 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- preprocess/README.md +155 -0
- preprocess/pipeline.py +161 -0
- preprocess/requirements.txt +34 -0
- preprocess/tools/__init__.py +53 -0
- preprocess/tools/f0_extraction.py +527 -0
- preprocess/tools/g2p.py +72 -0
- preprocess/tools/lyric_transcription.py +283 -0
- preprocess/tools/midi_editor/README.md +170 -0
- preprocess/tools/midi_editor/README_CN.md +170 -0
- preprocess/tools/midi_editor/eslint.config.js +23 -0
- preprocess/tools/midi_editor/index.html +13 -0
- preprocess/tools/midi_editor/package-lock.json +0 -0
- preprocess/tools/midi_editor/package.json +39 -0
- preprocess/tools/midi_editor/postcss.config.js +6 -0
- preprocess/tools/midi_editor/public/vite.svg +1 -0
- preprocess/tools/midi_editor/src/App.css +834 -0
- preprocess/tools/midi_editor/src/App.tsx +675 -0
- preprocess/tools/midi_editor/src/components/AudioTrack.tsx +182 -0
- preprocess/tools/midi_editor/src/components/LyricTable.tsx +301 -0
- preprocess/tools/midi_editor/src/components/PianoRoll.tsx +704 -0
- preprocess/tools/midi_editor/src/constants.ts +8 -0
- preprocess/tools/midi_editor/src/i18n.ts +196 -0
- preprocess/tools/midi_editor/src/index.css +37 -0
- preprocess/tools/midi_editor/src/lib/midi.ts +224 -0
- preprocess/tools/midi_editor/src/main.tsx +10 -0
- preprocess/tools/midi_editor/src/store/useMidiStore.ts +78 -0
- preprocess/tools/midi_editor/src/types.ts +17 -0
- preprocess/tools/midi_editor/tailwind.config.js +33 -0
- preprocess/tools/midi_editor/tsconfig.app.json +28 -0
- preprocess/tools/midi_editor/tsconfig.json +7 -0
- preprocess/tools/midi_editor/tsconfig.node.json +26 -0
- preprocess/tools/midi_editor/vite.config.ts +7 -0
- preprocess/tools/midi_parser.py +598 -0
- preprocess/tools/note_transcription/__init__.py +0 -0
- preprocess/tools/note_transcription/model.py +531 -0
- preprocess/tools/note_transcription/modules/__init__.py +1 -0
- preprocess/tools/note_transcription/modules/commons/__init__.py +1 -0
- preprocess/tools/note_transcription/modules/commons/conformer/__init__.py +1 -0
- preprocess/tools/note_transcription/modules/commons/conformer/conformer.py +96 -0
- preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py +113 -0
- preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py +198 -0
- preprocess/tools/note_transcription/modules/commons/conformer/layers.py +260 -0
- preprocess/tools/note_transcription/modules/commons/conv.py +175 -0
- preprocess/tools/note_transcription/modules/commons/layers.py +85 -0
- preprocess/tools/note_transcription/modules/commons/rel_transformer.py +378 -0
- preprocess/tools/note_transcription/modules/commons/rnn.py +261 -0
- preprocess/tools/note_transcription/modules/commons/transformer.py +751 -0
- preprocess/tools/note_transcription/modules/commons/wavenet.py +109 -0
- preprocess/tools/note_transcription/modules/pe/__init__.py +1 -0
- preprocess/tools/note_transcription/modules/pe/rmvpe/__init__.py +6 -0
preprocess/README.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎵 SoulX-Singer-Preprocess
|
| 2 |
+
|
| 3 |
+
This part offers a comprehensive **singing transcription and editing toolkit** for real-world music audio. It provides the pipeline from vocal extraction to high-level annotation optimized for SVS dataset construction. By integrating state-of-the-art models, it transforms raw audio into structured singing data and supports the **customizable creation and editing of lyric-aligned MIDI scores**.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## ✨ Features
|
| 7 |
+
|
| 8 |
+
The toolkit includes the following core modules:
|
| 9 |
+
|
| 10 |
+
- 🎤 **Clean Dry Vocal Extraction**
|
| 11 |
+
Extracts the lead vocal track from polyphonic music audio and dereverberation.
|
| 12 |
+
|
| 13 |
+
- 📝 **Lyrics Transcription**
|
| 14 |
+
Automatically transcribes lyrics from clean vocal.
|
| 15 |
+
|
| 16 |
+
- 🎶 **Note Transcription**
|
| 17 |
+
Converts singing voice into note-level representations for SVS.
|
| 18 |
+
|
| 19 |
+
- 🎼 **MIDI Editor**
|
| 20 |
+
Supports customizable creation and editing of MIDI scores integrated with lyrics.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
## 🔧 Python Environment
|
| 24 |
+
|
| 25 |
+
Before running the pipeline, set up the Python environment as follows:
|
| 26 |
+
|
| 27 |
+
1. **Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
|
| 28 |
+
|
| 29 |
+
2. **Activate or create a conda environment** (recommended Python 3.10):
|
| 30 |
+
|
| 31 |
+
- If you already have the `soulxsinger` environment:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
conda activate soulxsinger
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
- Otherwise, create it first:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
conda create -n soulxsinger -y python=3.10
|
| 41 |
+
conda activate soulxsinger
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
3. **Install dependencies** from the `preprocess` directory:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
cd preprocess
|
| 48 |
+
pip install -r requirements.txt
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## 📁 Data Preparation
|
| 52 |
+
|
| 53 |
+
Before running the pipeline, prepare the following inputs:
|
| 54 |
+
|
| 55 |
+
- **Prompt audio**
|
| 56 |
+
Reference audio that provides timbre and style
|
| 57 |
+
|
| 58 |
+
- **Target audio**
|
| 59 |
+
Original vocal or music audio to be processed and transcribed.
|
| 60 |
+
|
| 61 |
+
Configure the corresponding parameters in:
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
example/preprocess.sh
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Typical configuration includes:
|
| 68 |
+
- Input / output paths
|
| 69 |
+
- Module enable switches
|
| 70 |
+
|
| 71 |
+
## 🚀 Usage
|
| 72 |
+
|
| 73 |
+
After configuring `preprocess.sh`, run the transcription pipeline with:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
bash example/preprocess.sh
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
The script will automatically execute the following steps:
|
| 80 |
+
|
| 81 |
+
1. **Vocal separation and dereverberation**
|
| 82 |
+
2. **F0 extraction and voice activity detection (VAD)**
|
| 83 |
+
3. **Lyrics transcription**
|
| 84 |
+
4. **Note transcription**
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
After the pipeline completes, you will obtain **SoulX-Singer–style metadata** that can be directly used for Singing Voice Synthesis (SVS).
|
| 89 |
+
|
| 90 |
+
**Output paths:**
|
| 91 |
+
- The final metadata (**JSON file**) is written **in the same directory as your input audio**, with the **same filename** (e.g. `audio.mp3` → `audio.json`)
|
| 92 |
+
- All **intermediate results** (separated vocal and accompaniment, F0, VAD outputs, etc.) are also saved under the configured **`save_dir`**.
|
| 93 |
+
|
| 94 |
+
⚠️ **Important Note**
|
| 95 |
+
|
| 96 |
+
Transcription errors—especially in **lyrics** and **note annotations**—can significantly affect the final SVS quality. We **strongly recommend manually reviewing and correcting** the generated metadata before inference.
|
| 97 |
+
|
| 98 |
+
To support this, we provide a **MIDI Editor** for editing lyrics, phoneme alignment, note pitches, and durations. The workflow is:
|
| 99 |
+
|
| 100 |
+
**Export metadata to MIDI** → edit in the MIDI Editor → **Import edited MIDI back to metadata** for SVS.
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
#### Step 1: Metadata → MIDI (for editing)
|
| 105 |
+
|
| 106 |
+
Convert SoulX-Singer metadata to a MIDI file so you can open it in the MIDI Editor:
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
preprocess_root=example/transcriptions/music
|
| 110 |
+
|
| 111 |
+
python -m preprocess.tools.midi_parser \
|
| 112 |
+
--meta2midi \
|
| 113 |
+
--meta "${preprocess_root}/metadata.json" \
|
| 114 |
+
--midi "${preprocess_root}/vocal.mid"
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
#### Step 2: Edit in the MIDI Editor
|
| 118 |
+
|
| 119 |
+
Open the MIDI Editor (see [MIDI Editor Tutorial](tools/midi_editor/README.md)), load `vocal.mid`, and correct lyrics, pitches, or durations as needed. Save the result as e.g. `vocal_edited.mid`.
|
| 120 |
+
|
| 121 |
+
#### Step 3: MIDI → Metadata (for SoulX-Singer inference)
|
| 122 |
+
|
| 123 |
+
Convert the edited MIDI back into SoulX-Singer-style metadata (and cut wavs) for SVS:
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
python -m preprocess.tools.midi_parser \
|
| 127 |
+
--midi2meta \
|
| 128 |
+
--midi "${preprocess_root}/vocal_edited.mid" \
|
| 129 |
+
--meta "${preprocess_root}/edit_metadata.json" \
|
| 130 |
+
--vocal "${preprocess_root}/vocal.wav" \
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
Use `edit_metadata.json` (and the wavs under `edit_cut_wavs`) as the target metadata in your inference pipeline.
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## 🔗 References & Dependencies
|
| 137 |
+
|
| 138 |
+
This project builds upon the following excellent open-source works:
|
| 139 |
+
|
| 140 |
+
### 🎧 Vocal Separation & Dereverberation
|
| 141 |
+
- [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
|
| 142 |
+
- [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
|
| 143 |
+
- [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
|
| 144 |
+
|
| 145 |
+
### 🎼 F0 Extraction
|
| 146 |
+
- [RMVPE](https://github.com/Dream-High/RMVPE)
|
| 147 |
+
|
| 148 |
+
### 📝 Lyrics Transcription (ASR)
|
| 149 |
+
- [Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
|
| 150 |
+
- [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
|
| 151 |
+
|
| 152 |
+
### 🎶 Note Transcription
|
| 153 |
+
- [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
|
| 154 |
+
|
| 155 |
+
We sincerely thank the authors of these repositories for their exceptional open-source contributions, which have been fundamental to the development of this toolkit.
|
preprocess/pipeline.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import shutil
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import librosa
|
| 6 |
+
|
| 7 |
+
from preprocess.utils import convert_metadata, merge_short_segments
|
| 8 |
+
|
| 9 |
+
from preprocess.tools import (
|
| 10 |
+
F0Extractor,
|
| 11 |
+
VocalDetector,
|
| 12 |
+
VocalSeparator,
|
| 13 |
+
NoteTranscriber,
|
| 14 |
+
LyricTranscriber,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PreprocessPipeline:
|
| 19 |
+
def __init__(self, device: str, language: str, save_dir: str, vocal_sep: bool = True, max_merge_duration: int = 60000, midi_transcribe: bool = True):
|
| 20 |
+
self.device = device
|
| 21 |
+
self.language = language
|
| 22 |
+
self.save_dir = save_dir
|
| 23 |
+
self.vocal_sep = vocal_sep
|
| 24 |
+
self.max_merge_duration = max_merge_duration
|
| 25 |
+
self.midi_transcribe = midi_transcribe
|
| 26 |
+
|
| 27 |
+
if vocal_sep:
|
| 28 |
+
self.vocal_separator = VocalSeparator(
|
| 29 |
+
sep_model_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt",
|
| 30 |
+
sep_config_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/config_karaoke_becruily.yaml",
|
| 31 |
+
der_model_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
|
| 32 |
+
der_config_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml",
|
| 33 |
+
device=device
|
| 34 |
+
)
|
| 35 |
+
else:
|
| 36 |
+
self.vocal_separator = None
|
| 37 |
+
self.f0_extractor = F0Extractor(
|
| 38 |
+
model_path="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
|
| 39 |
+
device=device,
|
| 40 |
+
)
|
| 41 |
+
if self.midi_transcribe:
|
| 42 |
+
self.vocal_detector = VocalDetector(
|
| 43 |
+
cut_wavs_output_dir= f"{save_dir}/cut_wavs",
|
| 44 |
+
)
|
| 45 |
+
self.lyric_transcriber = LyricTranscriber(
|
| 46 |
+
zh_model_path="pretrained_models/SoulX-Singer-Preprocess/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 47 |
+
en_model_path="pretrained_models/SoulX-Singer-Preprocess/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
|
| 48 |
+
device=device
|
| 49 |
+
)
|
| 50 |
+
self.note_transcriber = NoteTranscriber(
|
| 51 |
+
rosvot_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rosvot/model.pt",
|
| 52 |
+
rwbd_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rwbd/model.pt",
|
| 53 |
+
device=device
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
self.vocal_detector = None
|
| 57 |
+
self.lyric_transcriber = None
|
| 58 |
+
self.note_transcriber = None
|
| 59 |
+
|
| 60 |
+
def run(
|
| 61 |
+
self,
|
| 62 |
+
audio_path: str,
|
| 63 |
+
vocal_sep: bool = None,
|
| 64 |
+
max_merge_duration: int = None,
|
| 65 |
+
language: str = None,
|
| 66 |
+
) -> None:
|
| 67 |
+
vocal_sep = self.vocal_sep if vocal_sep is None else vocal_sep
|
| 68 |
+
max_merge_duration = self.max_merge_duration if max_merge_duration is None else max_merge_duration
|
| 69 |
+
language = self.language if language is None else language
|
| 70 |
+
output_dir = Path(self.save_dir)
|
| 71 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
if vocal_sep:
|
| 74 |
+
# Perform vocal/accompaniment separation
|
| 75 |
+
sep = self.vocal_separator.process(audio_path)
|
| 76 |
+
vocal = sep.vocals_dereverbed.T
|
| 77 |
+
acc = sep.accompaniment.T
|
| 78 |
+
sample_rate = sep.sample_rate
|
| 79 |
+
|
| 80 |
+
vocal_path = output_dir / "vocal.wav"
|
| 81 |
+
acc_path = output_dir / "acc.wav"
|
| 82 |
+
sf.write(vocal_path, vocal, sample_rate)
|
| 83 |
+
sf.write(acc_path, acc, sample_rate)
|
| 84 |
+
else:
|
| 85 |
+
# Use the original audio as vocal source (no separation)
|
| 86 |
+
vocal, sample_rate = librosa.load(audio_path, sr=None, mono=True)
|
| 87 |
+
vocal_path = output_dir / "vocal.wav"
|
| 88 |
+
sf.write(vocal_path, vocal, sample_rate)
|
| 89 |
+
|
| 90 |
+
vocal_f0 = self.f0_extractor.process(str(vocal_path), f0_path=str(vocal_path).replace(".wav", "_f0.npy"))
|
| 91 |
+
|
| 92 |
+
if not self.midi_transcribe or self.vocal_detector is None or self.lyric_transcriber is None or self.note_transcriber is None:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
segments = self.vocal_detector.process(str(vocal_path), f0=vocal_f0)
|
| 96 |
+
|
| 97 |
+
metadata = []
|
| 98 |
+
for seg in segments:
|
| 99 |
+
self.f0_extractor.process(seg["wav_fn"], f0_path=seg["wav_fn"].replace(".wav", "_f0.npy"))
|
| 100 |
+
words, durs = self.lyric_transcriber.process(
|
| 101 |
+
seg["wav_fn"], language
|
| 102 |
+
)
|
| 103 |
+
seg["words"] = words
|
| 104 |
+
seg["word_durs"] = durs
|
| 105 |
+
seg["language"] = language
|
| 106 |
+
metadata.append(
|
| 107 |
+
self.note_transcriber.process(seg, segment_info=seg)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
merged = merge_short_segments(
|
| 111 |
+
vocal,
|
| 112 |
+
sample_rate,
|
| 113 |
+
metadata,
|
| 114 |
+
output_dir / "long_cut_wavs",
|
| 115 |
+
max_duration_ms=max_merge_duration,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
final_metadata = []
|
| 119 |
+
|
| 120 |
+
for item in merged:
|
| 121 |
+
self.f0_extractor.process(item.wav_fn, f0_path=item.wav_fn.replace(".wav", "_f0.npy"))
|
| 122 |
+
final_metadata.append(convert_metadata(item))
|
| 123 |
+
|
| 124 |
+
with open(output_dir / "metadata.json", "w", encoding="utf-8") as f:
|
| 125 |
+
json.dump(final_metadata, f, ensure_ascii=False, indent=2)
|
| 126 |
+
|
| 127 |
+
shutil.copy(output_dir / "metadata.json", audio_path.replace(".wav", ".json").replace(".mp3", ".json").replace(".flac", ".json"))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main(args):
|
| 131 |
+
pipeline = PreprocessPipeline(
|
| 132 |
+
device=args.device,
|
| 133 |
+
language=args.language,
|
| 134 |
+
save_dir=args.save_dir,
|
| 135 |
+
vocal_sep=args.vocal_sep,
|
| 136 |
+
max_merge_duration=args.max_merge_duration,
|
| 137 |
+
midi_transcribe=args.midi_transcribe,
|
| 138 |
+
)
|
| 139 |
+
pipeline.run(
|
| 140 |
+
audio_path=args.audio_path,
|
| 141 |
+
language=args.language,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
import argparse
|
| 147 |
+
|
| 148 |
+
parser = argparse.ArgumentParser()
|
| 149 |
+
parser.add_argument("--audio_path", type=str, required=True, help="Path to the input audio file")
|
| 150 |
+
parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the output files")
|
| 151 |
+
parser.add_argument("--language", type=str, default="Mandarin", help="Language of the audio")
|
| 152 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the models on")
|
| 153 |
+
parser.add_argument("--vocal_sep", type=str, default="True", help="Whether to perform vocal separation")
|
| 154 |
+
parser.add_argument("--max_merge_duration", type=int, default=60000, help="Maximum merged segment duration in milliseconds")
|
| 155 |
+
parser.add_argument("--midi_transcribe", type=str, default="True", help="Whether to do MIDI transcription")
|
| 156 |
+
args = parser.parse_args()
|
| 157 |
+
|
| 158 |
+
args.vocal_sep = args.vocal_sep.lower() == "true"
|
| 159 |
+
args.midi_transcribe = args.midi_transcribe.lower() == "true"
|
| 160 |
+
|
| 161 |
+
main(args)
|
preprocess/requirements.txt
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beartype==0.22.9
|
| 2 |
+
einops==0.8.2
|
| 3 |
+
funasr==1.3.0
|
| 4 |
+
g2p_en==2.1.0
|
| 5 |
+
g2pM==0.1.2.5
|
| 6 |
+
librosa==0.11.0
|
| 7 |
+
loralib==0.1.2
|
| 8 |
+
matplotlib==3.10.8
|
| 9 |
+
mido==1.3.3
|
| 10 |
+
ml_collections==1.1.0
|
| 11 |
+
nemo_toolkit==2.6.1
|
| 12 |
+
nltk==3.9.2
|
| 13 |
+
numba==0.63.1
|
| 14 |
+
numpy==2.2.6
|
| 15 |
+
omegaconf==2.3.0
|
| 16 |
+
packaging==24.2
|
| 17 |
+
praat-parselmouth==0.4.7
|
| 18 |
+
pretty_midi==0.2.11
|
| 19 |
+
pyloudnorm==0.2.0
|
| 20 |
+
pyworld==0.3.5
|
| 21 |
+
rotary_embedding_torch==0.8.9
|
| 22 |
+
sageattention==1.0.6
|
| 23 |
+
scikit_learn==1.7.2
|
| 24 |
+
scipy==1.15.3
|
| 25 |
+
six==1.17.0
|
| 26 |
+
setuptools==81.0.0
|
| 27 |
+
scikit_image==0.25.2
|
| 28 |
+
soundfile==0.13.1
|
| 29 |
+
ToJyutping==3.2.0
|
| 30 |
+
torch==2.10.0
|
| 31 |
+
torchaudio==2.10.0
|
| 32 |
+
tqdm==4.67.1
|
| 33 |
+
wandb==0.24.2
|
| 34 |
+
webrtcvad==2.0.10
|
preprocess/tools/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocess tools.
|
| 2 |
+
|
| 3 |
+
This package provides a thin, stable import surface for common preprocess components.
|
| 4 |
+
|
| 5 |
+
Examples:
|
| 6 |
+
from preprocess.tools import (
|
| 7 |
+
F0Extractor,
|
| 8 |
+
PitchExtractor,
|
| 9 |
+
VocalDetectionModel,
|
| 10 |
+
VocalSeparationModel,
|
| 11 |
+
VocalExtractionModel,
|
| 12 |
+
NoteTranscriptionModel,
|
| 13 |
+
LyricTranscriptionModel,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
Note:
|
| 17 |
+
Keep these imports lightweight. If a tool pulls heavy dependencies at import time,
|
| 18 |
+
consider switching to lazy imports.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
# Core tools
|
| 24 |
+
from .f0_extraction import F0Extractor
|
| 25 |
+
from .vocal_detection import VocalDetector
|
| 26 |
+
|
| 27 |
+
# Some tools may live outside this package in different layouts across branches.
|
| 28 |
+
# Keep the public surface stable while avoiding hard import failures.
|
| 29 |
+
try:
|
| 30 |
+
from .vocal_separation.model import VocalSeparator # type: ignore
|
| 31 |
+
except Exception: # pragma: no cover
|
| 32 |
+
VocalSeparator = None # type: ignore
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from .note_transcription.model import NoteTranscriber # type: ignore
|
| 36 |
+
except Exception: # pragma: no cover
|
| 37 |
+
NoteTranscriber = None # type: ignore
|
| 38 |
+
try:
|
| 39 |
+
from .lyric_transcription import LyricTranscriber
|
| 40 |
+
except Exception: # pragma: no cover
|
| 41 |
+
LyricTranscriber = None # type: ignore
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"F0Extractor",
|
| 45 |
+
"VocalDetector",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
if VocalSeparator is not None:
|
| 49 |
+
__all__.append("VocalSeparator")
|
| 50 |
+
if LyricTranscriber is not None:
|
| 51 |
+
__all__.append("LyricTranscriber")
|
| 52 |
+
if NoteTranscriber is not None:
|
| 53 |
+
__all__.append("NoteTranscriber")
|
preprocess/tools/f0_extraction.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/Dream-High/RMVPE
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
import librosa
|
| 5 |
+
import numpy as np
|
| 6 |
+
from librosa.filters import mel
|
| 7 |
+
from scipy.interpolate import interp1d
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BiGRU(nn.Module):
|
| 17 |
+
def __init__(self, input_features, hidden_features, num_layers):
|
| 18 |
+
super(BiGRU, self).__init__()
|
| 19 |
+
self.gru = nn.GRU(
|
| 20 |
+
input_features,
|
| 21 |
+
hidden_features,
|
| 22 |
+
num_layers=num_layers,
|
| 23 |
+
batch_first=True,
|
| 24 |
+
bidirectional=True,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return self.gru(x)[0]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ConvBlockRes(nn.Module):
|
| 32 |
+
def __init__(self, in_channels, out_channels, momentum=0.01):
|
| 33 |
+
super(ConvBlockRes, self).__init__()
|
| 34 |
+
self.conv = nn.Sequential(
|
| 35 |
+
nn.Conv2d(
|
| 36 |
+
in_channels=in_channels,
|
| 37 |
+
out_channels=out_channels,
|
| 38 |
+
kernel_size=(3, 3),
|
| 39 |
+
stride=(1, 1),
|
| 40 |
+
padding=(1, 1),
|
| 41 |
+
bias=False,
|
| 42 |
+
),
|
| 43 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 44 |
+
nn.ReLU(),
|
| 45 |
+
nn.Conv2d(
|
| 46 |
+
in_channels=out_channels,
|
| 47 |
+
out_channels=out_channels,
|
| 48 |
+
kernel_size=(3, 3),
|
| 49 |
+
stride=(1, 1),
|
| 50 |
+
padding=(1, 1),
|
| 51 |
+
bias=False,
|
| 52 |
+
),
|
| 53 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 54 |
+
nn.ReLU(),
|
| 55 |
+
)
|
| 56 |
+
if in_channels != out_channels:
|
| 57 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
if not hasattr(self, "shortcut"):
|
| 61 |
+
return self.conv(x) + x
|
| 62 |
+
else:
|
| 63 |
+
return self.conv(x) + self.shortcut(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ResEncoderBlock(nn.Module):
|
| 67 |
+
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
| 68 |
+
super(ResEncoderBlock, self).__init__()
|
| 69 |
+
self.n_blocks = n_blocks
|
| 70 |
+
self.conv = nn.ModuleList()
|
| 71 |
+
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
| 72 |
+
for i in range(n_blocks - 1):
|
| 73 |
+
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 74 |
+
self.kernel_size = kernel_size
|
| 75 |
+
if self.kernel_size is not None:
|
| 76 |
+
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
for conv in self.conv:
|
| 80 |
+
x = conv(x)
|
| 81 |
+
if self.kernel_size is not None:
|
| 82 |
+
return x, self.pool(x)
|
| 83 |
+
else:
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Encoder(nn.Module):
|
| 88 |
+
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
| 89 |
+
super(Encoder, self).__init__()
|
| 90 |
+
self.n_encoders = n_encoders
|
| 91 |
+
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| 92 |
+
self.layers = nn.ModuleList()
|
| 93 |
+
self.latent_channels = []
|
| 94 |
+
for i in range(self.n_encoders):
|
| 95 |
+
self.layers.append(
|
| 96 |
+
ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)
|
| 97 |
+
)
|
| 98 |
+
self.latent_channels.append([out_channels, in_size])
|
| 99 |
+
in_channels = out_channels
|
| 100 |
+
out_channels *= 2
|
| 101 |
+
in_size //= 2
|
| 102 |
+
self.out_size = in_size
|
| 103 |
+
self.out_channel = out_channels
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
concat_tensors = []
|
| 107 |
+
x = self.bn(x)
|
| 108 |
+
for layer in self.layers:
|
| 109 |
+
t, x = layer(x)
|
| 110 |
+
concat_tensors.append(t)
|
| 111 |
+
return x, concat_tensors
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Intermediate(nn.Module):
|
| 115 |
+
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
| 116 |
+
super(Intermediate, self).__init__()
|
| 117 |
+
self.n_inters = n_inters
|
| 118 |
+
self.layers = nn.ModuleList()
|
| 119 |
+
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
| 120 |
+
for i in range(self.n_inters - 1):
|
| 121 |
+
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
for layer in self.layers:
|
| 125 |
+
x = layer(x)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ResDecoderBlock(nn.Module):
|
| 130 |
+
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
| 131 |
+
super(ResDecoderBlock, self).__init__()
|
| 132 |
+
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
| 133 |
+
self.n_blocks = n_blocks
|
| 134 |
+
self.conv1 = nn.Sequential(
|
| 135 |
+
nn.ConvTranspose2d(
|
| 136 |
+
in_channels=in_channels,
|
| 137 |
+
out_channels=out_channels,
|
| 138 |
+
kernel_size=(3, 3),
|
| 139 |
+
stride=stride,
|
| 140 |
+
padding=(1, 1),
|
| 141 |
+
output_padding=out_padding,
|
| 142 |
+
bias=False,
|
| 143 |
+
),
|
| 144 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 145 |
+
nn.ReLU(),
|
| 146 |
+
)
|
| 147 |
+
self.conv2 = nn.ModuleList()
|
| 148 |
+
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
| 149 |
+
for i in range(n_blocks - 1):
|
| 150 |
+
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 151 |
+
|
| 152 |
+
def forward(self, x, concat_tensor):
|
| 153 |
+
x = self.conv1(x)
|
| 154 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 155 |
+
for conv2 in self.conv2:
|
| 156 |
+
x = conv2(x)
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class Decoder(nn.Module):
|
| 161 |
+
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
| 162 |
+
super(Decoder, self).__init__()
|
| 163 |
+
self.layers = nn.ModuleList()
|
| 164 |
+
self.n_decoders = n_decoders
|
| 165 |
+
for i in range(self.n_decoders):
|
| 166 |
+
out_channels = in_channels // 2
|
| 167 |
+
self.layers.append(
|
| 168 |
+
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
|
| 169 |
+
)
|
| 170 |
+
in_channels = out_channels
|
| 171 |
+
|
| 172 |
+
def forward(self, x, concat_tensors):
|
| 173 |
+
for i, layer in enumerate(self.layers):
|
| 174 |
+
x = layer(x, concat_tensors[-1 - i])
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class DeepUnet(nn.Module):
|
| 179 |
+
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 180 |
+
super(DeepUnet, self).__init__()
|
| 181 |
+
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
| 182 |
+
self.intermediate = Intermediate(
|
| 183 |
+
self.encoder.out_channel // 2,
|
| 184 |
+
self.encoder.out_channel,
|
| 185 |
+
inter_layers,
|
| 186 |
+
n_blocks,
|
| 187 |
+
)
|
| 188 |
+
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
x, concat_tensors = self.encoder(x)
|
| 192 |
+
x = self.intermediate(x)
|
| 193 |
+
x = self.decoder(x, concat_tensors)
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class E2E(nn.Module):
|
| 198 |
+
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 199 |
+
super(E2E, self).__init__()
|
| 200 |
+
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
| 201 |
+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
| 202 |
+
if n_gru:
|
| 203 |
+
self.fc = nn.Sequential(
|
| 204 |
+
BiGRU(3 * 128, 256, n_gru),
|
| 205 |
+
nn.Linear(512, 360),
|
| 206 |
+
nn.Dropout(0.25),
|
| 207 |
+
nn.Sigmoid(),
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
self.fc = nn.Sequential(
|
| 211 |
+
nn.Linear(3 * 128, 360),
|
| 212 |
+
nn.Dropout(0.25),
|
| 213 |
+
nn.Sigmoid()
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward(self, mel):
|
| 217 |
+
mel = mel.transpose(-1, -2).unsqueeze(1)
|
| 218 |
+
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
| 219 |
+
x = self.fc(x)
|
| 220 |
+
return x
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class MelSpectrogram(torch.nn.Module):
|
| 225 |
+
def __init__(self, is_half, n_mel_channels, sampling_rate, win_length, hop_length,
|
| 226 |
+
n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
|
| 227 |
+
super().__init__()
|
| 228 |
+
n_fft = win_length if n_fft is None else n_fft
|
| 229 |
+
self.hann_window = {}
|
| 230 |
+
mel_basis = mel(
|
| 231 |
+
sr=sampling_rate,
|
| 232 |
+
n_fft=n_fft,
|
| 233 |
+
n_mels=n_mel_channels,
|
| 234 |
+
fmin=mel_fmin,
|
| 235 |
+
fmax=mel_fmax,
|
| 236 |
+
htk=True,
|
| 237 |
+
)
|
| 238 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 239 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 240 |
+
self.n_fft = win_length if n_fft is None else n_fft
|
| 241 |
+
self.hop_length = hop_length
|
| 242 |
+
self.win_length = win_length
|
| 243 |
+
self.sampling_rate = sampling_rate
|
| 244 |
+
self.n_mel_channels = n_mel_channels
|
| 245 |
+
self.clamp = clamp
|
| 246 |
+
self.is_half = is_half
|
| 247 |
+
|
| 248 |
+
def forward(self, audio, keyshift=0, speed=1, center=True):
|
| 249 |
+
factor = 2 ** (keyshift / 12)
|
| 250 |
+
n_fft_new = int(np.round(self.n_fft * factor))
|
| 251 |
+
win_length_new = int(np.round(self.win_length * factor))
|
| 252 |
+
hop_length_new = int(np.round(self.hop_length * speed))
|
| 253 |
+
|
| 254 |
+
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
| 255 |
+
if keyshift_key not in self.hann_window:
|
| 256 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
| 257 |
+
|
| 258 |
+
fft = torch.stft(
|
| 259 |
+
audio,
|
| 260 |
+
n_fft=n_fft_new,
|
| 261 |
+
hop_length=hop_length_new,
|
| 262 |
+
win_length=win_length_new,
|
| 263 |
+
window=self.hann_window[keyshift_key],
|
| 264 |
+
center=center,
|
| 265 |
+
return_complex=True,
|
| 266 |
+
)
|
| 267 |
+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
| 268 |
+
|
| 269 |
+
if keyshift != 0:
|
| 270 |
+
size = self.n_fft // 2 + 1
|
| 271 |
+
resize = magnitude.size(1)
|
| 272 |
+
if resize < size:
|
| 273 |
+
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
| 274 |
+
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
| 275 |
+
|
| 276 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
| 277 |
+
if self.is_half:
|
| 278 |
+
mel_output = mel_output.half()
|
| 279 |
+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
| 280 |
+
return log_mel_spec
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class RMVPE:
|
| 285 |
+
def __init__(self, model_path: str, is_half, device=None):
|
| 286 |
+
self.is_half = is_half
|
| 287 |
+
if device is None:
|
| 288 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 289 |
+
self.device = torch.device(device) if isinstance(device, str) else device
|
| 290 |
+
|
| 291 |
+
self.mel_extractor = MelSpectrogram(
|
| 292 |
+
is_half=is_half,
|
| 293 |
+
n_mel_channels=128,
|
| 294 |
+
sampling_rate=16000,
|
| 295 |
+
win_length=1024,
|
| 296 |
+
hop_length=160,
|
| 297 |
+
n_fft=None,
|
| 298 |
+
mel_fmin=30,
|
| 299 |
+
mel_fmax=8000
|
| 300 |
+
).to(self.device)
|
| 301 |
+
|
| 302 |
+
model = E2E(n_blocks=4, n_gru=1, kernel_size=(2, 2))
|
| 303 |
+
ckpt = torch.load(model_path, map_location=self.device)
|
| 304 |
+
model.load_state_dict(ckpt)
|
| 305 |
+
model.eval()
|
| 306 |
+
|
| 307 |
+
if is_half:
|
| 308 |
+
model = model.half()
|
| 309 |
+
else:
|
| 310 |
+
model = model.float()
|
| 311 |
+
|
| 312 |
+
self.model = model.to(self.device)
|
| 313 |
+
|
| 314 |
+
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
| 315 |
+
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
|
| 316 |
+
|
| 317 |
+
def mel2hidden(self, mel):
|
| 318 |
+
with torch.no_grad():
|
| 319 |
+
n_frames = mel.shape[-1]
|
| 320 |
+
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
|
| 321 |
+
if n_pad > 0:
|
| 322 |
+
mel = F.pad(mel, (0, n_pad), mode="constant")
|
| 323 |
+
mel = mel.half() if self.is_half else mel.float()
|
| 324 |
+
hidden = self.model(mel)
|
| 325 |
+
return hidden[:, :n_frames]
|
| 326 |
+
|
| 327 |
+
def decode(self, hidden, thred=0.03):
|
| 328 |
+
cents_pred = self.to_local_average_cents(hidden, thred=thred)
|
| 329 |
+
f0 = 10 * (2 ** (cents_pred / 1200))
|
| 330 |
+
f0[f0 == 10] = 0
|
| 331 |
+
return f0
|
| 332 |
+
|
| 333 |
+
def infer_from_audio(self, audio, thred=0.03):
|
| 334 |
+
if not torch.is_tensor(audio):
|
| 335 |
+
audio = torch.from_numpy(audio)
|
| 336 |
+
|
| 337 |
+
mel = self.mel_extractor(audio.float().to(self.device).unsqueeze(0), center=True)
|
| 338 |
+
hidden = self.mel2hidden(mel)
|
| 339 |
+
hidden = hidden.squeeze(0).cpu().numpy()
|
| 340 |
+
|
| 341 |
+
if self.is_half:
|
| 342 |
+
hidden = hidden.astype("float32")
|
| 343 |
+
|
| 344 |
+
f0 = self.decode(hidden, thred=thred)
|
| 345 |
+
return f0
|
| 346 |
+
|
| 347 |
+
def to_local_average_cents(self, salience, thred=0.05):
|
| 348 |
+
center = np.argmax(salience, axis=1)
|
| 349 |
+
salience = np.pad(salience, ((0, 0), (4, 4)))
|
| 350 |
+
center += 4
|
| 351 |
+
|
| 352 |
+
todo_salience = []
|
| 353 |
+
todo_cents_mapping = []
|
| 354 |
+
starts = center - 4
|
| 355 |
+
ends = center + 5
|
| 356 |
+
|
| 357 |
+
for idx in range(salience.shape[0]):
|
| 358 |
+
todo_salience.append(salience[:, starts[idx]:ends[idx]][idx])
|
| 359 |
+
todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
|
| 360 |
+
|
| 361 |
+
todo_salience = np.array(todo_salience)
|
| 362 |
+
todo_cents_mapping = np.array(todo_cents_mapping)
|
| 363 |
+
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
|
| 364 |
+
weight_sum = np.sum(todo_salience, 1)
|
| 365 |
+
devided = product_sum / weight_sum
|
| 366 |
+
|
| 367 |
+
maxx = np.max(salience, axis=1)
|
| 368 |
+
devided[maxx <= thred] = 0
|
| 369 |
+
|
| 370 |
+
return devided
|
| 371 |
+
|
| 372 |
+
class F0Extractor:
|
| 373 |
+
"""Extract frame-level f0 from singing voice.
|
| 374 |
+
|
| 375 |
+
Wrapper around an RMVPE network that:
|
| 376 |
+
1) loads the checkpoint once in ``__init__``
|
| 377 |
+
2) exposes a simple :py:meth:`process` API and optionally saves ``*_f0.npy``.
|
| 378 |
+
"""
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
model_path: str,
|
| 382 |
+
device: str = "cpu",
|
| 383 |
+
*,
|
| 384 |
+
is_half: bool = False,
|
| 385 |
+
input_sr: int = 16000,
|
| 386 |
+
target_sr: int = 24000,
|
| 387 |
+
hop_size: int = 480,
|
| 388 |
+
max_duration: float = 300,
|
| 389 |
+
thred: float = 0.03,
|
| 390 |
+
verbose: bool = True,
|
| 391 |
+
):
|
| 392 |
+
"""Initialize the f0 extractor.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
model_path: Path to RMVPE checkpoint.
|
| 396 |
+
device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
|
| 397 |
+
is_half: Whether to run the model in fp16.
|
| 398 |
+
input_sr: Input resample rate used by RMVPE frontend.
|
| 399 |
+
target_sr: Target sample rate for the output f0 grid.
|
| 400 |
+
hop_size: Target hop size for the output f0 grid.
|
| 401 |
+
max_duration: Max duration (seconds) for interpolation grid.
|
| 402 |
+
thred: Voicing threshold used when decoding salience.
|
| 403 |
+
verbose: Whether to print verbose logs.
|
| 404 |
+
"""
|
| 405 |
+
self.model_path = model_path
|
| 406 |
+
self.input_sr = input_sr
|
| 407 |
+
self.target_sr = target_sr
|
| 408 |
+
self.hop_size = hop_size
|
| 409 |
+
self.max_duration = max_duration
|
| 410 |
+
self.thred = thred
|
| 411 |
+
|
| 412 |
+
self.verbose = verbose
|
| 413 |
+
|
| 414 |
+
self.model = RMVPE(model_path, is_half=is_half, device=device)
|
| 415 |
+
|
| 416 |
+
if self.verbose:
|
| 417 |
+
print(
|
| 418 |
+
"[f0 extraction] init success:",
|
| 419 |
+
f"device={device}",
|
| 420 |
+
f"model_path={model_path}",
|
| 421 |
+
f"is_half={is_half}",
|
| 422 |
+
f"input_sr={input_sr}",
|
| 423 |
+
f"target_sr={target_sr}",
|
| 424 |
+
f"hop_size={hop_size}",
|
| 425 |
+
f"thred={thred}",
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def interpolate_f0(
|
| 430 |
+
f0_16k: np.ndarray,
|
| 431 |
+
original_length: int,
|
| 432 |
+
original_sr: int,
|
| 433 |
+
*,
|
| 434 |
+
target_sr: int = 48000,
|
| 435 |
+
hop_size: int = 256,
|
| 436 |
+
max_duration: float = 20.0,
|
| 437 |
+
) -> np.ndarray:
|
| 438 |
+
"""Interpolate f0 from RMVPE's 16k hop grid to target mel hop grid."""
|
| 439 |
+
mel_target_sr = target_sr
|
| 440 |
+
mel_hop_size = hop_size
|
| 441 |
+
mel_max_duration = max_duration
|
| 442 |
+
|
| 443 |
+
batch_max_length = int(mel_max_duration * mel_target_sr / mel_hop_size)
|
| 444 |
+
duration_in_seconds = original_length / original_sr
|
| 445 |
+
effective_target_length = int(duration_in_seconds * mel_target_sr)
|
| 446 |
+
original_frames = math.ceil(effective_target_length / mel_hop_size)
|
| 447 |
+
target_frames = min(original_frames, batch_max_length)
|
| 448 |
+
|
| 449 |
+
rmvpe_hop = 160
|
| 450 |
+
t_16k = np.arange(len(f0_16k)) * (rmvpe_hop / 16000.0)
|
| 451 |
+
t_target = np.arange(target_frames) * (mel_hop_size / float(mel_target_sr))
|
| 452 |
+
|
| 453 |
+
if len(f0_16k) > 0:
|
| 454 |
+
f_interp = interp1d(
|
| 455 |
+
t_16k,
|
| 456 |
+
f0_16k,
|
| 457 |
+
kind="linear",
|
| 458 |
+
bounds_error=False,
|
| 459 |
+
fill_value=0.0,
|
| 460 |
+
assume_sorted=True,
|
| 461 |
+
)
|
| 462 |
+
f0 = f_interp(t_target)
|
| 463 |
+
else:
|
| 464 |
+
f0 = np.zeros(target_frames)
|
| 465 |
+
|
| 466 |
+
if len(f0) != target_frames:
|
| 467 |
+
f0 = (
|
| 468 |
+
f0[:target_frames]
|
| 469 |
+
if len(f0) > target_frames
|
| 470 |
+
else np.pad(f0, (0, target_frames - len(f0)), "constant")
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
return f0
|
| 474 |
+
|
| 475 |
+
def process(self, audio_path: str, *, f0_path: str | None = None, verbose: Optional[bool] = None) -> np.ndarray:
|
| 476 |
+
"""Run f0 extraction for a single wav.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
audio_path: Path to the input wav file.
|
| 480 |
+
f0_path: if is not None, save the f0 data to this path.
|
| 481 |
+
verbose: Override instance-level verbose flag for this call.
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
np.ndarray: shape ``[T]``, f0 in Hz (0 for unvoiced).
|
| 485 |
+
"""
|
| 486 |
+
verbose = self.verbose if verbose is None else verbose
|
| 487 |
+
if verbose:
|
| 488 |
+
print(f"[f0 extraction] process: start: {audio_path}")
|
| 489 |
+
t0 = time.time()
|
| 490 |
+
|
| 491 |
+
audio, _ = librosa.load(audio_path, sr=self.input_sr)
|
| 492 |
+
f0_16k = self.model.infer_from_audio(audio, thred=self.thred)
|
| 493 |
+
f0 = self.interpolate_f0(
|
| 494 |
+
f0_16k,
|
| 495 |
+
original_length=audio.shape[-1],
|
| 496 |
+
original_sr=self.input_sr,
|
| 497 |
+
target_sr=self.target_sr,
|
| 498 |
+
hop_size=self.hop_size,
|
| 499 |
+
max_duration=self.max_duration,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if verbose:
|
| 503 |
+
dt = time.time() - t0
|
| 504 |
+
voiced_ratio = float(np.mean(f0 > 0)) if len(f0) else 0.0
|
| 505 |
+
print(
|
| 506 |
+
"[f0 extraction] process: done:",
|
| 507 |
+
f"frames={len(f0)}",
|
| 508 |
+
f"voiced_ratio={voiced_ratio:.3f}",
|
| 509 |
+
f"time={dt:.3f}s",
|
| 510 |
+
)
|
| 511 |
+
if f0_path is not None:
|
| 512 |
+
np.save(f0_path, f0)
|
| 513 |
+
|
| 514 |
+
return f0
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
if __name__ == "__main__":
|
| 518 |
+
model_path = (
|
| 519 |
+
"pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt"
|
| 520 |
+
)
|
| 521 |
+
audio_path = "example/audio/zh_prompt.mp3"
|
| 522 |
+
|
| 523 |
+
pe = F0Extractor(
|
| 524 |
+
model_path,
|
| 525 |
+
device="cuda",
|
| 526 |
+
)
|
| 527 |
+
f0 = pe.process(audio_path, f0_path="example/audio/zh_prompt_f0.npy")
|
preprocess/tools/g2p.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import ToJyutping
|
| 4 |
+
from g2pM import G2pM
|
| 5 |
+
from g2p_en import G2p as G2pE
|
| 6 |
+
|
| 7 |
+
_EN_WORD_RE = re.compile(r"^[A-Za-z]+(?:'[A-Za-z]+)*$")
|
| 8 |
+
_ZH_WORD_RE = re.compile(r"[\u4e00-\u9fff]")
|
| 9 |
+
|
| 10 |
+
EN_FLAG = "en_"
|
| 11 |
+
YUE_FLAG = "yue_"
|
| 12 |
+
ZH_FLAG = "zh_"
|
| 13 |
+
|
| 14 |
+
g2p_zh = G2pM()
|
| 15 |
+
g2p_en = G2pE()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_chinese_char(word: str) -> bool:
|
| 19 |
+
if len(word) != 1:
|
| 20 |
+
return False
|
| 21 |
+
return bool(_ZH_WORD_RE.fullmatch(word))
|
| 22 |
+
|
| 23 |
+
def is_english_word(word: str) -> bool:
|
| 24 |
+
if not word:
|
| 25 |
+
return False
|
| 26 |
+
return bool(_EN_WORD_RE.fullmatch(word))
|
| 27 |
+
|
| 28 |
+
def g2p_cantonese(sent):
|
| 29 |
+
return ToJyutping.get_jyutping_list(sent) # with tone
|
| 30 |
+
|
| 31 |
+
def g2p_mandarin(sent):
|
| 32 |
+
return g2p_zh(sent, tone=True, char_split=False)
|
| 33 |
+
|
| 34 |
+
def g2p_english(word):
|
| 35 |
+
return g2p_en(word)
|
| 36 |
+
|
| 37 |
+
def g2p_transform(words, lang):
|
| 38 |
+
|
| 39 |
+
zh_words = []
|
| 40 |
+
transformed_words = [0] * len(words)
|
| 41 |
+
|
| 42 |
+
for idx, w in enumerate(words):
|
| 43 |
+
if w == "<SP>":
|
| 44 |
+
transformed_words[idx] = w
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
w = w.replace("?", "").replace(".", "").replace("!", "").replace(",", "")
|
| 48 |
+
|
| 49 |
+
if is_chinese_char(w):
|
| 50 |
+
zh_words.append([idx, w])
|
| 51 |
+
else:
|
| 52 |
+
if is_english_word(w):
|
| 53 |
+
w = EN_FLAG + "-".join(g2p_english(w.lower()))
|
| 54 |
+
else:
|
| 55 |
+
w = "<SP>"
|
| 56 |
+
transformed_words[idx] = w
|
| 57 |
+
|
| 58 |
+
sent = "".join([k[1] for k in zh_words])
|
| 59 |
+
|
| 60 |
+
# zh (zh and yue) transformer to g2p
|
| 61 |
+
if len(sent) > 0:
|
| 62 |
+
if lang == "Cantonese":
|
| 63 |
+
g2pm_rst = g2p_cantonese(sent) # with tone
|
| 64 |
+
g2pm_rst = [YUE_FLAG + k[1] for k in g2pm_rst]
|
| 65 |
+
else:
|
| 66 |
+
g2pm_rst = g2p_mandarin(sent)
|
| 67 |
+
g2pm_rst = [ZH_FLAG + k for k in g2pm_rst]
|
| 68 |
+
for p, w in zip([k[0] for k in zh_words], g2pm_rst):
|
| 69 |
+
transformed_words[p] = w
|
| 70 |
+
|
| 71 |
+
return transformed_words
|
| 72 |
+
|
preprocess/tools/lyric_transcription.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary
|
| 2 |
+
# https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from typing import Any, Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
import librosa
|
| 9 |
+
import numpy as np
|
| 10 |
+
from funasr import AutoModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _build_words_with_gaps(raw_words, raw_timestamps, wav_fn: str):
|
| 14 |
+
words, word_durs = [], []
|
| 15 |
+
prev = 0.0
|
| 16 |
+
for w, t in zip(raw_words, raw_timestamps):
|
| 17 |
+
s, e = float(t[0]), float(t[1])
|
| 18 |
+
if s > prev:
|
| 19 |
+
words.append("<SP>")
|
| 20 |
+
word_durs.append(s - prev)
|
| 21 |
+
words.append(w)
|
| 22 |
+
word_durs.append(e - s)
|
| 23 |
+
prev = e
|
| 24 |
+
|
| 25 |
+
wav_len = librosa.get_duration(filename=wav_fn)
|
| 26 |
+
if wav_len > prev:
|
| 27 |
+
if len(words) == 0:
|
| 28 |
+
words.append("<SP>")
|
| 29 |
+
word_durs.append(wav_len)
|
| 30 |
+
return words, word_durs
|
| 31 |
+
if words[-1] != "<SP>":
|
| 32 |
+
words.append("<SP>")
|
| 33 |
+
word_durs.append(wav_len - prev)
|
| 34 |
+
else:
|
| 35 |
+
word_durs[-1] += wav_len - prev
|
| 36 |
+
|
| 37 |
+
return words, word_durs
|
| 38 |
+
|
| 39 |
+
def _word_dur_post_process(words, word_durs, f0):
|
| 40 |
+
"""Post-process word durations using f0 to better place silences.
|
| 41 |
+
"""
|
| 42 |
+
# f0 time grid parameters
|
| 43 |
+
sr = 24000 # f0 sample rate
|
| 44 |
+
hop_length = 480 # f0 hop length
|
| 45 |
+
|
| 46 |
+
# Convert word durations (seconds) to frame boundaries on the f0 grid.
|
| 47 |
+
boundaries = np.cumsum([
|
| 48 |
+
0,
|
| 49 |
+
*[
|
| 50 |
+
int(dur * sr / hop_length)
|
| 51 |
+
for dur in word_durs
|
| 52 |
+
],
|
| 53 |
+
]).tolist()
|
| 54 |
+
|
| 55 |
+
sil_tolerance = 5 # tolerance frames for silence detection
|
| 56 |
+
ext_tolerance = 5 # tolerance frames for vocal extension
|
| 57 |
+
|
| 58 |
+
new_words: list[str] = []
|
| 59 |
+
new_word_durs: list[float] = []
|
| 60 |
+
if words:
|
| 61 |
+
new_words.append(words[0])
|
| 62 |
+
new_word_durs.append(word_durs[0])
|
| 63 |
+
|
| 64 |
+
for i in range(1, len(words)):
|
| 65 |
+
word = words[i]
|
| 66 |
+
if word == "<SP>":
|
| 67 |
+
start_frame = boundaries[i]
|
| 68 |
+
end_frame = boundaries[i + 1]
|
| 69 |
+
|
| 70 |
+
num_frames = end_frame - start_frame
|
| 71 |
+
frame_idx = start_frame
|
| 72 |
+
|
| 73 |
+
# Find first region with at least 5 consecutive "unvoiced" frames.
|
| 74 |
+
unvoiced_count = 0
|
| 75 |
+
while frame_idx < end_frame:
|
| 76 |
+
if f0[frame_idx] <= 1: # unvoiced
|
| 77 |
+
unvoiced_count += 1
|
| 78 |
+
if unvoiced_count >= sil_tolerance:
|
| 79 |
+
frame_idx -= sil_tolerance - 1 # back to the last voiced frame
|
| 80 |
+
break
|
| 81 |
+
else:
|
| 82 |
+
unvoiced_count = 0
|
| 83 |
+
frame_idx += 1
|
| 84 |
+
|
| 85 |
+
voice_frames = frame_idx - start_frame
|
| 86 |
+
|
| 87 |
+
if voice_frames >= int(num_frames * 0.9): # over 90% voiced
|
| 88 |
+
# Treat the whole "<SP>" as silence and merge into previous word.
|
| 89 |
+
new_word_durs[-1] += word_durs[i]
|
| 90 |
+
elif voice_frames >= ext_tolerance: # over 5 frames voiced
|
| 91 |
+
# Split the "<SP>" into two parts: leading silence and tail kept as "<SP>".
|
| 92 |
+
dur = voice_frames * hop_length / sr
|
| 93 |
+
new_word_durs[-1] += dur
|
| 94 |
+
new_words.append("<SP>")
|
| 95 |
+
new_word_durs.append(word_durs[i] - dur)
|
| 96 |
+
else:
|
| 97 |
+
# Too short to adjust, keep as-is.
|
| 98 |
+
new_words.append(word)
|
| 99 |
+
new_word_durs.append(word_durs[i])
|
| 100 |
+
else:
|
| 101 |
+
new_words.append(word)
|
| 102 |
+
new_word_durs.append(word_durs[i])
|
| 103 |
+
|
| 104 |
+
return new_words, new_word_durs
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class _ASRZhModel:
|
| 108 |
+
"""Mandarin/Cantonese ASR wrapper."""
|
| 109 |
+
|
| 110 |
+
def __init__(self, model_path: str, device: str):
|
| 111 |
+
self.model = AutoModel(
|
| 112 |
+
model=model_path,
|
| 113 |
+
disable_update=True,
|
| 114 |
+
device=device,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def process(self, wav_fn):
|
| 118 |
+
out = self.model.generate(wav_fn, output_timestamp=True)[0]
|
| 119 |
+
raw_words = out["text"].replace("@", "").split(" ")
|
| 120 |
+
raw_timestamps = [[t[0] / 1000, t[1] / 1000] for t in out["timestamp"]]
|
| 121 |
+
words, word_durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
|
| 122 |
+
|
| 123 |
+
f0_path = os.path.splitext(wav_fn)[0] + "_f0.npy"
|
| 124 |
+
if os.path.exists(f0_path):
|
| 125 |
+
words, word_durs = _word_dur_post_process(
|
| 126 |
+
words, word_durs, np.load(f0_path)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return words, word_durs
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class _ASREnModel:
|
| 133 |
+
"""English ASR wrapper for NeMo Parakeet-TDT."""
|
| 134 |
+
|
| 135 |
+
def __init__(self, model_path: str, device: str):
|
| 136 |
+
try:
|
| 137 |
+
import nemo.collections.asr as nemo_asr # type: ignore
|
| 138 |
+
except Exception as e: # pragma: no cover
|
| 139 |
+
raise ImportError(
|
| 140 |
+
"NeMo (nemo_toolkit) is required for ASR English but is not available in this Python env. "
|
| 141 |
+
"Install it in the active environment, then retry."
|
| 142 |
+
) from e
|
| 143 |
+
|
| 144 |
+
self.model = nemo_asr.models.ASRModel.restore_from(
|
| 145 |
+
restore_path=model_path,
|
| 146 |
+
map_location=device,
|
| 147 |
+
)
|
| 148 |
+
self.model.eval()
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
def _clean_word(word: str) -> str:
|
| 152 |
+
return re.sub(r"[\?\.,:]", "", word).strip()
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _extract_word_segments(output: Any) -> List[Dict[str, Any]]:
|
| 156 |
+
ts = getattr(output, "timestamp", None)
|
| 157 |
+
if not ts or not isinstance(ts, dict):
|
| 158 |
+
return []
|
| 159 |
+
word_ts = ts.get("word")
|
| 160 |
+
return word_ts if isinstance(word_ts, list) else []
|
| 161 |
+
|
| 162 |
+
def process(self, wav_fn: str) -> Tuple[List[str], List[float]]:
|
| 163 |
+
outputs = self.model.transcribe(
|
| 164 |
+
[wav_fn],
|
| 165 |
+
timestamps=True,
|
| 166 |
+
batch_size=1,
|
| 167 |
+
num_workers=0,
|
| 168 |
+
)
|
| 169 |
+
output = outputs[0] if outputs else None
|
| 170 |
+
|
| 171 |
+
raw_words: List[str] = []
|
| 172 |
+
raw_timestamps: List[List[float]] = []
|
| 173 |
+
if output is not None:
|
| 174 |
+
for w in self._extract_word_segments(output):
|
| 175 |
+
s, e = float(w.get("start", 0.0)), float(w.get("end", 0.0))
|
| 176 |
+
word = self._clean_word(str(w.get("word", "")))
|
| 177 |
+
if word:
|
| 178 |
+
raw_words.append(word)
|
| 179 |
+
raw_timestamps.append([s, e])
|
| 180 |
+
|
| 181 |
+
words, durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
|
| 182 |
+
|
| 183 |
+
f0_path = os.path.splitext(wav_fn)[0] + "_f0.npy"
|
| 184 |
+
if os.path.exists(f0_path):
|
| 185 |
+
words, durs = _word_dur_post_process(
|
| 186 |
+
words, durs, np.load(f0_path)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return words, durs
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class LyricTranscriber:
|
| 193 |
+
"""Transcribe lyrics from singing voice segment
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
zh_model_path: str,
|
| 199 |
+
en_model_path: str,
|
| 200 |
+
device: str = "cuda",
|
| 201 |
+
*,
|
| 202 |
+
verbose: bool = True,
|
| 203 |
+
):
|
| 204 |
+
"""Initialize lyric transcriber.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
zh_model_path (str): Path to the Chinese model file.
|
| 208 |
+
en_model_path (str): Path to the English model file.
|
| 209 |
+
device (str): Device to use for tensor operations.
|
| 210 |
+
verbose (bool): Whether to print verbose logs.
|
| 211 |
+
"""
|
| 212 |
+
self.verbose = verbose
|
| 213 |
+
self.device = device
|
| 214 |
+
self.zh_model_path = zh_model_path
|
| 215 |
+
self.en_model_path = en_model_path
|
| 216 |
+
|
| 217 |
+
if self.verbose:
|
| 218 |
+
print(
|
| 219 |
+
"[lyric transcription] init: start:",
|
| 220 |
+
f"device={device}",
|
| 221 |
+
f"model_path={zh_model_path}",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Always initialize Chinese ASR.
|
| 225 |
+
self.zh_model = _ASRZhModel(device=device, model_path=zh_model_path)
|
| 226 |
+
|
| 227 |
+
# English ASR will be lazily initialized on first English request to avoid long waiting cost when importing NeMo
|
| 228 |
+
self.en_model = None
|
| 229 |
+
|
| 230 |
+
if self.verbose:
|
| 231 |
+
print("[lyric transcription] init: success")
|
| 232 |
+
|
| 233 |
+
def process(self, wav_fn, language: str | None = "Mandarin", *, verbose: bool | None = None):
|
| 234 |
+
""" Lyric transcriber process
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
wav_fn (str): Path to the audio file.
|
| 238 |
+
language (str | None): Language of the audio. Defaults to "Mandarin". Supports "Mandarin", "Cantonese" and "English".
|
| 239 |
+
verbose (bool | None): Whether to print verbose logs. Defaults to None.
|
| 240 |
+
"""
|
| 241 |
+
v = self.verbose if verbose is None else verbose
|
| 242 |
+
if language not in {"Mandarin", "Cantonese", "English"}:
|
| 243 |
+
raise ValueError(f"Unsupported language: {language}, should be one of ['Mandarin', 'Cantonese', 'English']")
|
| 244 |
+
if v:
|
| 245 |
+
print(f"[lyric transcription] process: start: wav_fn={wav_fn} language={language}")
|
| 246 |
+
t0 = time.time()
|
| 247 |
+
|
| 248 |
+
lang = (language or "auto").lower()
|
| 249 |
+
if lang in {"english"}:
|
| 250 |
+
if self.en_model is None:
|
| 251 |
+
# Lazy-load NeMo model only when English is actually used.
|
| 252 |
+
if v:
|
| 253 |
+
print("[lyric transcription] init English ASR start, please make sure NeMo is installed and wait for a while")
|
| 254 |
+
self.en_model = _ASREnModel(model_path=self.en_model_path, device=self.device)
|
| 255 |
+
if v:
|
| 256 |
+
print("[lyric transcription] init English ASR success")
|
| 257 |
+
out = self.en_model.process(wav_fn)
|
| 258 |
+
else:
|
| 259 |
+
out = self.zh_model.process(wav_fn)
|
| 260 |
+
|
| 261 |
+
if v:
|
| 262 |
+
words, durs = out
|
| 263 |
+
n_words = len(words) if isinstance(words, list) else 0
|
| 264 |
+
dur_sum = float(sum(durs)) if isinstance(durs, list) else 0.0
|
| 265 |
+
dt = time.time() - t0
|
| 266 |
+
print(
|
| 267 |
+
"[lyric transcription] process: done:",
|
| 268 |
+
f"n_words={n_words}",
|
| 269 |
+
f"dur_sum={dur_sum:.3f}s",
|
| 270 |
+
f"time={dt:.3f}s",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
return out
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
if __name__ == "__main__":
|
| 277 |
+
m = LyricTranscriber(
|
| 278 |
+
zh_model_path="pretrained_models/SoulX-Singer-Preprocess/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 279 |
+
en_model_path="pretrained_models/SoulX-Singer-Preprocess/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
|
| 280 |
+
device="cuda"
|
| 281 |
+
)
|
| 282 |
+
print(m.process("example/audio/zh_prompt.mp3", language="Mandarin"))
|
| 283 |
+
print(m.process("example/audio/en_prompt.mp3", language="English"))
|
preprocess/tools/midi_editor/README.md
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎹 MIDI Editor - Web-based Singing MIDI Editor
|
| 2 |
+
|
| 3 |
+
[English](README.md) | [简体中文](README_CN.md)
|
| 4 |
+
|
| 5 |
+
A full-featured web MIDI editor for singing voice preprocess. It supports real-time drag editing of MIDI notes, lyric editing, audio waveform alignment, and importing/exporting MIDI files with lyrics.
|
| 6 |
+
|
| 7 |
+
  
|
| 8 |
+
|
| 9 |
+
## ✨ Features
|
| 10 |
+
|
| 11 |
+
### 🎼 Piano Roll Editing
|
| 12 |
+
|
| 13 |
+
- **Visual note editing**: Full range from C1 to C8 with intuitive piano key layout
|
| 14 |
+
- **Drag operations**:
|
| 15 |
+
- Move notes: drag note blocks to adjust position and pitch
|
| 16 |
+
- Resize start: drag the left edge to adjust start time
|
| 17 |
+
- Resize end: drag the right edge to adjust end time
|
| 18 |
+
- **Quick pitch adjust**:
|
| 19 |
+
- Command/Ctrl + Up/Down to nudge selected note pitch
|
| 20 |
+
- Use the Transpose control in the toolbar to shift all notes at once
|
| 21 |
+
- **Double-click to add**: Add new notes quickly in empty areas
|
| 22 |
+
- **Piano key preview**: Click a key on the left to audition the pitch
|
| 23 |
+
|
| 24 |
+
### 🔍 Zoom & Navigation
|
| 25 |
+
|
| 26 |
+
- **Horizontal zoom**
|
| 27 |
+
- **Vertical zoom**
|
| 28 |
+
- **Dynamic snapping**: finer snap granularity at higher zoom (min 0.01s)
|
| 29 |
+
- **Auto scroll**: keep the playhead visible during playback
|
| 30 |
+
|
| 31 |
+
### 📝 Lyric Editing
|
| 32 |
+
|
| 33 |
+
- **Inline editing**: edit lyrics for each note in the side list
|
| 34 |
+
- **Batch fill**: enter lyrics and auto-fill notes in order
|
| 35 |
+
- **Fill from selection**: start batch fill from the currently selected note
|
| 36 |
+
- **Precise fields**: edit PITCH, START, and END directly
|
| 37 |
+
- **Confirm edits**: press Enter or click ✓ to confirm, avoiding accidental changes
|
| 38 |
+
|
| 39 |
+
### 🎵 Audio Alignment
|
| 40 |
+
|
| 41 |
+
- **Waveform display**: import audio to display waveform, synced with the MIDI timeline
|
| 42 |
+
- **Formats**: MP3, WAV, OGG, FLAC, M4A, AAC
|
| 43 |
+
- **Sync playback**: play audio and MIDI together with independent volume control
|
| 44 |
+
- **Click to seek**: click waveform or timeline to seek
|
| 45 |
+
|
| 46 |
+
### ⚠️ Overlap Detection
|
| 47 |
+
|
| 48 |
+
- **Visual highlight**: overlapping notes blink in red
|
| 49 |
+
- **One-click fix**: remove all overlaps automatically
|
| 50 |
+
|
| 51 |
+
### 📥 Import & Export
|
| 52 |
+
|
| 53 |
+
- **MIDI import**: parse standard MIDI files with automatic lyric metadata extraction
|
| 54 |
+
- **MIDI export**: export MIDI files with lyric information
|
| 55 |
+
|
| 56 |
+
### 🎨 UI & UX
|
| 57 |
+
|
| 58 |
+
- **Theme toggle**: light and dark modes
|
| 59 |
+
- **Responsive layout**: adapts to window size
|
| 60 |
+
- **SVG grid**: cross-browser grid rendering
|
| 61 |
+
- **Status feedback**: real-time state and error tips
|
| 62 |
+
|
| 63 |
+
## 🚀 Quick Start
|
| 64 |
+
|
| 65 |
+
### Requirements
|
| 66 |
+
|
| 67 |
+
- Node.js 18+
|
| 68 |
+
- npm or yarn
|
| 69 |
+
|
| 70 |
+
### Install
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
# Install dependencies
|
| 74 |
+
npm install
|
| 75 |
+
|
| 76 |
+
# Start dev server
|
| 77 |
+
npm run dev
|
| 78 |
+
|
| 79 |
+
# Expose to LAN
|
| 80 |
+
npm run dev -- --host 0.0.0.0
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### Build
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
# Build for production
|
| 87 |
+
npm run build
|
| 88 |
+
|
| 89 |
+
# Preview build
|
| 90 |
+
npm run preview
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## 📖 Usage
|
| 94 |
+
|
| 95 |
+
### Basic Workflow
|
| 96 |
+
|
| 97 |
+
1. **Import MIDI**: click Import MIDI and select a .mid file
|
| 98 |
+
2. **Edit notes**: drag notes in the piano roll to adjust time and pitch
|
| 99 |
+
3. **Add lyrics**: edit lyrics in the right-side list, or use batch fill
|
| 100 |
+
4. **Align audio** (optional): import reference audio for side-by-side editing
|
| 101 |
+
5. **Export**: click Export MIDI to save
|
| 102 |
+
|
| 103 |
+
### Shortcuts
|
| 104 |
+
|
| 105 |
+
| Action | Description |
|
| 106 |
+
|------|------|
|
| 107 |
+
| Double-click piano roll | Add a new note |
|
| 108 |
+
| Double-click note | Edit lyric |
|
| 109 |
+
| Drag note | Move note and pitch |
|
| 110 |
+
| Drag note edges | Resize note |
|
| 111 |
+
| Backspace / Delete | Delete selected note |
|
| 112 |
+
| Enter | Confirm value edits |
|
| 113 |
+
| Escape | Cancel value edits |
|
| 114 |
+
| Ctrl(Command) + Wheel | Horizontal zoom |
|
| 115 |
+
| Ctrl(Command) + Shift(Option) + Wheel | Vertical zoom |
|
| 116 |
+
|
| 117 |
+
### Playback Controls
|
| 118 |
+
|
| 119 |
+
| Button | Description |
|
| 120 |
+
|------|------|
|
| 121 |
+
| ⏮ | Go to start |
|
| 122 |
+
| ⏪ | Back 2 seconds |
|
| 123 |
+
| ▶ / ⏸ | Play / Pause |
|
| 124 |
+
| ⏩ | Forward 2 seconds |
|
| 125 |
+
| ⏭ | Go to end |
|
| 126 |
+
| Selection | Play selected region |
|
| 127 |
+
|
| 128 |
+
## 🛠 Tech Stack
|
| 129 |
+
|
| 130 |
+
- **Frontend**: React 19 + TypeScript
|
| 131 |
+
- **Build**: Vite 7
|
| 132 |
+
- **State**: Zustand
|
| 133 |
+
- **Audio**: Tone.js
|
| 134 |
+
- **Waveform**: WaveSurfer.js
|
| 135 |
+
- **MIDI**: @tonejs/midi
|
| 136 |
+
- **Styles**: CSS with custom variables
|
| 137 |
+
|
| 138 |
+
## 📁 Project Structure
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
.
|
| 142 |
+
├── eslint.config.js
|
| 143 |
+
├── index.html
|
| 144 |
+
├── package.json
|
| 145 |
+
├── postcss.config.js
|
| 146 |
+
├── README.md
|
| 147 |
+
├── README_CN.md
|
| 148 |
+
├── tailwind.config.js
|
| 149 |
+
├── tsconfig.app.json
|
| 150 |
+
├── tsconfig.json
|
| 151 |
+
├── tsconfig.node.json
|
| 152 |
+
├── vite.config.ts
|
| 153 |
+
├── public/
|
| 154 |
+
└── src/
|
| 155 |
+
├── App.css # Main styles (theme variables, layout, components)
|
| 156 |
+
├── App.tsx # Main app component (transport, import/export, transpose)
|
| 157 |
+
├── constants.ts # Constants (grid width, row height, pitch range)
|
| 158 |
+
├── i18n.ts # Internationalization (zh/en translations, smart lyric tokenizer)
|
| 159 |
+
├── index.css # Global styles (Tailwind, root font, theme gradients)
|
| 160 |
+
├── main.tsx # React entry point
|
| 161 |
+
├── types.ts # Type definitions (NoteEvent, TimeSignature, etc.)
|
| 162 |
+
├── components/
|
| 163 |
+
│ ├── AudioTrack.tsx # Audio waveform display component
|
| 164 |
+
│ ├── LyricTable.tsx # Lyric editing table component
|
| 165 |
+
│ └── PianoRoll.tsx # Piano roll editor component
|
| 166 |
+
├── lib/
|
| 167 |
+
│ └── midi.ts # MIDI import/export utilities (UTF-8 lyric encoding)
|
| 168 |
+
└── store/
|
| 169 |
+
└── useMidiStore.ts # Zustand state management
|
| 170 |
+
```
|
preprocess/tools/midi_editor/README_CN.md
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎹 MIDI Editor - 网页端歌声 MIDI 编辑器
|
| 2 |
+
|
| 3 |
+
[English](README.md) | [简体中文](README_CN.md)
|
| 4 |
+
|
| 5 |
+
一个功能完整的网页端歌声 MIDI 文件编辑器。支持实时拖拽调整 MIDI 音符、歌词编辑、音频波形对齐,以及导入导出含歌词的 MIDI 文件。
|
| 6 |
+
|
| 7 |
+
  
|
| 8 |
+
|
| 9 |
+
## ✨ 功能特性
|
| 10 |
+
|
| 11 |
+
### 🎼 钢琴卷帘编辑
|
| 12 |
+
|
| 13 |
+
- **可视化音符编辑**:支持 C1-C8 全音域显示,直观的钢琴键布局
|
| 14 |
+
- **拖拽操作**:
|
| 15 |
+
- 移动音符:拖拽音符块调整位置和音高
|
| 16 |
+
- 调整音头:拖拽音符左边缘调整开始时间
|
| 17 |
+
- 调整音尾:拖拽音符右边缘调整结束时间
|
| 18 |
+
- **快捷音高调整**:
|
| 19 |
+
- Command/Ctrl + 上/下键调整选中音符的音高
|
| 20 |
+
- 通过功能区的移调功能来整体移动音高
|
| 21 |
+
- **双击添加**:在钢琴卷帘空白处双击快速添加新音符
|
| 22 |
+
- **钢琴键试听**:点击左侧钢琴键可试听对应音高
|
| 23 |
+
|
| 24 |
+
### 🔍 缩放与导航
|
| 25 |
+
|
| 26 |
+
- **水平缩放**
|
| 27 |
+
- **垂直缩放**
|
| 28 |
+
- **动态精度**:缩放越大,音符调整的 snap 粒度越精细(最小 0.01 秒)
|
| 29 |
+
- **自动滚动**:播放时播放头自动保持可见
|
| 30 |
+
|
| 31 |
+
### 📝 歌词编辑
|
| 32 |
+
|
| 33 |
+
- **实时编辑**:右侧列表直接编辑每个音符的歌词
|
| 34 |
+
- **批量填充**:输入一段歌词,按字顺序自动填充到音符
|
| 35 |
+
- **从选中开始**:批量填充可从当前选中的音符开始
|
| 36 |
+
- **精确调整**:可直接编辑 PITCH(音高)、START(开始时间)、END(结束时间)
|
| 37 |
+
- **确认机制**:修改数值后按 Enter 或点击 ✓ 确认,避免误操作
|
| 38 |
+
|
| 39 |
+
### 🎵 音频对齐
|
| 40 |
+
|
| 41 |
+
- **波形显示**:导入音频后显示波形,与 MIDI 同步滚动
|
| 42 |
+
- **格式支持**:MP3、WAV、OGG、FLAC、M4A、AAC
|
| 43 |
+
- **同步播放**:音频与 MIDI 同步播放,可分别调整音量大小
|
| 44 |
+
- **点击定位**:点击波形或时间尺可快速定位播放位置
|
| 45 |
+
|
| 46 |
+
### ⚠️ 重叠检测
|
| 47 |
+
|
| 48 |
+
- **可视化标注**:时间重叠的音符显示为红色并闪烁
|
| 49 |
+
- **一键修复**:点击消除重叠按钮自动修复所有重叠
|
| 50 |
+
|
| 51 |
+
### 📥 导入导出
|
| 52 |
+
|
| 53 |
+
- **MIDI 导入**:支持标准 MIDI 文件,自动解析歌词元数据
|
| 54 |
+
- **MIDI 导出**:导出包含歌词信息的 MIDI 文件
|
| 55 |
+
|
| 56 |
+
### 🎨 界面特性
|
| 57 |
+
|
| 58 |
+
- **主题切换**:支持浅色/深色主题
|
| 59 |
+
- **响应式布局**:自适应窗口大小
|
| 60 |
+
- **SVG 网格**:跨浏览器兼容的网格渲染
|
| 61 |
+
- **状态提示**:实时显示操作状态和错误信息
|
| 62 |
+
|
| 63 |
+
## 🚀 快速开始
|
| 64 |
+
|
| 65 |
+
### 环境要求
|
| 66 |
+
|
| 67 |
+
- Node.js 18+
|
| 68 |
+
- npm 或 yarn
|
| 69 |
+
|
| 70 |
+
### 安装
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
# 安装依赖
|
| 74 |
+
npm install
|
| 75 |
+
|
| 76 |
+
# 启动开发服务器
|
| 77 |
+
npm run dev
|
| 78 |
+
|
| 79 |
+
# 在局域网启动
|
| 80 |
+
npm run dev -- --host 0.0.0.0
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### 构建
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
# 构建生产版本
|
| 87 |
+
npm run build
|
| 88 |
+
|
| 89 |
+
# 预览构建结果
|
| 90 |
+
npm run preview
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## 📖 使用指南
|
| 94 |
+
|
| 95 |
+
### 基本工作流
|
| 96 |
+
|
| 97 |
+
1. **导入 MIDI**:点击导入 MIDI 按钮选择 .mid 文件
|
| 98 |
+
2. **编辑音符**:在钢琴卷帘中拖拽调整音符位置和时长
|
| 99 |
+
3. **添加歌词**:在右侧列表中输入句级别的歌词或单字编辑
|
| 100 |
+
4. **对齐音频**(可选):导入参考音频进行对照编辑
|
| 101 |
+
5. **导出文件**:点击导出含歌词 MIDI 保存文件
|
| 102 |
+
|
| 103 |
+
### 快捷操作
|
| 104 |
+
|
| 105 |
+
| 操作 | 说明 |
|
| 106 |
+
|------|------|
|
| 107 |
+
| 双击钢琴卷帘 | 添加新音符 |
|
| 108 |
+
| 双击音符 | 修改歌词 |
|
| 109 |
+
| 拖拽音符 | 移动音符位置/音高 |
|
| 110 |
+
| 拖拽音符边缘 | 调整音符时长 |
|
| 111 |
+
| Backspace / Delete | 删除选中音符 |
|
| 112 |
+
| Enter | 确认数值修改 |
|
| 113 |
+
| Escape | 取消数值修改 |
|
| 114 |
+
| Ctrl(Command) + 滚轮 | 水平缩放 |
|
| 115 |
+
| Ctrl(Command) + Shift(Option) + 滚轮 | 垂直缩放 |
|
| 116 |
+
|
| 117 |
+
### 播放控制
|
| 118 |
+
|
| 119 |
+
| 按钮 | 功能 |
|
| 120 |
+
|------|------|
|
| 121 |
+
| ⏮ | 回到开头 |
|
| 122 |
+
| ⏪ | 后退 2 秒 |
|
| 123 |
+
| ▶ / ⏸ | 播放 / 暂停 |
|
| 124 |
+
| ⏩ | 前进 2 秒 |
|
| 125 |
+
| ⏭ | 跳到结尾 |
|
| 126 |
+
| 选定区域 | 播放选定区域 |
|
| 127 |
+
|
| 128 |
+
## 🛠 技术栈
|
| 129 |
+
|
| 130 |
+
- **前端框架**:React 19 + TypeScript
|
| 131 |
+
- **构建工具**:Vite 7
|
| 132 |
+
- **状态管理**:Zustand
|
| 133 |
+
- **音频引擎**:Tone.js
|
| 134 |
+
- **波形显示**:WaveSurfer.js
|
| 135 |
+
- **MIDI 解析**:@tonejs/midi
|
| 136 |
+
- **样式**:CSS(自定义变量主题)
|
| 137 |
+
|
| 138 |
+
## 📁 项目结构
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
.
|
| 142 |
+
├── eslint.config.js
|
| 143 |
+
├── index.html
|
| 144 |
+
├── package.json
|
| 145 |
+
├── postcss.config.js
|
| 146 |
+
├── README.md
|
| 147 |
+
├── README_CN.md
|
| 148 |
+
├── tailwind.config.js
|
| 149 |
+
├── tsconfig.app.json
|
| 150 |
+
├── tsconfig.json
|
| 151 |
+
├── tsconfig.node.json
|
| 152 |
+
├── vite.config.ts
|
| 153 |
+
├── public/
|
| 154 |
+
└── src/
|
| 155 |
+
├── App.css # 主样式(含主题变量、布局、组件样式)
|
| 156 |
+
├── App.tsx # 主应用组件(走带、导入导出、移调等)
|
| 157 |
+
├── constants.ts # 常量定义(网格宽度、行高、音域范围)
|
| 158 |
+
├── i18n.ts # 国际化(中英文翻译、歌词智能分词器)
|
| 159 |
+
├── index.css # ���局样式(Tailwind、根字体、主题渐变)
|
| 160 |
+
├── main.tsx # React 入口
|
| 161 |
+
├── types.ts # 类型定义(NoteEvent、TimeSignature 等)
|
| 162 |
+
├── components/
|
| 163 |
+
│ ├── AudioTrack.tsx # 音频波形显示组件
|
| 164 |
+
│ ├── LyricTable.tsx # 歌词编辑表格组件
|
| 165 |
+
│ └── PianoRoll.tsx # 钢琴卷帘编辑器组件
|
| 166 |
+
├── lib/
|
| 167 |
+
│ └── midi.ts # MIDI 导入导出工具(含 UTF-8 歌词编解码)
|
| 168 |
+
└── store/
|
| 169 |
+
└── useMidiStore.ts # Zustand 状态管理
|
| 170 |
+
```
|
preprocess/tools/midi_editor/eslint.config.js
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import js from '@eslint/js'
|
| 2 |
+
import globals from 'globals'
|
| 3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import tseslint from 'typescript-eslint'
|
| 6 |
+
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 7 |
+
|
| 8 |
+
export default defineConfig([
|
| 9 |
+
globalIgnores(['dist']),
|
| 10 |
+
{
|
| 11 |
+
files: ['**/*.{ts,tsx}'],
|
| 12 |
+
extends: [
|
| 13 |
+
js.configs.recommended,
|
| 14 |
+
tseslint.configs.recommended,
|
| 15 |
+
reactHooks.configs.flat.recommended,
|
| 16 |
+
reactRefresh.configs.vite,
|
| 17 |
+
],
|
| 18 |
+
languageOptions: {
|
| 19 |
+
ecmaVersion: 2020,
|
| 20 |
+
globals: globals.browser,
|
| 21 |
+
},
|
| 22 |
+
},
|
| 23 |
+
])
|
preprocess/tools/midi_editor/index.html
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>SoulX-Singer MIDI Editor</title>
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<div id="root"></div>
|
| 11 |
+
<script type="module" src="/src/main.tsx"></script>
|
| 12 |
+
</body>
|
| 13 |
+
</html>
|
preprocess/tools/midi_editor/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
preprocess/tools/midi_editor/package.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "midi-editor",
|
| 3 |
+
"private": true,
|
| 4 |
+
"version": "0.0.0",
|
| 5 |
+
"type": "module",
|
| 6 |
+
"scripts": {
|
| 7 |
+
"dev": "vite",
|
| 8 |
+
"build": "tsc -b && vite build",
|
| 9 |
+
"lint": "eslint .",
|
| 10 |
+
"preview": "vite preview"
|
| 11 |
+
},
|
| 12 |
+
"dependencies": {
|
| 13 |
+
"@tonejs/midi": "^2.0.28",
|
| 14 |
+
"class-variance-authority": "^0.7.1",
|
| 15 |
+
"nanoid": "^5.1.6",
|
| 16 |
+
"react": "^19.2.0",
|
| 17 |
+
"react-dom": "^19.2.0",
|
| 18 |
+
"tone": "^15.1.22",
|
| 19 |
+
"wavesurfer.js": "^7.12.1",
|
| 20 |
+
"zustand": "^5.0.10"
|
| 21 |
+
},
|
| 22 |
+
"devDependencies": {
|
| 23 |
+
"@eslint/js": "^9.39.1",
|
| 24 |
+
"@types/node": "^24.10.1",
|
| 25 |
+
"@types/react": "^19.2.5",
|
| 26 |
+
"@types/react-dom": "^19.2.3",
|
| 27 |
+
"@vitejs/plugin-react": "^5.1.1",
|
| 28 |
+
"autoprefixer": "^10.4.20",
|
| 29 |
+
"eslint": "^9.39.1",
|
| 30 |
+
"eslint-plugin-react-hooks": "^7.0.1",
|
| 31 |
+
"eslint-plugin-react-refresh": "^0.4.24",
|
| 32 |
+
"globals": "^16.5.0",
|
| 33 |
+
"postcss": "^8.4.47",
|
| 34 |
+
"tailwindcss": "^3.4.15",
|
| 35 |
+
"typescript": "~5.9.3",
|
| 36 |
+
"typescript-eslint": "^8.46.4",
|
| 37 |
+
"vite": "^7.2.4"
|
| 38 |
+
}
|
| 39 |
+
}
|
preprocess/tools/midi_editor/postcss.config.js
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export default {
|
| 2 |
+
plugins: {
|
| 3 |
+
tailwindcss: {},
|
| 4 |
+
autoprefixer: {},
|
| 5 |
+
},
|
| 6 |
+
}
|
preprocess/tools/midi_editor/public/vite.svg
ADDED
|
|
preprocess/tools/midi_editor/src/App.css
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.app-shell {
|
| 2 |
+
padding: 24px;
|
| 3 |
+
color: var(--text-primary);
|
| 4 |
+
width: 100%;
|
| 5 |
+
max-width: 100%;
|
| 6 |
+
margin: 0;
|
| 7 |
+
height: 100vh;
|
| 8 |
+
max-height: 100vh;
|
| 9 |
+
display: flex;
|
| 10 |
+
flex-direction: column;
|
| 11 |
+
overflow: hidden;
|
| 12 |
+
box-sizing: border-box;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
.topbar {
|
| 16 |
+
display: flex;
|
| 17 |
+
align-items: center;
|
| 18 |
+
justify-content: space-between;
|
| 19 |
+
gap: 24px;
|
| 20 |
+
background: var(--panel-strong);
|
| 21 |
+
border: 1px solid var(--border-subtle);
|
| 22 |
+
border-radius: 16px;
|
| 23 |
+
padding: 20px 24px;
|
| 24 |
+
box-shadow: var(--shadow-panel);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.topbar h1 {
|
| 28 |
+
margin: 4px 0 0 0;
|
| 29 |
+
font-size: 26px;
|
| 30 |
+
letter-spacing: -0.5px;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.eyebrow {
|
| 34 |
+
margin: 0;
|
| 35 |
+
text-transform: uppercase;
|
| 36 |
+
font-size: 12px;
|
| 37 |
+
letter-spacing: 2px;
|
| 38 |
+
color: var(--text-muted);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.muted {
|
| 42 |
+
margin: 6px 0 0 0;
|
| 43 |
+
color: var(--text-muted);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.actions {
|
| 47 |
+
display: flex;
|
| 48 |
+
gap: 10px;
|
| 49 |
+
align-items: center;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
.transpose-group {
|
| 53 |
+
display: flex;
|
| 54 |
+
align-items: center;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.transpose-select {
|
| 58 |
+
padding: 10px 10px;
|
| 59 |
+
border-radius: 12px;
|
| 60 |
+
border: 1px solid var(--border-soft);
|
| 61 |
+
background: var(--button-soft-bg);
|
| 62 |
+
color: var(--button-soft-text);
|
| 63 |
+
font-weight: 600;
|
| 64 |
+
font-size: 14px;
|
| 65 |
+
cursor: pointer;
|
| 66 |
+
outline: none;
|
| 67 |
+
appearance: none;
|
| 68 |
+
-webkit-appearance: none;
|
| 69 |
+
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='10' height='6'%3E%3Cpath d='M0 0l5 6 5-6z' fill='%23888'/%3E%3C/svg%3E");
|
| 70 |
+
background-repeat: no-repeat;
|
| 71 |
+
background-position: right 10px center;
|
| 72 |
+
padding-right: 26px;
|
| 73 |
+
transition: transform 140ms ease, box-shadow 140ms ease, background 140ms ease;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.transpose-select:hover {
|
| 77 |
+
transform: translateY(-1px);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.transpose-select:focus {
|
| 81 |
+
border-color: var(--accent);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
.icon-toggle {
|
| 85 |
+
width: 40px;
|
| 86 |
+
height: 40px;
|
| 87 |
+
border-radius: 999px;
|
| 88 |
+
border: 1px solid var(--border-soft);
|
| 89 |
+
background: var(--button-ghost-bg);
|
| 90 |
+
color: var(--button-ghost-text);
|
| 91 |
+
display: inline-flex;
|
| 92 |
+
align-items: center;
|
| 93 |
+
justify-content: center;
|
| 94 |
+
font-size: 18px;
|
| 95 |
+
cursor: pointer;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.icon-toggle:hover {
|
| 99 |
+
transform: translateY(-1px);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.lang-label {
|
| 103 |
+
font-size: 14px;
|
| 104 |
+
font-weight: 700;
|
| 105 |
+
line-height: 1;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.audio-bar {
|
| 109 |
+
margin-top: 14px;
|
| 110 |
+
padding: 12px 16px;
|
| 111 |
+
border-radius: 14px;
|
| 112 |
+
background: var(--panel-strong);
|
| 113 |
+
border: 1px solid var(--border-subtle);
|
| 114 |
+
display: flex;
|
| 115 |
+
align-items: center;
|
| 116 |
+
justify-content: space-between;
|
| 117 |
+
gap: 16px;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
.audio-left {
|
| 121 |
+
display: flex;
|
| 122 |
+
align-items: center;
|
| 123 |
+
gap: 12px;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.audio-hint {
|
| 127 |
+
color: var(--text-muted);
|
| 128 |
+
font-size: 12px;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
.audio-right {
|
| 132 |
+
display: flex;
|
| 133 |
+
align-items: center;
|
| 134 |
+
gap: 20px;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.volume-control {
|
| 138 |
+
display: flex;
|
| 139 |
+
align-items: center;
|
| 140 |
+
gap: 8px;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
.volume-label {
|
| 144 |
+
font-size: 12px;
|
| 145 |
+
color: var(--text-muted);
|
| 146 |
+
min-width: 32px;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
.volume-slider {
|
| 150 |
+
width: 80px;
|
| 151 |
+
height: 4px;
|
| 152 |
+
cursor: pointer;
|
| 153 |
+
accent-color: var(--accent);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
.volume-value {
|
| 157 |
+
font-size: 11px;
|
| 158 |
+
color: var(--text-muted);
|
| 159 |
+
min-width: 36px;
|
| 160 |
+
text-align: right;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
.toggle {
|
| 164 |
+
display: inline-flex;
|
| 165 |
+
align-items: center;
|
| 166 |
+
gap: 8px;
|
| 167 |
+
font-size: 13px;
|
| 168 |
+
color: var(--text-primary);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.panel {
|
| 172 |
+
margin-top: 18px;
|
| 173 |
+
background: var(--panel);
|
| 174 |
+
border: 1px solid var(--border-subtle);
|
| 175 |
+
border-radius: 16px;
|
| 176 |
+
padding: 18px;
|
| 177 |
+
box-shadow: var(--shadow-panel);
|
| 178 |
+
display: flex;
|
| 179 |
+
flex-direction: column;
|
| 180 |
+
flex: 1;
|
| 181 |
+
min-height: 0;
|
| 182 |
+
overflow: hidden;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
.panel-split {
|
| 186 |
+
display: grid;
|
| 187 |
+
grid-template-columns: minmax(0, 1fr) 360px;
|
| 188 |
+
gap: 16px;
|
| 189 |
+
align-items: stretch;
|
| 190 |
+
flex: 1;
|
| 191 |
+
min-height: 0;
|
| 192 |
+
max-height: 100%;
|
| 193 |
+
overflow: hidden;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
.panel-main {
|
| 197 |
+
min-width: 0;
|
| 198 |
+
display: flex;
|
| 199 |
+
flex-direction: column;
|
| 200 |
+
min-height: 0;
|
| 201 |
+
max-height: 100%;
|
| 202 |
+
overflow: hidden;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
.panel-side {
|
| 206 |
+
display: flex;
|
| 207 |
+
flex-direction: column;
|
| 208 |
+
gap: 16px;
|
| 209 |
+
width: 360px;
|
| 210 |
+
max-width: 360px;
|
| 211 |
+
/* Use absolute positioning to enforce height */
|
| 212 |
+
position: relative;
|
| 213 |
+
overflow: hidden;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
.controls {
|
| 217 |
+
display: grid;
|
| 218 |
+
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
| 219 |
+
gap: 14px;
|
| 220 |
+
align-items: center;
|
| 221 |
+
background: var(--panel-soft);
|
| 222 |
+
padding: 12px 14px;
|
| 223 |
+
border-radius: 12px;
|
| 224 |
+
border: 1px solid var(--border-soft);
|
| 225 |
+
flex-shrink: 0;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
.controls label {
|
| 229 |
+
display: block;
|
| 230 |
+
font-size: 12px;
|
| 231 |
+
text-transform: uppercase;
|
| 232 |
+
letter-spacing: 1px;
|
| 233 |
+
color: var(--text-muted);
|
| 234 |
+
margin-bottom: 4px;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
.controls input[type='number'] {
|
| 238 |
+
width: 100%;
|
| 239 |
+
padding: 10px 12px;
|
| 240 |
+
border-radius: 10px;
|
| 241 |
+
border: 1px solid var(--border-soft);
|
| 242 |
+
background: var(--input-bg);
|
| 243 |
+
color: var(--text-primary);
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
.timesig {
|
| 247 |
+
display: flex;
|
| 248 |
+
align-items: center;
|
| 249 |
+
gap: 6px;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.timesig span {
|
| 253 |
+
font-weight: 700;
|
| 254 |
+
color: var(--text-muted);
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
.transport {
|
| 258 |
+
display: flex;
|
| 259 |
+
gap: 6px;
|
| 260 |
+
align-items: center;
|
| 261 |
+
grid-column: 1 / -1;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.transport button {
|
| 265 |
+
padding: 6px 10px !important;
|
| 266 |
+
font-size: 13px !important;
|
| 267 |
+
min-width: 0;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
.status {
|
| 271 |
+
grid-column: 1 / -1;
|
| 272 |
+
color: var(--text-muted);
|
| 273 |
+
font-size: 13px;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
.transport-divider {
|
| 277 |
+
width: 1px;
|
| 278 |
+
height: 20px;
|
| 279 |
+
background: var(--border-soft);
|
| 280 |
+
margin: 0 2px;
|
| 281 |
+
flex-shrink: 0;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.selection-btn {
|
| 285 |
+
font-size: 12px !important;
|
| 286 |
+
padding: 6px 10px !important;
|
| 287 |
+
white-space: nowrap;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
.selection-btn.active {
|
| 291 |
+
background: var(--accent) !important;
|
| 292 |
+
color: white !important;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.button,
|
| 296 |
+
.actions button,
|
| 297 |
+
.transport button,
|
| 298 |
+
.ghost,
|
| 299 |
+
.primary,
|
| 300 |
+
.json-btn,
|
| 301 |
+
.soft {
|
| 302 |
+
cursor: pointer;
|
| 303 |
+
border-radius: 12px;
|
| 304 |
+
border: 1px solid transparent;
|
| 305 |
+
padding: 10px 14px;
|
| 306 |
+
font-weight: 600;
|
| 307 |
+
transition: transform 140ms ease, box-shadow 140ms ease, background 140ms ease, border 140ms ease;
|
| 308 |
+
color: #0f1528;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
.ghost {
|
| 312 |
+
background: var(--button-ghost-bg);
|
| 313 |
+
color: var(--button-ghost-text);
|
| 314 |
+
border-color: var(--border-soft);
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
.primary {
|
| 318 |
+
background: linear-gradient(135deg, var(--accent), var(--accent-strong));
|
| 319 |
+
color: var(--button-primary-text);
|
| 320 |
+
box-shadow: 0 8px 26px rgba(72, 228, 194, 0.2);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
.json-btn {
|
| 324 |
+
background: linear-gradient(135deg, #f59e0b, #d97706);
|
| 325 |
+
color: #fff;
|
| 326 |
+
box-shadow: 0 8px 26px rgba(245, 158, 11, 0.2);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
.soft {
|
| 330 |
+
background: var(--button-soft-bg);
|
| 331 |
+
color: var(--button-soft-text);
|
| 332 |
+
border: 1px solid var(--border-soft);
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
.ghost:disabled,
|
| 336 |
+
.primary:disabled,
|
| 337 |
+
.json-btn:disabled,
|
| 338 |
+
.soft:disabled {
|
| 339 |
+
opacity: 0.6;
|
| 340 |
+
cursor: not-allowed;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
.ghost:hover,
|
| 344 |
+
.primary:hover,
|
| 345 |
+
.json-btn:hover,
|
| 346 |
+
.soft:hover {
|
| 347 |
+
transform: translateY(-1px);
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
.piano-shell {
|
| 351 |
+
border-radius: 12px;
|
| 352 |
+
background: var(--panel-strong);
|
| 353 |
+
border: 1px solid var(--border-subtle);
|
| 354 |
+
overflow: hidden;
|
| 355 |
+
flex: 1;
|
| 356 |
+
min-height: 0;
|
| 357 |
+
max-height: 100%;
|
| 358 |
+
display: flex;
|
| 359 |
+
flex-direction: column;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
.ruler {
|
| 363 |
+
position: relative;
|
| 364 |
+
height: 32px;
|
| 365 |
+
background: var(--panel-soft);
|
| 366 |
+
border-bottom: 1px solid var(--border-soft);
|
| 367 |
+
min-width: 100%;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
.ruler-shell {
|
| 371 |
+
display: flex;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
.ruler-spacer {
|
| 375 |
+
background: var(--panel-soft);
|
| 376 |
+
border-bottom: 1px solid var(--border-soft);
|
| 377 |
+
height: 32px;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
.ruler-scroll {
|
| 381 |
+
overflow: hidden;
|
| 382 |
+
flex: 1;
|
| 383 |
+
height: 32px;
|
| 384 |
+
cursor: pointer;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
.measure-mark {
|
| 388 |
+
position: absolute;
|
| 389 |
+
top: 0;
|
| 390 |
+
height: 100%;
|
| 391 |
+
display: flex;
|
| 392 |
+
flex-direction: column;
|
| 393 |
+
align-items: flex-start;
|
| 394 |
+
font-size: 10px;
|
| 395 |
+
color: var(--text-muted);
|
| 396 |
+
padding-left: 4px;
|
| 397 |
+
border-left: 1px solid var(--border-soft);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
.measure-mark span {
|
| 401 |
+
margin-top: 2px;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
.ruler-playhead {
|
| 405 |
+
position: absolute;
|
| 406 |
+
top: 0;
|
| 407 |
+
width: 2px;
|
| 408 |
+
height: 100%;
|
| 409 |
+
background: #ff7043;
|
| 410 |
+
pointer-events: none;
|
| 411 |
+
z-index: 10;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
.ruler-scroll.selecting {
|
| 415 |
+
cursor: crosshair;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
.selection-range {
|
| 419 |
+
position: absolute;
|
| 420 |
+
top: 0;
|
| 421 |
+
height: 100%;
|
| 422 |
+
background: rgba(66, 165, 245, 0.35);
|
| 423 |
+
border-left: 2px solid #42a5f5;
|
| 424 |
+
border-right: 2px solid #42a5f5;
|
| 425 |
+
pointer-events: none;
|
| 426 |
+
z-index: 5;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
.grid-selection-range {
|
| 430 |
+
position: absolute;
|
| 431 |
+
top: 0;
|
| 432 |
+
background: rgba(66, 165, 245, 0.15);
|
| 433 |
+
border-left: 2px dashed #42a5f5;
|
| 434 |
+
border-right: 2px dashed #42a5f5;
|
| 435 |
+
pointer-events: none;
|
| 436 |
+
z-index: 1;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
.roll-body {
|
| 440 |
+
display: flex;
|
| 441 |
+
flex: 1;
|
| 442 |
+
min-height: 0;
|
| 443 |
+
overflow: hidden;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
.pitch-rail {
|
| 447 |
+
background: var(--panel-strong);
|
| 448 |
+
border-right: 1px solid var(--border-subtle);
|
| 449 |
+
color: var(--text-primary);
|
| 450 |
+
font-size: 12px;
|
| 451 |
+
text-align: right;
|
| 452 |
+
overflow: hidden;
|
| 453 |
+
flex-shrink: 0;
|
| 454 |
+
height: 100%;
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
.pitch-cell {
|
| 458 |
+
border-bottom: 1px solid var(--border-soft);
|
| 459 |
+
display: flex;
|
| 460 |
+
align-items: center;
|
| 461 |
+
justify-content: flex-end;
|
| 462 |
+
padding: 0 4px;
|
| 463 |
+
font-variant-numeric: tabular-nums;
|
| 464 |
+
box-sizing: border-box;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
.pitch-white {
|
| 468 |
+
background: rgba(255, 255, 255, 0.06);
|
| 469 |
+
color: var(--text-primary);
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
.pitch-black {
|
| 473 |
+
background: rgba(0, 0, 0, 0.35);
|
| 474 |
+
color: rgba(233, 238, 247, 0.9);
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
.pitch-c {
|
| 478 |
+
background: rgba(100, 150, 255, 0.15);
|
| 479 |
+
font-weight: 600;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
.pitch-label {
|
| 483 |
+
font-size: 10px;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
.roll-grid {
|
| 487 |
+
position: relative;
|
| 488 |
+
overflow: auto;
|
| 489 |
+
flex: 1;
|
| 490 |
+
min-height: 0;
|
| 491 |
+
background-color: var(--grid-bg);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
.grid-content {
|
| 495 |
+
background-color: var(--grid-bg);
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
.grid-svg {
|
| 499 |
+
shape-rendering: crispEdges;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
.grid-overlay {
|
| 503 |
+
position: relative;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
.note-chip {
|
| 507 |
+
position: absolute;
|
| 508 |
+
background: linear-gradient(135deg, var(--accent), var(--accent-strong));
|
| 509 |
+
border-radius: 6px;
|
| 510 |
+
border: 1px solid rgba(255, 255, 255, 0.16);
|
| 511 |
+
box-shadow: 0 10px 22px rgba(0, 0, 0, 0.25);
|
| 512 |
+
display: flex;
|
| 513 |
+
align-items: center;
|
| 514 |
+
justify-content: center;
|
| 515 |
+
color: var(--note-text);
|
| 516 |
+
font-weight: 700;
|
| 517 |
+
user-select: none;
|
| 518 |
+
box-sizing: border-box;
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
.note-active {
|
| 522 |
+
outline: 2px solid #ff7043;
|
| 523 |
+
z-index: 2;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
.note-overlap {
|
| 527 |
+
background: linear-gradient(135deg, #ef5350 0%, #ff7043 100%) !important;
|
| 528 |
+
animation: pulse-overlap 1s ease-in-out infinite;
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
/* Selected overlapping note - more visible outline */
|
| 532 |
+
.note-overlap.note-active {
|
| 533 |
+
outline: 3px solid #1e40af;
|
| 534 |
+
outline-offset: 1px;
|
| 535 |
+
box-shadow: 0 0 12px rgba(30, 64, 175, 0.8);
|
| 536 |
+
animation: none;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
@keyframes pulse-overlap {
|
| 540 |
+
0%, 100% { opacity: 1; }
|
| 541 |
+
50% { opacity: 0.7; }
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
.playhead {
|
| 545 |
+
position: absolute;
|
| 546 |
+
top: 0;
|
| 547 |
+
width: 2px;
|
| 548 |
+
background: #ff7043;
|
| 549 |
+
box-shadow: 0 0 12px rgba(255, 112, 67, 0.6);
|
| 550 |
+
pointer-events: none;
|
| 551 |
+
z-index: 20;
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
.pitch-rail-inner {
|
| 555 |
+
will-change: transform;
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
.note-label {
|
| 559 |
+
width: 100%;
|
| 560 |
+
text-align: center;
|
| 561 |
+
font-size: 12px;
|
| 562 |
+
padding: 0 12px;
|
| 563 |
+
overflow: hidden;
|
| 564 |
+
text-overflow: ellipsis;
|
| 565 |
+
white-space: nowrap;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
.note-handle {
|
| 569 |
+
position: absolute;
|
| 570 |
+
top: 0;
|
| 571 |
+
width: 8px;
|
| 572 |
+
height: 100%;
|
| 573 |
+
background: rgba(255, 255, 255, 0.25);
|
| 574 |
+
cursor: ew-resize;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
.note-handle.start {
|
| 578 |
+
left: 0;
|
| 579 |
+
border-radius: 6px 0 0 6px;
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
.note-handle.end {
|
| 583 |
+
right: 0;
|
| 584 |
+
border-radius: 0 6px 6px 0;
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
.lyric-container {
|
| 588 |
+
flex: 1;
|
| 589 |
+
min-height: 0;
|
| 590 |
+
position: relative;
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
.lyric-card {
|
| 594 |
+
border: 1px solid rgba(255, 255, 255, 0.06);
|
| 595 |
+
border-radius: 12px;
|
| 596 |
+
background: var(--panel-soft);
|
| 597 |
+
overflow: hidden;
|
| 598 |
+
display: flex;
|
| 599 |
+
flex-direction: column;
|
| 600 |
+
/* Force fixed height with absolute positioning */
|
| 601 |
+
position: absolute;
|
| 602 |
+
top: 0;
|
| 603 |
+
left: 0;
|
| 604 |
+
right: 0;
|
| 605 |
+
bottom: 0;
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
.lyric-bulk {
|
| 609 |
+
display: flex;
|
| 610 |
+
gap: 8px;
|
| 611 |
+
padding: 10px 12px;
|
| 612 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.06);
|
| 613 |
+
align-items: center;
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
.lyric-bulk-input {
|
| 617 |
+
flex: 1;
|
| 618 |
+
padding: 8px 10px;
|
| 619 |
+
border-radius: 10px;
|
| 620 |
+
border: 1px solid var(--border-soft);
|
| 621 |
+
background: var(--input-bg);
|
| 622 |
+
color: var(--text-primary);
|
| 623 |
+
resize: vertical;
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
.lyric-header,
|
| 627 |
+
.lyric-row {
|
| 628 |
+
display: grid;
|
| 629 |
+
grid-template-columns: 1.4fr 0.5fr 0.5fr 0.5fr;
|
| 630 |
+
gap: 8px;
|
| 631 |
+
padding: 10px 12px;
|
| 632 |
+
align-items: center;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
.lyric-header {
|
| 636 |
+
font-size: 12px;
|
| 637 |
+
text-transform: uppercase;
|
| 638 |
+
letter-spacing: 1px;
|
| 639 |
+
color: var(--text-muted);
|
| 640 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.06);
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
.lyric-list {
|
| 644 |
+
overflow-y: auto;
|
| 645 |
+
overflow-x: hidden;
|
| 646 |
+
flex: 1;
|
| 647 |
+
min-height: 0;
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
.lyric-row {
|
| 651 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.04);
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
.lyric-row:hover {
|
| 655 |
+
background: rgba(255, 255, 255, 0.03);
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
.lyric-row-active {
|
| 659 |
+
background: rgba(72, 228, 194, 0.08);
|
| 660 |
+
border-left: 3px solid #48e4c2;
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
.lyric-input {
|
| 664 |
+
width: 100%;
|
| 665 |
+
padding: 8px 10px;
|
| 666 |
+
border-radius: 10px;
|
| 667 |
+
border: 1px solid var(--border-soft);
|
| 668 |
+
background: var(--input-bg);
|
| 669 |
+
color: var(--text-primary);
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
.lyric-meta {
|
| 673 |
+
color: var(--text-muted);
|
| 674 |
+
font-variant-numeric: tabular-nums;
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
.editable-cell {
|
| 678 |
+
position: relative;
|
| 679 |
+
display: flex;
|
| 680 |
+
align-items: center;
|
| 681 |
+
gap: 2px;
|
| 682 |
+
}
|
| 683 |
+
|
| 684 |
+
.lyric-meta-input {
|
| 685 |
+
width: 100%;
|
| 686 |
+
padding: 2px 4px;
|
| 687 |
+
border: 1px solid transparent;
|
| 688 |
+
border-radius: 4px;
|
| 689 |
+
background: transparent;
|
| 690 |
+
color: var(--text-muted);
|
| 691 |
+
font-size: 12px;
|
| 692 |
+
font-variant-numeric: tabular-nums;
|
| 693 |
+
text-align: center;
|
| 694 |
+
outline: none;
|
| 695 |
+
transition: border-color 0.15s, background-color 0.15s;
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
.lyric-meta-input:hover {
|
| 699 |
+
background: var(--surface-elevated);
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
.lyric-meta-input:focus {
|
| 703 |
+
border-color: var(--accent);
|
| 704 |
+
background: var(--surface-elevated);
|
| 705 |
+
color: var(--text-primary);
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
.lyric-meta-dirty {
|
| 709 |
+
border-color: #f59e0b !important;
|
| 710 |
+
background: rgba(245, 158, 11, 0.1) !important;
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
.confirm-btn {
|
| 714 |
+
flex-shrink: 0;
|
| 715 |
+
width: 18px;
|
| 716 |
+
height: 18px;
|
| 717 |
+
padding: 0;
|
| 718 |
+
border: none;
|
| 719 |
+
border-radius: 4px;
|
| 720 |
+
background: #22c55e;
|
| 721 |
+
color: white;
|
| 722 |
+
font-size: 12px;
|
| 723 |
+
font-weight: bold;
|
| 724 |
+
cursor: pointer;
|
| 725 |
+
display: flex;
|
| 726 |
+
align-items: center;
|
| 727 |
+
justify-content: center;
|
| 728 |
+
transition: background 0.15s;
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
.confirm-btn:hover {
|
| 732 |
+
background: #16a34a;
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
/* Hide number input spinners */
|
| 736 |
+
.lyric-meta-input::-webkit-outer-spin-button,
|
| 737 |
+
.lyric-meta-input::-webkit-inner-spin-button {
|
| 738 |
+
-webkit-appearance: none;
|
| 739 |
+
margin: 0;
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
.lyric-meta-input[type=number] {
|
| 743 |
+
-moz-appearance: textfield;
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
.lyric-empty {
|
| 747 |
+
padding: 16px;
|
| 748 |
+
color: var(--text-muted);
|
| 749 |
+
text-align: center;
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
.audio-track {
|
| 753 |
+
display: grid;
|
| 754 |
+
grid-template-columns: 80px 1fr;
|
| 755 |
+
gap: 12px;
|
| 756 |
+
align-items: center;
|
| 757 |
+
padding: 12px 14px;
|
| 758 |
+
border-radius: 12px;
|
| 759 |
+
border: 1px solid var(--border-soft);
|
| 760 |
+
background: var(--panel-soft);
|
| 761 |
+
margin-bottom: 12px;
|
| 762 |
+
flex-shrink: 0;
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
.audio-track-label {
|
| 766 |
+
font-size: 12px;
|
| 767 |
+
text-transform: uppercase;
|
| 768 |
+
letter-spacing: 1px;
|
| 769 |
+
color: var(--text-muted);
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
.audio-wave {
|
| 773 |
+
width: 100%;
|
| 774 |
+
height: 80px;
|
| 775 |
+
min-height: 80px;
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
:root {
|
| 779 |
+
--text-primary: #e9eef7;
|
| 780 |
+
--text-muted: rgba(233, 238, 247, 0.7);
|
| 781 |
+
--panel: rgba(13, 16, 28, 0.8);
|
| 782 |
+
--panel-strong: rgba(16, 21, 35, 0.95);
|
| 783 |
+
--panel-soft: rgba(255, 255, 255, 0.03);
|
| 784 |
+
--border-subtle: rgba(255, 255, 255, 0.08);
|
| 785 |
+
--border-soft: rgba(255, 255, 255, 0.12);
|
| 786 |
+
--input-bg: rgba(255, 255, 255, 0.06);
|
| 787 |
+
--grid-bg: rgba(14, 18, 30, 0.9);
|
| 788 |
+
--grid-line-minor: rgba(233, 238, 247, 0.08);
|
| 789 |
+
--grid-line-major: rgba(233, 238, 247, 0.16);
|
| 790 |
+
--accent: #48e4c2;
|
| 791 |
+
--accent-strong: #4b64bc;
|
| 792 |
+
--note-text: #0b1122;
|
| 793 |
+
--button-ghost-bg: rgba(233, 238, 247, 0.18);
|
| 794 |
+
--button-ghost-text: #ffffff;
|
| 795 |
+
--button-soft-bg: rgba(255, 255, 255, 0.14);
|
| 796 |
+
--button-soft-text: #ffffff;
|
| 797 |
+
--button-primary-text: #0b1122;
|
| 798 |
+
--shadow-panel: 0 18px 40px rgba(0, 0, 0, 0.32);
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
:root[data-theme='light'] {
|
| 802 |
+
--text-primary: #1b2238;
|
| 803 |
+
--text-muted: rgba(27, 34, 56, 0.7);
|
| 804 |
+
--panel: rgba(255, 255, 255, 0.9);
|
| 805 |
+
--panel-strong: rgba(250, 252, 255, 0.98);
|
| 806 |
+
--panel-soft: rgba(15, 23, 42, 0.04);
|
| 807 |
+
--border-subtle: rgba(15, 23, 42, 0.12);
|
| 808 |
+
--border-soft: rgba(15, 23, 42, 0.16);
|
| 809 |
+
--input-bg: rgba(15, 23, 42, 0.06);
|
| 810 |
+
--grid-bg: rgba(248, 250, 255, 0.95);
|
| 811 |
+
--grid-line-minor: rgba(15, 23, 42, 0.12);
|
| 812 |
+
--grid-line-major: rgba(15, 23, 42, 0.24);
|
| 813 |
+
--accent: #3f8cff;
|
| 814 |
+
--accent-strong: #4b64bc;
|
| 815 |
+
--note-text: #ffffff;
|
| 816 |
+
--button-ghost-bg: rgba(15, 23, 42, 0.06);
|
| 817 |
+
--button-ghost-text: #1b2238;
|
| 818 |
+
--button-soft-bg: rgba(15, 23, 42, 0.06);
|
| 819 |
+
--button-soft-text: #1b2238;
|
| 820 |
+
--button-primary-text: #0b1122;
|
| 821 |
+
--shadow-panel: 0 18px 40px rgba(15, 23, 42, 0.15);
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
.sr-only {
|
| 825 |
+
position: absolute;
|
| 826 |
+
width: 1px;
|
| 827 |
+
height: 1px;
|
| 828 |
+
padding: 0;
|
| 829 |
+
margin: -1px;
|
| 830 |
+
overflow: hidden;
|
| 831 |
+
clip: rect(0, 0, 0, 0);
|
| 832 |
+
white-space: nowrap;
|
| 833 |
+
border: 0;
|
| 834 |
+
}
|
preprocess/tools/midi_editor/src/App.tsx
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
| 2 |
+
import * as Tone from 'tone'
|
| 3 |
+
import { PianoRoll } from './components/PianoRoll'
|
| 4 |
+
import { LyricTable } from './components/LyricTable'
|
| 5 |
+
import { AudioTrack } from './components/AudioTrack'
|
| 6 |
+
import { useMidiStore } from './store/useMidiStore'
|
| 7 |
+
import { exportMidi, importMidiFile } from './lib/midi'
|
| 8 |
+
import type { TimeSignature } from './types'
|
| 9 |
+
import type { Lang } from './i18n'
|
| 10 |
+
import { getTranslations } from './i18n'
|
| 11 |
+
import { BASE_GRID_SECOND_WIDTH, BASE_ROW_HEIGHT, LOW_NOTE, HIGH_NOTE } from './constants'
|
| 12 |
+
import './App.css'
|
| 13 |
+
|
| 14 |
+
type PlayEvent = {
|
| 15 |
+
time: number
|
| 16 |
+
midi: number
|
| 17 |
+
duration: number
|
| 18 |
+
velocity: number
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
function App() {
|
| 22 |
+
const {
|
| 23 |
+
notes,
|
| 24 |
+
tempo,
|
| 25 |
+
timeSignature,
|
| 26 |
+
selectedId,
|
| 27 |
+
playhead,
|
| 28 |
+
ppq,
|
| 29 |
+
addNote,
|
| 30 |
+
updateNote,
|
| 31 |
+
removeNote,
|
| 32 |
+
setNotes,
|
| 33 |
+
setTempo,
|
| 34 |
+
setTimeSignature,
|
| 35 |
+
setPpq,
|
| 36 |
+
select,
|
| 37 |
+
setPlayhead,
|
| 38 |
+
} = useMidiStore()
|
| 39 |
+
|
| 40 |
+
const [lang, setLang] = useState<Lang>('zh')
|
| 41 |
+
const t = getTranslations(lang)
|
| 42 |
+
|
| 43 |
+
const [status, setStatus] = useState(t.ready)
|
| 44 |
+
const [isPlaying, setIsPlaying] = useState(false)
|
| 45 |
+
const [theme, setTheme] = useState<'dark' | 'light'>('light')
|
| 46 |
+
const [audioUrl, setAudioUrl] = useState<string | null>(null)
|
| 47 |
+
const [audioDuration, setAudioDuration] = useState(0)
|
| 48 |
+
const [midiVolume, setMidiVolume] = useState(80) // 0-100
|
| 49 |
+
const [audioVolume, setAudioVolume] = useState(80) // 0-100
|
| 50 |
+
const [horizontalZoom, setHorizontalZoom] = useState(1)
|
| 51 |
+
const [verticalZoom, setVerticalZoom] = useState(1)
|
| 52 |
+
const [focusLyricId, setFocusLyricId] = useState<string | null>(null)
|
| 53 |
+
// Selection range for loop playback (in seconds)
|
| 54 |
+
const [selectionStart, setSelectionStart] = useState<number | null>(null)
|
| 55 |
+
const [selectionEnd, setSelectionEnd] = useState<number | null>(null)
|
| 56 |
+
const [isSelectingRange, setIsSelectingRange] = useState(false)
|
| 57 |
+
const fileInputRef = useRef<HTMLInputElement | null>(null)
|
| 58 |
+
const audioInputRef = useRef<HTMLInputElement | null>(null)
|
| 59 |
+
const audioRef = useRef<HTMLAudioElement | null>(null)
|
| 60 |
+
const partRef = useRef<Tone.Part<PlayEvent> | null>(null)
|
| 61 |
+
const synthRef = useRef<Tone.PolySynth | null>(null)
|
| 62 |
+
const rafRef = useRef<number | null>(null)
|
| 63 |
+
const audioScrollRef = useRef<HTMLDivElement | null>(null)
|
| 64 |
+
|
| 65 |
+
useEffect(() => {
|
| 66 |
+
return () => {
|
| 67 |
+
stopPlayback()
|
| 68 |
+
synthRef.current?.dispose()
|
| 69 |
+
}
|
| 70 |
+
}, [])
|
| 71 |
+
|
| 72 |
+
useEffect(() => {
|
| 73 |
+
document.documentElement.dataset.theme = theme
|
| 74 |
+
}, [theme])
|
| 75 |
+
|
| 76 |
+
// Update status text when language changes
|
| 77 |
+
useEffect(() => {
|
| 78 |
+
setStatus(t.ready)
|
| 79 |
+
}, [lang])
|
| 80 |
+
|
| 81 |
+
// Sync audio volume - also trigger when audioUrl changes (new audio loaded)
|
| 82 |
+
useEffect(() => {
|
| 83 |
+
if (audioRef.current) {
|
| 84 |
+
audioRef.current.volume = audioVolume / 100
|
| 85 |
+
}
|
| 86 |
+
}, [audioVolume, audioUrl])
|
| 87 |
+
|
| 88 |
+
// Sync MIDI synth volume
|
| 89 |
+
useEffect(() => {
|
| 90 |
+
if (synthRef.current) {
|
| 91 |
+
// Convert 0-100 to dB scale (-60 to 0)
|
| 92 |
+
const dbValue = midiVolume === 0 ? -Infinity : (midiVolume / 100) * 60 - 60
|
| 93 |
+
synthRef.current.volume.value = dbValue
|
| 94 |
+
}
|
| 95 |
+
}, [midiVolume])
|
| 96 |
+
|
| 97 |
+
useEffect(() => {
|
| 98 |
+
if (!audioUrl) return
|
| 99 |
+
return () => {
|
| 100 |
+
URL.revokeObjectURL(audioUrl)
|
| 101 |
+
}
|
| 102 |
+
}, [audioUrl])
|
| 103 |
+
|
| 104 |
+
const ensureSynth = async () => {
|
| 105 |
+
await Tone.start()
|
| 106 |
+
if (!synthRef.current) {
|
| 107 |
+
synthRef.current = new Tone.PolySynth(Tone.Synth).toDestination()
|
| 108 |
+
// Apply current volume
|
| 109 |
+
const dbValue = midiVolume === 0 ? -Infinity : (midiVolume / 100) * 60 - 60
|
| 110 |
+
synthRef.current.volume.value = dbValue
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
const playPreviewNote = useCallback(async (midi: number) => {
|
| 115 |
+
await ensureSynth()
|
| 116 |
+
const frequency = Tone.Frequency(midi, 'midi').toFrequency()
|
| 117 |
+
synthRef.current?.triggerAttackRelease(frequency, '8n', Tone.now(), 0.7)
|
| 118 |
+
}, [midiVolume])
|
| 119 |
+
|
| 120 |
+
useEffect(() => {
|
| 121 |
+
const onKeyDown = (event: KeyboardEvent) => {
|
| 122 |
+
if (!selectedId) return
|
| 123 |
+
const target = event.target as HTMLElement | null
|
| 124 |
+
if (target && ['INPUT', 'TEXTAREA'].includes(target.tagName)) return
|
| 125 |
+
|
| 126 |
+
// Delete note
|
| 127 |
+
if (event.key === 'Backspace' || event.key === 'Delete') {
|
| 128 |
+
event.preventDefault()
|
| 129 |
+
removeNote(selectedId)
|
| 130 |
+
select(null)
|
| 131 |
+
return
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
// Cmd/Ctrl + Up/Down to adjust pitch
|
| 135 |
+
const isCmdOrCtrl = event.metaKey || event.ctrlKey
|
| 136 |
+
if (isCmdOrCtrl && (event.key === 'ArrowUp' || event.key === 'ArrowDown')) {
|
| 137 |
+
event.preventDefault()
|
| 138 |
+
const selectedNote = notes.find(n => n.id === selectedId)
|
| 139 |
+
if (!selectedNote) return
|
| 140 |
+
|
| 141 |
+
const delta = event.key === 'ArrowUp' ? 1 : -1
|
| 142 |
+
const newMidi = Math.max(LOW_NOTE, Math.min(HIGH_NOTE, selectedNote.midi + delta))
|
| 143 |
+
|
| 144 |
+
if (newMidi !== selectedNote.midi) {
|
| 145 |
+
updateNote(selectedId, { midi: newMidi })
|
| 146 |
+
playPreviewNote(newMidi)
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
window.addEventListener('keydown', onKeyDown)
|
| 151 |
+
return () => window.removeEventListener('keydown', onKeyDown)
|
| 152 |
+
}, [selectedId, notes, removeNote, select, updateNote, playPreviewNote])
|
| 153 |
+
|
| 154 |
+
const noteEvents = useMemo<PlayEvent[]>(
|
| 155 |
+
() =>
|
| 156 |
+
notes.map((note) => ({
|
| 157 |
+
time: (60 / tempo) * note.start,
|
| 158 |
+
duration: (60 / tempo) * note.duration,
|
| 159 |
+
midi: note.midi,
|
| 160 |
+
velocity: note.velocity,
|
| 161 |
+
})),
|
| 162 |
+
[notes, tempo],
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
const beatToSeconds = (beat: number) => beat * (60 / tempo)
|
| 166 |
+
const secondsToBeat = (seconds: number) => seconds / (60 / tempo)
|
| 167 |
+
const seekBySeconds = (deltaSeconds: number) => {
|
| 168 |
+
const maxNoteEnd = notes.reduce((acc, n) => Math.max(acc, n.start + n.duration), 0)
|
| 169 |
+
const maxBeat = Math.max(secondsToBeat(audioDuration), maxNoteEnd)
|
| 170 |
+
const nextSeconds = Math.max(0, Math.min(beatToSeconds(maxBeat), beatToSeconds(playhead) + deltaSeconds))
|
| 171 |
+
seekToBeat(secondsToBeat(nextSeconds))
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
const gridSecondWidth = BASE_GRID_SECOND_WIDTH * horizontalZoom
|
| 175 |
+
const rowHeight = BASE_ROW_HEIGHT * verticalZoom
|
| 176 |
+
|
| 177 |
+
// Calculate MIDI content width to sync with audio track
|
| 178 |
+
const midiContentWidth = useMemo(() => {
|
| 179 |
+
const noteEndSeconds = notes.reduce((acc, n) => {
|
| 180 |
+
const endBeat = n.start + n.duration
|
| 181 |
+
return Math.max(acc, beatToSeconds(endBeat))
|
| 182 |
+
}, 8)
|
| 183 |
+
const maxSeconds = Math.max(noteEndSeconds + 10, audioDuration + 10, 30)
|
| 184 |
+
return maxSeconds * gridSecondWidth
|
| 185 |
+
}, [notes, audioDuration, gridSecondWidth, beatToSeconds])
|
| 186 |
+
|
| 187 |
+
const seekToBeat = (beat: number) => {
|
| 188 |
+
setPlayhead(beat)
|
| 189 |
+
Tone.Transport.seconds = beatToSeconds(beat)
|
| 190 |
+
if (audioRef.current) {
|
| 191 |
+
audioRef.current.currentTime = beatToSeconds(beat)
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
const schedulePlayback = async () => {
|
| 196 |
+
if (!notes.length && !audioUrl) return
|
| 197 |
+
await ensureSynth()
|
| 198 |
+
partRef.current?.dispose()
|
| 199 |
+
Tone.Transport.cancel()
|
| 200 |
+
Tone.Transport.stop()
|
| 201 |
+
Tone.Transport.bpm.value = tempo
|
| 202 |
+
|
| 203 |
+
// Determine playback range
|
| 204 |
+
const hasSelection = selectionStart !== null && selectionEnd !== null && selectionEnd > selectionStart
|
| 205 |
+
const startSeconds = hasSelection ? selectionStart : beatToSeconds(playhead)
|
| 206 |
+
const endSeconds = hasSelection ? selectionEnd : null
|
| 207 |
+
|
| 208 |
+
Tone.Transport.seconds = startSeconds
|
| 209 |
+
|
| 210 |
+
// Filter notes within selection range if applicable
|
| 211 |
+
const filteredEvents = hasSelection
|
| 212 |
+
? noteEvents.filter(e => e.time >= startSeconds && e.time < endSeconds!)
|
| 213 |
+
: noteEvents
|
| 214 |
+
|
| 215 |
+
if (filteredEvents.length) {
|
| 216 |
+
partRef.current = new Tone.Part((time, event) => {
|
| 217 |
+
if (midiVolume === 0) return
|
| 218 |
+
const frequency = Tone.Frequency(event.midi, 'midi').toFrequency()
|
| 219 |
+
synthRef.current?.triggerAttackRelease(frequency, event.duration, time, event.velocity)
|
| 220 |
+
}, filteredEvents)
|
| 221 |
+
partRef.current.start(0)
|
| 222 |
+
}
|
| 223 |
+
Tone.Transport.start()
|
| 224 |
+
if (audioRef.current && audioUrl) {
|
| 225 |
+
audioRef.current.currentTime = startSeconds
|
| 226 |
+
if (audioVolume > 0) {
|
| 227 |
+
audioRef.current.play().catch(() => null)
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
setIsPlaying(true)
|
| 231 |
+
setStatus(hasSelection ? t.selectionPlayback : t.playing)
|
| 232 |
+
|
| 233 |
+
const tick = () => {
|
| 234 |
+
const seconds =
|
| 235 |
+
audioRef.current && audioUrl && !audioRef.current.paused
|
| 236 |
+
? audioRef.current.currentTime
|
| 237 |
+
: Tone.Transport.seconds
|
| 238 |
+
|
| 239 |
+
// Stop at selection end
|
| 240 |
+
if (endSeconds !== null && seconds >= endSeconds) {
|
| 241 |
+
pausePlayback()
|
| 242 |
+
seekToBeat(secondsToBeat(selectionStart!))
|
| 243 |
+
setStatus(t.selectionDone)
|
| 244 |
+
return
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
const beat = seconds / (60 / tempo)
|
| 248 |
+
setPlayhead(beat)
|
| 249 |
+
rafRef.current = requestAnimationFrame(tick)
|
| 250 |
+
}
|
| 251 |
+
rafRef.current = requestAnimationFrame(tick)
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
const stopPlayback = () => {
|
| 255 |
+
Tone.Transport.stop()
|
| 256 |
+
Tone.Transport.cancel()
|
| 257 |
+
partRef.current?.dispose()
|
| 258 |
+
partRef.current = null
|
| 259 |
+
setIsPlaying(false)
|
| 260 |
+
setPlayhead(0)
|
| 261 |
+
if (audioRef.current) {
|
| 262 |
+
audioRef.current.pause()
|
| 263 |
+
audioRef.current.currentTime = 0
|
| 264 |
+
}
|
| 265 |
+
if (rafRef.current) {
|
| 266 |
+
cancelAnimationFrame(rafRef.current)
|
| 267 |
+
rafRef.current = null
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
const pausePlayback = () => {
|
| 272 |
+
Tone.Transport.stop()
|
| 273 |
+
partRef.current?.dispose()
|
| 274 |
+
partRef.current = null
|
| 275 |
+
setIsPlaying(false)
|
| 276 |
+
if (audioRef.current) {
|
| 277 |
+
audioRef.current.pause()
|
| 278 |
+
}
|
| 279 |
+
if (rafRef.current) {
|
| 280 |
+
cancelAnimationFrame(rafRef.current)
|
| 281 |
+
rafRef.current = null
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
const handlePlayToggle = async () => {
|
| 286 |
+
if (isPlaying) {
|
| 287 |
+
pausePlayback()
|
| 288 |
+
setStatus(t.paused)
|
| 289 |
+
} else {
|
| 290 |
+
await schedulePlayback()
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
const handleImportClick = () => fileInputRef.current?.click()
|
| 295 |
+
const handleAudioImportClick = () => audioInputRef.current?.click()
|
| 296 |
+
|
| 297 |
+
const handleFileChange = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
| 298 |
+
const file = event.target.files?.[0]
|
| 299 |
+
if (!file) return
|
| 300 |
+
|
| 301 |
+
try {
|
| 302 |
+
const snapshot = await importMidiFile(file)
|
| 303 |
+
setNotes(snapshot.notes)
|
| 304 |
+
setTempo(snapshot.tempo)
|
| 305 |
+
setTimeSignature(snapshot.timeSignature as TimeSignature)
|
| 306 |
+
setPpq(snapshot.ppq) // Preserve original ppq for accurate export
|
| 307 |
+
setStatus(t.imported(file.name))
|
| 308 |
+
} catch (error) {
|
| 309 |
+
console.error(error)
|
| 310 |
+
setStatus(t.importFailed)
|
| 311 |
+
} finally {
|
| 312 |
+
event.target.value = ''
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
const handleAudioChange = (event: React.ChangeEvent<HTMLInputElement>) => {
|
| 317 |
+
const file = event.target.files?.[0]
|
| 318 |
+
if (!file) return
|
| 319 |
+
|
| 320 |
+
// Validate audio file type
|
| 321 |
+
const validAudioTypes = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/flac', 'audio/mp4', 'audio/aac', 'audio/x-m4a']
|
| 322 |
+
const validExtensions = ['.mp3', '.wav', '.ogg', '.flac', '.m4a', '.aac']
|
| 323 |
+
const fileName = file.name.toLowerCase()
|
| 324 |
+
const isValidType = validAudioTypes.includes(file.type) || file.type.startsWith('audio/')
|
| 325 |
+
const isValidExtension = validExtensions.some(ext => fileName.endsWith(ext))
|
| 326 |
+
|
| 327 |
+
if (!isValidType && !isValidExtension) {
|
| 328 |
+
setStatus(t.unsupportedFormat(validExtensions.join(', ')))
|
| 329 |
+
event.target.value = ''
|
| 330 |
+
return
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
const url = URL.createObjectURL(file)
|
| 334 |
+
setAudioUrl(url)
|
| 335 |
+
setStatus(t.audioImported(file.name))
|
| 336 |
+
event.target.value = ''
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
// Fix overlapping notes by trimming the first note to end where the second begins
|
| 340 |
+
// Returns the number of fixed overlaps
|
| 341 |
+
const fixOverlaps = (): number => {
|
| 342 |
+
const sortedNotes = [...notes].sort((a, b) => a.start - b.start)
|
| 343 |
+
let fixCount = 0
|
| 344 |
+
|
| 345 |
+
for (let i = 0; i < sortedNotes.length - 1; i++) {
|
| 346 |
+
const noteA = sortedNotes[i]
|
| 347 |
+
const noteB = sortedNotes[i + 1]
|
| 348 |
+
const noteAEnd = noteA.start + noteA.duration
|
| 349 |
+
|
| 350 |
+
// If noteA overlaps with noteB
|
| 351 |
+
if (noteAEnd > noteB.start) {
|
| 352 |
+
// Trim noteA to end at noteB's start
|
| 353 |
+
const newDuration = Math.max(0.01, noteB.start - noteA.start)
|
| 354 |
+
updateNote(noteA.id, { duration: newDuration })
|
| 355 |
+
fixCount++
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
return fixCount
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// UI handler for fix overlaps button
|
| 363 |
+
const handleFixOverlaps = () => {
|
| 364 |
+
const fixCount = fixOverlaps()
|
| 365 |
+
if (fixCount > 0) {
|
| 366 |
+
setStatus(t.fixedOverlaps(fixCount))
|
| 367 |
+
} else {
|
| 368 |
+
setStatus(t.noOverlaps)
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
const handleExport = () => {
|
| 373 |
+
// Auto-fix overlaps before export
|
| 374 |
+
fixOverlaps()
|
| 375 |
+
|
| 376 |
+
// Get the latest notes from store (after fix, zustand set is synchronous)
|
| 377 |
+
const latestNotes = useMidiStore.getState().notes
|
| 378 |
+
|
| 379 |
+
const blob = exportMidi({ notes: latestNotes, tempo, timeSignature, ppq })
|
| 380 |
+
const url = URL.createObjectURL(blob)
|
| 381 |
+
const anchor = document.createElement('a')
|
| 382 |
+
anchor.href = url
|
| 383 |
+
anchor.download = 'vocal-midi.mid'
|
| 384 |
+
anchor.click()
|
| 385 |
+
URL.revokeObjectURL(url)
|
| 386 |
+
setStatus(t.exported)
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
const handleTranspose = (semitones: number) => {
|
| 390 |
+
if (semitones === 0 || !notes.length) return
|
| 391 |
+
for (const note of notes) {
|
| 392 |
+
const newMidi = Math.max(0, Math.min(127, note.midi + semitones))
|
| 393 |
+
updateNote(note.id, { midi: newMidi })
|
| 394 |
+
}
|
| 395 |
+
setStatus(t.transposed(semitones))
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
return (
|
| 399 |
+
<div className="app-shell">
|
| 400 |
+
<header className="topbar">
|
| 401 |
+
<div>
|
| 402 |
+
<p className="eyebrow">{t.eyebrow}</p>
|
| 403 |
+
<h1>{t.title}</h1>
|
| 404 |
+
<p className="muted">{t.subtitle}</p>
|
| 405 |
+
</div>
|
| 406 |
+
<div className="actions">
|
| 407 |
+
<button className="primary" onClick={handleImportClick}>
|
| 408 |
+
{t.importMidi}
|
| 409 |
+
</button>
|
| 410 |
+
<button className="primary" onClick={handleExport}>
|
| 411 |
+
{t.exportMidi}
|
| 412 |
+
</button>
|
| 413 |
+
<div className="transpose-group" title={t.transposeTooltip}>
|
| 414 |
+
<select
|
| 415 |
+
className="transpose-select"
|
| 416 |
+
value={0}
|
| 417 |
+
onChange={(e) => {
|
| 418 |
+
const val = Number(e.target.value)
|
| 419 |
+
if (val !== 0) handleTranspose(val)
|
| 420 |
+
e.target.value = '0'
|
| 421 |
+
}}
|
| 422 |
+
>
|
| 423 |
+
<option value={0}>{t.transpose}</option>
|
| 424 |
+
{Array.from({ length: 24 }, (_, i) => i - 12)
|
| 425 |
+
.filter(v => v !== 0)
|
| 426 |
+
.reverse()
|
| 427 |
+
.map(v => (
|
| 428 |
+
<option key={v} value={v}>
|
| 429 |
+
{v > 0 ? `+${v}` : v}
|
| 430 |
+
</option>
|
| 431 |
+
))}
|
| 432 |
+
</select>
|
| 433 |
+
</div>
|
| 434 |
+
<button className="soft" onClick={handleFixOverlaps} title={t.fixOverlapsTooltip}>
|
| 435 |
+
{t.fixOverlaps}
|
| 436 |
+
</button>
|
| 437 |
+
<button className="icon-toggle" onClick={() => setTheme(theme === 'dark' ? 'light' : 'dark')}>
|
| 438 |
+
{theme === 'dark' ? (
|
| 439 |
+
<span className="icon" aria-label={t.switchToLight}>
|
| 440 |
+
☀️
|
| 441 |
+
</span>
|
| 442 |
+
) : (
|
| 443 |
+
<span className="icon" aria-label={t.switchToDark}>
|
| 444 |
+
🌙
|
| 445 |
+
</span>
|
| 446 |
+
)}
|
| 447 |
+
</button>
|
| 448 |
+
<button
|
| 449 |
+
className="icon-toggle"
|
| 450 |
+
onClick={() => setLang(lang === 'zh' ? 'en' : 'zh')}
|
| 451 |
+
title={lang === 'zh' ? 'Switch to English' : '切换到中文'}
|
| 452 |
+
>
|
| 453 |
+
<span className="lang-label">{lang === 'zh' ? 'EN' : '中'}</span>
|
| 454 |
+
</button>
|
| 455 |
+
<input ref={fileInputRef} type="file" accept=".mid,.midi" className="sr-only" onChange={handleFileChange} />
|
| 456 |
+
</div>
|
| 457 |
+
</header>
|
| 458 |
+
|
| 459 |
+
<section className="audio-bar">
|
| 460 |
+
<div className="audio-left">
|
| 461 |
+
<button className="ghost" onClick={handleAudioImportClick}>
|
| 462 |
+
{t.importAudio}
|
| 463 |
+
</button>
|
| 464 |
+
<input
|
| 465 |
+
ref={audioInputRef}
|
| 466 |
+
type="file"
|
| 467 |
+
accept=".mp3,.wav,.ogg,.flac,.m4a,.aac"
|
| 468 |
+
className="sr-only"
|
| 469 |
+
onChange={handleAudioChange}
|
| 470 |
+
/>
|
| 471 |
+
<span className="audio-hint">{t.audioHint}</span>
|
| 472 |
+
</div>
|
| 473 |
+
<div className="audio-right">
|
| 474 |
+
<div className="volume-control">
|
| 475 |
+
<span className="volume-label">{t.midiLabel}</span>
|
| 476 |
+
<input
|
| 477 |
+
type="range"
|
| 478 |
+
min={0}
|
| 479 |
+
max={100}
|
| 480 |
+
value={midiVolume}
|
| 481 |
+
onChange={(e) => setMidiVolume(Number(e.target.value))}
|
| 482 |
+
className="volume-slider"
|
| 483 |
+
/>
|
| 484 |
+
<span className="volume-value">{midiVolume}%</span>
|
| 485 |
+
</div>
|
| 486 |
+
<div className="volume-control">
|
| 487 |
+
<span className="volume-label">{t.audioLabel}</span>
|
| 488 |
+
<input
|
| 489 |
+
type="range"
|
| 490 |
+
min={0}
|
| 491 |
+
max={100}
|
| 492 |
+
value={audioVolume}
|
| 493 |
+
onChange={(e) => setAudioVolume(Number(e.target.value))}
|
| 494 |
+
className="volume-slider"
|
| 495 |
+
/>
|
| 496 |
+
<span className="volume-value">{audioVolume}%</span>
|
| 497 |
+
</div>
|
| 498 |
+
</div>
|
| 499 |
+
</section>
|
| 500 |
+
|
| 501 |
+
<section className="panel panel-split">
|
| 502 |
+
<div className="panel-main">
|
| 503 |
+
{audioUrl && (
|
| 504 |
+
<AudioTrack
|
| 505 |
+
key={audioUrl}
|
| 506 |
+
ref={audioScrollRef}
|
| 507 |
+
audioUrl={audioUrl}
|
| 508 |
+
muted={audioVolume === 0}
|
| 509 |
+
onSeek={(seconds) => seekToBeat(secondsToBeat(seconds))}
|
| 510 |
+
playheadSeconds={beatToSeconds(playhead)}
|
| 511 |
+
gridSecondWidth={gridSecondWidth}
|
| 512 |
+
minContentWidth={midiContentWidth}
|
| 513 |
+
/>
|
| 514 |
+
)}
|
| 515 |
+
<PianoRoll
|
| 516 |
+
notes={notes}
|
| 517 |
+
selectedId={selectedId}
|
| 518 |
+
timeSignature={timeSignature}
|
| 519 |
+
tempo={tempo}
|
| 520 |
+
playhead={playhead}
|
| 521 |
+
selectionStart={selectionStart}
|
| 522 |
+
selectionEnd={selectionEnd}
|
| 523 |
+
onAddNote={addNote}
|
| 524 |
+
onSelect={select}
|
| 525 |
+
onUpdateNote={updateNote}
|
| 526 |
+
onSeek={seekToBeat}
|
| 527 |
+
onScroll={(left) => {
|
| 528 |
+
if (audioScrollRef.current) {
|
| 529 |
+
audioScrollRef.current.scrollLeft = left
|
| 530 |
+
}
|
| 531 |
+
}}
|
| 532 |
+
onZoom={(deltaH, deltaV) => {
|
| 533 |
+
if (deltaH !== 0) {
|
| 534 |
+
setHorizontalZoom(prev => Math.max(0.5, prev + deltaH))
|
| 535 |
+
}
|
| 536 |
+
if (deltaV !== 0) {
|
| 537 |
+
setVerticalZoom(prev => Math.max(0.6, Math.min(2.5, prev + deltaV)))
|
| 538 |
+
}
|
| 539 |
+
}}
|
| 540 |
+
onPlayNote={playPreviewNote}
|
| 541 |
+
onFocusLyric={(noteId) => {
|
| 542 |
+
select(noteId)
|
| 543 |
+
setFocusLyricId(noteId)
|
| 544 |
+
}}
|
| 545 |
+
onSelectionChange={(start, end) => {
|
| 546 |
+
setSelectionStart(start)
|
| 547 |
+
setSelectionEnd(end)
|
| 548 |
+
}}
|
| 549 |
+
isSelectingRange={isSelectingRange}
|
| 550 |
+
audioDuration={audioDuration}
|
| 551 |
+
gridSecondWidth={gridSecondWidth}
|
| 552 |
+
rowHeight={rowHeight}
|
| 553 |
+
/>
|
| 554 |
+
</div>
|
| 555 |
+
<aside className="panel-side">
|
| 556 |
+
<div className="controls">
|
| 557 |
+
<div className="toggle" style={{ justifyContent: 'space-between' }}>
|
| 558 |
+
<span>{t.horizontalZoom}</span>
|
| 559 |
+
<input
|
| 560 |
+
type="range"
|
| 561 |
+
min={0.5}
|
| 562 |
+
max={10}
|
| 563 |
+
step={0.1}
|
| 564 |
+
value={Math.min(horizontalZoom, 10)}
|
| 565 |
+
onChange={(e) => setHorizontalZoom(Number(e.target.value))}
|
| 566 |
+
style={{ width: '140px' }}
|
| 567 |
+
/>
|
| 568 |
+
<span style={{ width: 42, textAlign: 'right' }}>{horizontalZoom.toFixed(1)}x</span>
|
| 569 |
+
</div>
|
| 570 |
+
<div className="toggle" style={{ justifyContent: 'space-between' }}>
|
| 571 |
+
<span>{t.verticalZoom}</span>
|
| 572 |
+
<input
|
| 573 |
+
type="range"
|
| 574 |
+
min={0.6}
|
| 575 |
+
max={2.5}
|
| 576 |
+
step={0.1}
|
| 577 |
+
value={verticalZoom}
|
| 578 |
+
onChange={(e) => setVerticalZoom(Number(e.target.value))}
|
| 579 |
+
style={{ width: '140px' }}
|
| 580 |
+
/>
|
| 581 |
+
<span style={{ width: 42, textAlign: 'right' }}>{verticalZoom.toFixed(1)}x</span>
|
| 582 |
+
</div>
|
| 583 |
+
<div className="transport">
|
| 584 |
+
<button
|
| 585 |
+
className="soft"
|
| 586 |
+
onClick={() => {
|
| 587 |
+
setPlayhead(0)
|
| 588 |
+
seekToBeat(0)
|
| 589 |
+
}}
|
| 590 |
+
title={t.goToStart}
|
| 591 |
+
>
|
| 592 |
+
⏮
|
| 593 |
+
</button>
|
| 594 |
+
<button
|
| 595 |
+
className="soft"
|
| 596 |
+
onClick={() => seekBySeconds(-2)}
|
| 597 |
+
title={t.back2s}
|
| 598 |
+
>
|
| 599 |
+
⏪
|
| 600 |
+
</button>
|
| 601 |
+
<button
|
| 602 |
+
className="primary"
|
| 603 |
+
onClick={handlePlayToggle}
|
| 604 |
+
disabled={!notes.length && !audioUrl}
|
| 605 |
+
title={isPlaying ? t.pause : (selectionStart !== null && selectionEnd !== null ? t.playSelection : t.play)}
|
| 606 |
+
>
|
| 607 |
+
{isPlaying ? '⏸' : '▶'}
|
| 608 |
+
</button>
|
| 609 |
+
<button
|
| 610 |
+
className="soft"
|
| 611 |
+
onClick={() => seekBySeconds(2)}
|
| 612 |
+
title={t.forward2s}
|
| 613 |
+
>
|
| 614 |
+
⏩
|
| 615 |
+
</button>
|
| 616 |
+
<button
|
| 617 |
+
className="soft"
|
| 618 |
+
onClick={() => {
|
| 619 |
+
const maxNoteEnd = notes.reduce((acc, n) => Math.max(acc, n.start + n.duration), 0)
|
| 620 |
+
seekToBeat(Math.max(secondsToBeat(audioDuration), maxNoteEnd))
|
| 621 |
+
}}
|
| 622 |
+
title={t.goToEnd}
|
| 623 |
+
>
|
| 624 |
+
⏭
|
| 625 |
+
</button>
|
| 626 |
+
<span className="transport-divider" />
|
| 627 |
+
<button
|
| 628 |
+
className={`soft selection-btn ${isSelectingRange ? 'active' : ''}`}
|
| 629 |
+
onClick={() => {
|
| 630 |
+
if (isSelectingRange) {
|
| 631 |
+
// Exiting selection mode - auto clear selection
|
| 632 |
+
setIsSelectingRange(false)
|
| 633 |
+
setSelectionStart(null)
|
| 634 |
+
setSelectionEnd(null)
|
| 635 |
+
} else {
|
| 636 |
+
setIsSelectingRange(true)
|
| 637 |
+
}
|
| 638 |
+
}}
|
| 639 |
+
title={isSelectingRange ? t.exitSelectMode : t.setRangeTooltip}
|
| 640 |
+
>
|
| 641 |
+
{isSelectingRange ? `📍 ${t.selectingRange}` : `📍 ${t.setRange}`}
|
| 642 |
+
</button>
|
| 643 |
+
</div>
|
| 644 |
+
<div className="status">{status}</div>
|
| 645 |
+
</div>
|
| 646 |
+
<div className="lyric-container">
|
| 647 |
+
<LyricTable
|
| 648 |
+
notes={notes}
|
| 649 |
+
selectedId={selectedId}
|
| 650 |
+
tempo={tempo}
|
| 651 |
+
focusLyricId={focusLyricId}
|
| 652 |
+
lang={lang}
|
| 653 |
+
onSelect={select}
|
| 654 |
+
onUpdate={updateNote}
|
| 655 |
+
onFocusHandled={() => setFocusLyricId(null)}
|
| 656 |
+
/>
|
| 657 |
+
</div>
|
| 658 |
+
</aside>
|
| 659 |
+
</section>
|
| 660 |
+
<audio
|
| 661 |
+
ref={audioRef}
|
| 662 |
+
src={audioUrl ?? undefined}
|
| 663 |
+
preload="auto"
|
| 664 |
+
className="sr-only"
|
| 665 |
+
onLoadedMetadata={(e) => {
|
| 666 |
+
setAudioDuration(e.currentTarget.duration)
|
| 667 |
+
// Ensure volume is set when audio loads
|
| 668 |
+
e.currentTarget.volume = audioVolume / 100
|
| 669 |
+
}}
|
| 670 |
+
/>
|
| 671 |
+
</div>
|
| 672 |
+
)
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
export default App
|
preprocess/tools/midi_editor/src/components/AudioTrack.tsx
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useRef, forwardRef, useState } from 'react'
|
| 2 |
+
import WaveSurfer from 'wavesurfer.js'
|
| 3 |
+
import { PITCH_WIDTH } from '../constants'
|
| 4 |
+
|
| 5 |
+
export type AudioTrackProps = {
|
| 6 |
+
audioUrl: string | null
|
| 7 |
+
muted: boolean
|
| 8 |
+
onSeek: (seconds: number) => void
|
| 9 |
+
mediaElement?: HTMLAudioElement | null
|
| 10 |
+
playheadSeconds: number
|
| 11 |
+
gridSecondWidth: number
|
| 12 |
+
minContentWidth?: number // Minimum width to match MIDI editor area
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
export const AudioTrack = forwardRef<HTMLDivElement, AudioTrackProps>(
|
| 16 |
+
({ audioUrl, muted, onSeek, playheadSeconds, gridSecondWidth, minContentWidth = 0 }, ref) => {
|
| 17 |
+
const containerRef = useRef<HTMLDivElement | null>(null)
|
| 18 |
+
const waveRef = useRef<WaveSurfer | null>(null)
|
| 19 |
+
const [waveWidth, setWaveWidth] = useState(0)
|
| 20 |
+
|
| 21 |
+
useEffect(() => {
|
| 22 |
+
if (!containerRef.current) return
|
| 23 |
+
if (!audioUrl) {
|
| 24 |
+
try {
|
| 25 |
+
waveRef.current?.destroy()
|
| 26 |
+
} catch {
|
| 27 |
+
// ignore teardown errors
|
| 28 |
+
}
|
| 29 |
+
waveRef.current = null
|
| 30 |
+
setWaveWidth(0)
|
| 31 |
+
return
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
let cancelled = false
|
| 35 |
+
|
| 36 |
+
// Clean up existing instance
|
| 37 |
+
if (waveRef.current) {
|
| 38 |
+
try {
|
| 39 |
+
waveRef.current.destroy()
|
| 40 |
+
} catch {
|
| 41 |
+
// ignore teardown errors
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
waveRef.current = WaveSurfer.create({
|
| 46 |
+
container: containerRef.current,
|
| 47 |
+
waveColor: '#4b64bc',
|
| 48 |
+
progressColor: '#4b64bc',
|
| 49 |
+
cursorColor: 'transparent',
|
| 50 |
+
barWidth: 2,
|
| 51 |
+
barGap: 2,
|
| 52 |
+
height: 60,
|
| 53 |
+
normalize: true,
|
| 54 |
+
minPxPerSec: gridSecondWidth,
|
| 55 |
+
interact: false,
|
| 56 |
+
hideScrollbar: true,
|
| 57 |
+
autoScroll: false,
|
| 58 |
+
})
|
| 59 |
+
|
| 60 |
+
waveRef.current.load(audioUrl).catch(() => null)
|
| 61 |
+
waveRef.current.on('error', () => null)
|
| 62 |
+
|
| 63 |
+
waveRef.current.on('ready', () => {
|
| 64 |
+
if (cancelled || !waveRef.current) return
|
| 65 |
+
const duration = waveRef.current.getDuration()
|
| 66 |
+
const requiredWidth = duration * gridSecondWidth
|
| 67 |
+
setWaveWidth(requiredWidth)
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
return () => {
|
| 71 |
+
cancelled = true
|
| 72 |
+
try {
|
| 73 |
+
waveRef.current?.destroy()
|
| 74 |
+
} catch {
|
| 75 |
+
// ignore teardown errors
|
| 76 |
+
}
|
| 77 |
+
waveRef.current = null
|
| 78 |
+
}
|
| 79 |
+
}, [audioUrl, gridSecondWidth])
|
| 80 |
+
|
| 81 |
+
useEffect(() => {
|
| 82 |
+
if (!waveRef.current) return
|
| 83 |
+
waveRef.current.setOptions({
|
| 84 |
+
waveColor: muted ? '#9aa6b2' : '#4b64bc',
|
| 85 |
+
progressColor: muted ? '#c0c9d4' : '#4b64bc',
|
| 86 |
+
})
|
| 87 |
+
}, [muted])
|
| 88 |
+
|
| 89 |
+
if (!audioUrl) return null
|
| 90 |
+
|
| 91 |
+
// Content width should be at least as wide as MIDI editor
|
| 92 |
+
const contentWidth = Math.max(waveWidth, minContentWidth)
|
| 93 |
+
|
| 94 |
+
return (
|
| 95 |
+
<div
|
| 96 |
+
className="audio-track-row"
|
| 97 |
+
style={{
|
| 98 |
+
display: 'flex',
|
| 99 |
+
borderBottom: '1px solid var(--border-soft)',
|
| 100 |
+
height: '70px',
|
| 101 |
+
flexShrink: 0
|
| 102 |
+
}}
|
| 103 |
+
>
|
| 104 |
+
<div
|
| 105 |
+
className="audio-gutter"
|
| 106 |
+
style={{
|
| 107 |
+
width: PITCH_WIDTH,
|
| 108 |
+
flexShrink: 0,
|
| 109 |
+
background: 'var(--panel-strong)',
|
| 110 |
+
borderRight: '1px solid var(--border-subtle)',
|
| 111 |
+
display: 'flex',
|
| 112 |
+
alignItems: 'center',
|
| 113 |
+
justifyContent: 'center',
|
| 114 |
+
fontSize: '11px',
|
| 115 |
+
color: 'var(--text-muted)',
|
| 116 |
+
fontWeight: 600,
|
| 117 |
+
}}
|
| 118 |
+
>
|
| 119 |
+
AUDIO
|
| 120 |
+
</div>
|
| 121 |
+
|
| 122 |
+
{/* Scroll Mask - Controlled by parent via ref */}
|
| 123 |
+
<div
|
| 124 |
+
ref={ref}
|
| 125 |
+
className="audio-scroll-mask"
|
| 126 |
+
style={{
|
| 127 |
+
flex: 1,
|
| 128 |
+
overflow: 'hidden',
|
| 129 |
+
position: 'relative',
|
| 130 |
+
background: 'var(--panel-soft)',
|
| 131 |
+
}}
|
| 132 |
+
onClick={(e) => {
|
| 133 |
+
const rect = e.currentTarget.getBoundingClientRect()
|
| 134 |
+
const scrollMask = e.currentTarget as HTMLDivElement
|
| 135 |
+
const x = e.clientX - rect.left + scrollMask.scrollLeft
|
| 136 |
+
const seconds = x / gridSecondWidth
|
| 137 |
+
onSeek(seconds)
|
| 138 |
+
}}
|
| 139 |
+
>
|
| 140 |
+
{/* Container that matches MIDI editor width */}
|
| 141 |
+
<div
|
| 142 |
+
className="audio-content"
|
| 143 |
+
style={{
|
| 144 |
+
width: contentWidth > 0 ? contentWidth : '100%',
|
| 145 |
+
height: '100%',
|
| 146 |
+
position: 'relative'
|
| 147 |
+
}}
|
| 148 |
+
>
|
| 149 |
+
{/* WaveSurfer container - only as wide as audio */}
|
| 150 |
+
<div
|
| 151 |
+
ref={containerRef}
|
| 152 |
+
className="wave-container"
|
| 153 |
+
style={{
|
| 154 |
+
width: waveWidth > 0 ? waveWidth : '100%',
|
| 155 |
+
height: '100%',
|
| 156 |
+
position: 'absolute',
|
| 157 |
+
left: 0,
|
| 158 |
+
top: 0
|
| 159 |
+
}}
|
| 160 |
+
/>
|
| 161 |
+
|
| 162 |
+
{/* Custom Playhead */}
|
| 163 |
+
<div
|
| 164 |
+
className="audio-playhead"
|
| 165 |
+
style={{
|
| 166 |
+
position: 'absolute',
|
| 167 |
+
top: 0,
|
| 168 |
+
bottom: 0,
|
| 169 |
+
width: '2px',
|
| 170 |
+
background: '#ff7043',
|
| 171 |
+
boxShadow: '0 0 12px rgba(255, 112, 67, 0.6)',
|
| 172 |
+
left: playheadSeconds * gridSecondWidth,
|
| 173 |
+
zIndex: 10,
|
| 174 |
+
pointerEvents: 'none',
|
| 175 |
+
}}
|
| 176 |
+
/>
|
| 177 |
+
</div>
|
| 178 |
+
</div>
|
| 179 |
+
</div>
|
| 180 |
+
)
|
| 181 |
+
}
|
| 182 |
+
)
|
preprocess/tools/midi_editor/src/components/LyricTable.tsx
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useMemo, useRef, useState } from 'react'
|
| 2 |
+
import type { NoteEvent } from '../types'
|
| 3 |
+
import type { Lang } from '../i18n'
|
| 4 |
+
import { getTranslations, tokenizeLyrics } from '../i18n'
|
| 5 |
+
|
| 6 |
+
export type LyricTableProps = {
|
| 7 |
+
notes: NoteEvent[]
|
| 8 |
+
selectedId: string | null
|
| 9 |
+
tempo: number
|
| 10 |
+
focusLyricId: string | null
|
| 11 |
+
lang: Lang
|
| 12 |
+
onSelect: (id: string | null) => void
|
| 13 |
+
onUpdate: (id: string, patch: Partial<NoteEvent>) => void
|
| 14 |
+
onScrollToNote?: (noteId: string) => void
|
| 15 |
+
onFocusHandled?: () => void
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
const formatSeconds = (beats: number, tempo: number) => {
|
| 19 |
+
const seconds = beats * (60 / tempo)
|
| 20 |
+
return Number.parseFloat(seconds.toFixed(2))
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
const secondsToBeats = (seconds: number, tempo: number) => {
|
| 24 |
+
return seconds * (tempo / 60)
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// Editable cell with confirmation
|
| 28 |
+
function EditableCell({
|
| 29 |
+
value,
|
| 30 |
+
noteId,
|
| 31 |
+
field,
|
| 32 |
+
tempo,
|
| 33 |
+
onConfirm,
|
| 34 |
+
confirmTitle,
|
| 35 |
+
type = 'number',
|
| 36 |
+
min,
|
| 37 |
+
step
|
| 38 |
+
}: {
|
| 39 |
+
value: number
|
| 40 |
+
noteId: string
|
| 41 |
+
field: 'midi' | 'start' | 'end'
|
| 42 |
+
tempo: number
|
| 43 |
+
onConfirm: (noteId: string, field: string, value: number) => void
|
| 44 |
+
confirmTitle?: string
|
| 45 |
+
type?: string
|
| 46 |
+
min?: number
|
| 47 |
+
step?: number
|
| 48 |
+
}) {
|
| 49 |
+
const displayValue = field === 'midi' ? value : formatSeconds(value, tempo)
|
| 50 |
+
const [localValue, setLocalValue] = useState(String(displayValue))
|
| 51 |
+
const [isDirty, setIsDirty] = useState(false)
|
| 52 |
+
const inputRef = useRef<HTMLInputElement>(null)
|
| 53 |
+
|
| 54 |
+
// Sync with external value when it changes (and not dirty)
|
| 55 |
+
useEffect(() => {
|
| 56 |
+
if (!isDirty) {
|
| 57 |
+
setLocalValue(String(displayValue))
|
| 58 |
+
}
|
| 59 |
+
}, [displayValue, isDirty])
|
| 60 |
+
|
| 61 |
+
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
| 62 |
+
setLocalValue(e.target.value)
|
| 63 |
+
setIsDirty(true)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
const handleConfirm = () => {
|
| 67 |
+
const parsed = parseFloat(localValue)
|
| 68 |
+
if (!isNaN(parsed)) {
|
| 69 |
+
if (field === 'midi') {
|
| 70 |
+
if (parsed >= 0 && parsed <= 127) {
|
| 71 |
+
onConfirm(noteId, field, Math.round(parsed))
|
| 72 |
+
}
|
| 73 |
+
} else {
|
| 74 |
+
if (parsed >= 0) {
|
| 75 |
+
onConfirm(noteId, field, secondsToBeats(parsed, tempo))
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
setIsDirty(false)
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
const handleKeyDown = (e: React.KeyboardEvent) => {
|
| 83 |
+
if (e.key === 'Enter') {
|
| 84 |
+
e.preventDefault()
|
| 85 |
+
handleConfirm()
|
| 86 |
+
inputRef.current?.blur()
|
| 87 |
+
} else if (e.key === 'Escape') {
|
| 88 |
+
setLocalValue(String(displayValue))
|
| 89 |
+
setIsDirty(false)
|
| 90 |
+
inputRef.current?.blur()
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
const handleBlur = () => {
|
| 95 |
+
if (isDirty) {
|
| 96 |
+
// Reset to original on blur without confirm
|
| 97 |
+
setLocalValue(String(displayValue))
|
| 98 |
+
setIsDirty(false)
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
return (
|
| 103 |
+
<div className="editable-cell">
|
| 104 |
+
<input
|
| 105 |
+
ref={inputRef}
|
| 106 |
+
className={`lyric-meta-input ${isDirty ? 'lyric-meta-dirty' : ''}`}
|
| 107 |
+
type={type}
|
| 108 |
+
min={min}
|
| 109 |
+
step={step}
|
| 110 |
+
value={localValue}
|
| 111 |
+
onChange={handleChange}
|
| 112 |
+
onKeyDown={handleKeyDown}
|
| 113 |
+
onBlur={handleBlur}
|
| 114 |
+
onClick={(e) => e.stopPropagation()}
|
| 115 |
+
/>
|
| 116 |
+
{isDirty && (
|
| 117 |
+
<button
|
| 118 |
+
className="confirm-btn"
|
| 119 |
+
onMouseDown={(e) => {
|
| 120 |
+
e.preventDefault() // Prevent input blur
|
| 121 |
+
e.stopPropagation()
|
| 122 |
+
}}
|
| 123 |
+
onClick={(e) => {
|
| 124 |
+
e.stopPropagation()
|
| 125 |
+
handleConfirm()
|
| 126 |
+
}}
|
| 127 |
+
title={confirmTitle}
|
| 128 |
+
>
|
| 129 |
+
✓
|
| 130 |
+
</button>
|
| 131 |
+
)}
|
| 132 |
+
</div>
|
| 133 |
+
)
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
export function LyricTable({ notes, selectedId, tempo, focusLyricId, lang, onSelect, onUpdate, onScrollToNote, onFocusHandled }: LyricTableProps) {
|
| 137 |
+
const t = getTranslations(lang)
|
| 138 |
+
const listRef = useRef<HTMLDivElement | null>(null)
|
| 139 |
+
const inputRefs = useRef<Map<string, HTMLInputElement>>(new Map())
|
| 140 |
+
const sorted = useMemo(() => [...notes].sort((a, b) => a.start - b.start), [notes])
|
| 141 |
+
|
| 142 |
+
// Scroll to selected note (no auto-focus on single click)
|
| 143 |
+
useEffect(() => {
|
| 144 |
+
if (!selectedId || !listRef.current) return
|
| 145 |
+
|
| 146 |
+
const target = listRef.current.querySelector<HTMLDivElement>(`[data-note-id="${selectedId}"]`)
|
| 147 |
+
if (target) {
|
| 148 |
+
target.scrollIntoView({ block: 'nearest', behavior: 'smooth' })
|
| 149 |
+
}
|
| 150 |
+
}, [selectedId])
|
| 151 |
+
|
| 152 |
+
// Focus lyric input when requested (double-click on note or click on list row)
|
| 153 |
+
useEffect(() => {
|
| 154 |
+
if (!focusLyricId) return
|
| 155 |
+
|
| 156 |
+
const input = inputRefs.current.get(focusLyricId)
|
| 157 |
+
if (input) {
|
| 158 |
+
setTimeout(() => {
|
| 159 |
+
input.focus()
|
| 160 |
+
input.select()
|
| 161 |
+
}, 50)
|
| 162 |
+
}
|
| 163 |
+
onFocusHandled?.()
|
| 164 |
+
}, [focusLyricId, onFocusHandled])
|
| 165 |
+
|
| 166 |
+
// Fill lyrics from selected note onwards
|
| 167 |
+
// Uses smart tokenizer: CJK chars -> one per note, English words -> one per note
|
| 168 |
+
const handleBulkFill = (bulkText: string) => {
|
| 169 |
+
if (!sorted.length) return
|
| 170 |
+
const tokens = tokenizeLyrics(bulkText)
|
| 171 |
+
if (!tokens.length) return
|
| 172 |
+
|
| 173 |
+
let startIndex = 0
|
| 174 |
+
if (selectedId) {
|
| 175 |
+
const selectedIndex = sorted.findIndex(n => n.id === selectedId)
|
| 176 |
+
if (selectedIndex >= 0) {
|
| 177 |
+
startIndex = selectedIndex
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
let tokenIndex = 0
|
| 182 |
+
for (let i = startIndex; i < sorted.length && tokenIndex < tokens.length; i++) {
|
| 183 |
+
onUpdate(sorted[i].id, { lyric: tokens[tokenIndex] })
|
| 184 |
+
tokenIndex++
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
const handleRowClick = (noteId: string) => {
|
| 189 |
+
onSelect(noteId)
|
| 190 |
+
onScrollToNote?.(noteId)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
const handleFieldConfirm = (noteId: string, field: string, value: number) => {
|
| 194 |
+
const note = notes.find(n => n.id === noteId)
|
| 195 |
+
if (!note) return
|
| 196 |
+
|
| 197 |
+
if (field === 'midi') {
|
| 198 |
+
onUpdate(noteId, { midi: value })
|
| 199 |
+
} else if (field === 'start') {
|
| 200 |
+
// Keep END the same, adjust duration accordingly
|
| 201 |
+
const currentEnd = note.start + note.duration
|
| 202 |
+
const newDuration = Math.max(0.01, currentEnd - value)
|
| 203 |
+
onUpdate(noteId, { start: value, duration: newDuration })
|
| 204 |
+
} else if (field === 'end') {
|
| 205 |
+
// End changed, update duration
|
| 206 |
+
const newDuration = Math.max(0.01, value - note.start)
|
| 207 |
+
onUpdate(noteId, { duration: newDuration })
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
return (
|
| 212 |
+
<div className="lyric-card">
|
| 213 |
+
<div className="lyric-bulk">
|
| 214 |
+
<textarea
|
| 215 |
+
className="lyric-bulk-input"
|
| 216 |
+
rows={2}
|
| 217 |
+
placeholder={selectedId ? t.fillPlaceholderSelected : t.fillPlaceholderDefault}
|
| 218 |
+
onKeyDown={(e) => {
|
| 219 |
+
if (e.key === 'Enter' && !e.shiftKey) {
|
| 220 |
+
e.preventDefault()
|
| 221 |
+
handleBulkFill(e.currentTarget.value)
|
| 222 |
+
}
|
| 223 |
+
}}
|
| 224 |
+
/>
|
| 225 |
+
<button
|
| 226 |
+
className="soft"
|
| 227 |
+
type="button"
|
| 228 |
+
onClick={(e) => {
|
| 229 |
+
const textarea = e.currentTarget.previousElementSibling as HTMLTextAreaElement
|
| 230 |
+
handleBulkFill(textarea.value)
|
| 231 |
+
}}
|
| 232 |
+
>
|
| 233 |
+
{t.fillButton.split('\n').map((line, i) => (
|
| 234 |
+
<span key={i}>{line}{i === 0 && <br/>}</span>
|
| 235 |
+
))}
|
| 236 |
+
</button>
|
| 237 |
+
</div>
|
| 238 |
+
<div className="lyric-header" style={{ flexShrink: 0 }}>
|
| 239 |
+
<div>LYRIC</div>
|
| 240 |
+
<div>PITCH</div>
|
| 241 |
+
<div>START</div>
|
| 242 |
+
<div>END</div>
|
| 243 |
+
</div>
|
| 244 |
+
<div className="lyric-list" ref={listRef}>
|
| 245 |
+
{sorted.map((note) => (
|
| 246 |
+
<div
|
| 247 |
+
key={note.id}
|
| 248 |
+
className={`lyric-row ${selectedId === note.id ? 'lyric-row-active' : ''}`}
|
| 249 |
+
data-note-id={note.id}
|
| 250 |
+
onClick={() => handleRowClick(note.id)}
|
| 251 |
+
>
|
| 252 |
+
<input
|
| 253 |
+
ref={(el) => {
|
| 254 |
+
if (el) {
|
| 255 |
+
inputRefs.current.set(note.id, el)
|
| 256 |
+
} else {
|
| 257 |
+
inputRefs.current.delete(note.id)
|
| 258 |
+
}
|
| 259 |
+
}}
|
| 260 |
+
className="lyric-input"
|
| 261 |
+
value={note.lyric}
|
| 262 |
+
placeholder={t.lyricPlaceholder}
|
| 263 |
+
onChange={(event) => onUpdate(note.id, { lyric: event.target.value })}
|
| 264 |
+
onClick={(e) => e.stopPropagation()}
|
| 265 |
+
/>
|
| 266 |
+
<EditableCell
|
| 267 |
+
value={note.midi}
|
| 268 |
+
noteId={note.id}
|
| 269 |
+
field="midi"
|
| 270 |
+
tempo={tempo}
|
| 271 |
+
onConfirm={handleFieldConfirm}
|
| 272 |
+
confirmTitle={t.confirmEdit}
|
| 273 |
+
min={0}
|
| 274 |
+
/>
|
| 275 |
+
<EditableCell
|
| 276 |
+
value={note.start}
|
| 277 |
+
noteId={note.id}
|
| 278 |
+
field="start"
|
| 279 |
+
tempo={tempo}
|
| 280 |
+
onConfirm={handleFieldConfirm}
|
| 281 |
+
confirmTitle={t.confirmEdit}
|
| 282 |
+
min={0}
|
| 283 |
+
step={0.01}
|
| 284 |
+
/>
|
| 285 |
+
<EditableCell
|
| 286 |
+
value={note.start + note.duration}
|
| 287 |
+
noteId={note.id}
|
| 288 |
+
field="end"
|
| 289 |
+
tempo={tempo}
|
| 290 |
+
onConfirm={handleFieldConfirm}
|
| 291 |
+
confirmTitle={t.confirmEdit}
|
| 292 |
+
min={0}
|
| 293 |
+
step={0.01}
|
| 294 |
+
/>
|
| 295 |
+
</div>
|
| 296 |
+
))}
|
| 297 |
+
{sorted.length === 0 && <div className="lyric-empty">{t.emptyHint}</div>}
|
| 298 |
+
</div>
|
| 299 |
+
</div>
|
| 300 |
+
)
|
| 301 |
+
}
|
preprocess/tools/midi_editor/src/components/PianoRoll.tsx
ADDED
|
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useMemo, useRef, useState, useCallback, memo } from 'react'
|
| 2 |
+
import type React from 'react'
|
| 3 |
+
import type { NoteEvent, TimeSignature } from '../types'
|
| 4 |
+
import { PITCH_WIDTH, LOW_NOTE, HIGH_NOTE } from '../constants'
|
| 5 |
+
|
| 6 |
+
const midiToName = (midi: number) => {
|
| 7 |
+
const names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
| 8 |
+
const octave = Math.floor(midi / 12) - 1
|
| 9 |
+
return `${names[midi % 12]}${octave}`
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
// Memoized note component to prevent unnecessary re-renders
|
| 13 |
+
const NoteChip = memo(function NoteChip({
|
| 14 |
+
note,
|
| 15 |
+
left,
|
| 16 |
+
top,
|
| 17 |
+
width,
|
| 18 |
+
height,
|
| 19 |
+
fontSize,
|
| 20 |
+
isSelected,
|
| 21 |
+
isOverlapping,
|
| 22 |
+
onPointerDown,
|
| 23 |
+
onDoubleClick,
|
| 24 |
+
}: {
|
| 25 |
+
note: NoteEvent
|
| 26 |
+
left: number
|
| 27 |
+
top: number
|
| 28 |
+
width: number
|
| 29 |
+
height: number
|
| 30 |
+
fontSize: number
|
| 31 |
+
isSelected: boolean
|
| 32 |
+
isOverlapping: boolean
|
| 33 |
+
onPointerDown: (event: React.PointerEvent<HTMLDivElement>, mode: 'move' | 'resize-start' | 'resize-end') => void
|
| 34 |
+
onDoubleClick: (event: React.MouseEvent<HTMLDivElement>) => void
|
| 35 |
+
}) {
|
| 36 |
+
return (
|
| 37 |
+
<div
|
| 38 |
+
className={`note-chip ${isSelected ? 'note-active' : ''} ${isOverlapping ? 'note-overlap' : ''}`}
|
| 39 |
+
style={{
|
| 40 |
+
left,
|
| 41 |
+
top: top + 1,
|
| 42 |
+
width,
|
| 43 |
+
height,
|
| 44 |
+
willChange: 'transform', // GPU acceleration hint
|
| 45 |
+
}}
|
| 46 |
+
onPointerDown={(e) => onPointerDown(e, 'move')}
|
| 47 |
+
onDoubleClick={onDoubleClick}
|
| 48 |
+
>
|
| 49 |
+
<div className="note-label" style={{ fontSize }}>
|
| 50 |
+
<span>{note.lyric || '\u00a0'}</span>
|
| 51 |
+
</div>
|
| 52 |
+
<div className="note-handle start" onPointerDown={(e) => { e.stopPropagation(); onPointerDown(e, 'resize-start') }} />
|
| 53 |
+
<div className="note-handle end" onPointerDown={(e) => { e.stopPropagation(); onPointerDown(e, 'resize-end') }} />
|
| 54 |
+
</div>
|
| 55 |
+
)
|
| 56 |
+
})
|
| 57 |
+
|
| 58 |
+
// Dynamic snap based on zoom level - higher zoom = finer snap
|
| 59 |
+
const getSnapSeconds = (gridSecondWidth: number) => {
|
| 60 |
+
// At base width (80px/s), snap is 0.1s
|
| 61 |
+
// At 2x zoom (160px/s), snap is 0.05s
|
| 62 |
+
// At 4x zoom (320px/s), snap is 0.025s
|
| 63 |
+
// At 8x zoom (640px/s), snap is 0.01s
|
| 64 |
+
const baseSnap = 0.1
|
| 65 |
+
const zoomFactor = gridSecondWidth / 80
|
| 66 |
+
return Math.max(0.01, baseSnap / zoomFactor)
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const snapSeconds = (value: number, gridSecondWidth: number) => {
|
| 70 |
+
const snap = getSnapSeconds(gridSecondWidth)
|
| 71 |
+
return Math.max(0, Math.round(value / snap) * snap)
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
export type PianoRollProps = {
|
| 75 |
+
notes: NoteEvent[]
|
| 76 |
+
selectedId: string | null
|
| 77 |
+
timeSignature: TimeSignature
|
| 78 |
+
tempo: number
|
| 79 |
+
playhead: number // in beats
|
| 80 |
+
selectionStart: number | null // in seconds
|
| 81 |
+
selectionEnd: number | null // in seconds
|
| 82 |
+
onAddNote: (note: Partial<NoteEvent>) => NoteEvent
|
| 83 |
+
onUpdateNote: (id: string, patch: Partial<NoteEvent>) => void
|
| 84 |
+
onSelect: (id: string | null) => void
|
| 85 |
+
onSeek: (beat: number) => void
|
| 86 |
+
onScroll?: (left: number) => void
|
| 87 |
+
onZoom?: (deltaH: number, deltaV: number) => void
|
| 88 |
+
onPlayNote?: (midi: number) => void
|
| 89 |
+
onFocusLyric?: (noteId: string) => void
|
| 90 |
+
onSelectionChange?: (start: number | null, end: number | null) => void
|
| 91 |
+
isSelectingRange?: boolean
|
| 92 |
+
audioDuration?: number
|
| 93 |
+
gridSecondWidth: number
|
| 94 |
+
rowHeight: number
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
export function PianoRoll({
|
| 98 |
+
notes,
|
| 99 |
+
selectedId,
|
| 100 |
+
timeSignature: _timeSignature,
|
| 101 |
+
tempo,
|
| 102 |
+
playhead,
|
| 103 |
+
selectionStart,
|
| 104 |
+
selectionEnd,
|
| 105 |
+
onAddNote,
|
| 106 |
+
onSelect,
|
| 107 |
+
onUpdateNote,
|
| 108 |
+
onSeek,
|
| 109 |
+
onScroll,
|
| 110 |
+
onZoom,
|
| 111 |
+
onPlayNote,
|
| 112 |
+
onFocusLyric,
|
| 113 |
+
onSelectionChange,
|
| 114 |
+
isSelectingRange = false,
|
| 115 |
+
audioDuration = 0,
|
| 116 |
+
gridSecondWidth,
|
| 117 |
+
rowHeight
|
| 118 |
+
}: PianoRollProps) {
|
| 119 |
+
const scrollContainerRef = useRef<HTMLDivElement | null>(null)
|
| 120 |
+
const rulerScrollRef = useRef<HTMLDivElement | null>(null)
|
| 121 |
+
const [scrollTop, setScrollTop] = useState(0)
|
| 122 |
+
const [scrollLeft, setScrollLeft] = useState(0)
|
| 123 |
+
const [viewportWidth, setViewportWidth] = useState(800)
|
| 124 |
+
const [viewportHeight, setViewportHeight] = useState(400)
|
| 125 |
+
const dragRef = useRef<{
|
| 126 |
+
id: string
|
| 127 |
+
mode: 'move' | 'resize-start' | 'resize-end'
|
| 128 |
+
originX: number
|
| 129 |
+
originY: number
|
| 130 |
+
startSeconds: number
|
| 131 |
+
durationSeconds: number
|
| 132 |
+
midi: number
|
| 133 |
+
lastMidi?: number // Track last midi for pitch change sound
|
| 134 |
+
} | null>(null)
|
| 135 |
+
|
| 136 |
+
// Selection drag state
|
| 137 |
+
const selectionDragRef = useRef<{
|
| 138 |
+
startX: number
|
| 139 |
+
startSeconds: number
|
| 140 |
+
} | null>(null)
|
| 141 |
+
|
| 142 |
+
// Store callbacks in refs to avoid stale closures in event handlers
|
| 143 |
+
const onPlayNoteRef = useRef(onPlayNote)
|
| 144 |
+
const onUpdateNoteRef = useRef(onUpdateNote)
|
| 145 |
+
|
| 146 |
+
useEffect(() => {
|
| 147 |
+
onPlayNoteRef.current = onPlayNote
|
| 148 |
+
onUpdateNoteRef.current = onUpdateNote
|
| 149 |
+
}, [onPlayNote, onUpdateNote])
|
| 150 |
+
|
| 151 |
+
// Conversion helpers
|
| 152 |
+
const beatToSeconds = useCallback((beat: number) => beat * (60 / tempo), [tempo])
|
| 153 |
+
const secondsToBeat = useCallback((seconds: number) => seconds / (60 / tempo), [tempo])
|
| 154 |
+
|
| 155 |
+
// Calculate dimensions
|
| 156 |
+
const totalRows = HIGH_NOTE - LOW_NOTE + 1
|
| 157 |
+
const contentHeight = totalRows * rowHeight
|
| 158 |
+
const [containerWidth, setContainerWidth] = useState(1200)
|
| 159 |
+
|
| 160 |
+
// Track container size
|
| 161 |
+
useEffect(() => {
|
| 162 |
+
const container = scrollContainerRef.current
|
| 163 |
+
if (!container) return
|
| 164 |
+
|
| 165 |
+
const observer = new ResizeObserver((entries) => {
|
| 166 |
+
for (const entry of entries) {
|
| 167 |
+
setContainerWidth(entry.contentRect.width)
|
| 168 |
+
setViewportWidth(entry.contentRect.width)
|
| 169 |
+
setViewportHeight(entry.contentRect.height)
|
| 170 |
+
}
|
| 171 |
+
})
|
| 172 |
+
observer.observe(container)
|
| 173 |
+
return () => observer.disconnect()
|
| 174 |
+
}, [])
|
| 175 |
+
|
| 176 |
+
const maxSeconds = useMemo(() => {
|
| 177 |
+
const noteEndSeconds = notes.reduce((acc, n) => {
|
| 178 |
+
const endBeat = n.start + n.duration
|
| 179 |
+
return Math.max(acc, beatToSeconds(endBeat))
|
| 180 |
+
}, 8)
|
| 181 |
+
// Ensure grid extends at least 2x the visible area for smoother scrolling
|
| 182 |
+
const minSecondsForView = (containerWidth / gridSecondWidth) * 2
|
| 183 |
+
return Math.max(noteEndSeconds + 10, audioDuration + 10, minSecondsForView, 30)
|
| 184 |
+
}, [notes, audioDuration, beatToSeconds, containerWidth, gridSecondWidth])
|
| 185 |
+
|
| 186 |
+
const contentWidth = maxSeconds * gridSecondWidth
|
| 187 |
+
|
| 188 |
+
// Drag handlers - use refs to avoid stale closure issues
|
| 189 |
+
const handlePointerMove = useCallback((event: PointerEvent) => {
|
| 190 |
+
const drag = dragRef.current
|
| 191 |
+
if (!drag) return
|
| 192 |
+
|
| 193 |
+
const dxSeconds = (event.clientX - drag.originX) / gridSecondWidth
|
| 194 |
+
const dy = (event.clientY - drag.originY) / rowHeight
|
| 195 |
+
|
| 196 |
+
if (drag.mode === 'move') {
|
| 197 |
+
const nextSeconds = snapSeconds(drag.startSeconds + dxSeconds, gridSecondWidth)
|
| 198 |
+
const nextMidi = Math.min(HIGH_NOTE, Math.max(LOW_NOTE, Math.round(drag.midi - dy)))
|
| 199 |
+
|
| 200 |
+
// Play sound when pitch changes
|
| 201 |
+
if (nextMidi !== drag.lastMidi && onPlayNoteRef.current) {
|
| 202 |
+
onPlayNoteRef.current(nextMidi)
|
| 203 |
+
drag.lastMidi = nextMidi
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
onUpdateNoteRef.current(drag.id, {
|
| 207 |
+
start: secondsToBeat(nextSeconds),
|
| 208 |
+
midi: nextMidi
|
| 209 |
+
})
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
if (drag.mode === 'resize-start') {
|
| 213 |
+
const nextSeconds = snapSeconds(drag.startSeconds + dxSeconds, gridSecondWidth)
|
| 214 |
+
const delta = drag.startSeconds - nextSeconds
|
| 215 |
+
const nextDurationSeconds = Math.max(0.05, drag.durationSeconds + delta)
|
| 216 |
+
onUpdateNoteRef.current(drag.id, {
|
| 217 |
+
start: secondsToBeat(nextSeconds),
|
| 218 |
+
duration: secondsToBeat(nextDurationSeconds)
|
| 219 |
+
})
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
if (drag.mode === 'resize-end') {
|
| 223 |
+
const nextDurationSeconds = Math.max(0.05, snapSeconds(drag.durationSeconds + dxSeconds, gridSecondWidth))
|
| 224 |
+
onUpdateNoteRef.current(drag.id, { duration: secondsToBeat(nextDurationSeconds) })
|
| 225 |
+
}
|
| 226 |
+
}, [gridSecondWidth, rowHeight, secondsToBeat])
|
| 227 |
+
|
| 228 |
+
const handlePointerUp = useCallback(() => {
|
| 229 |
+
dragRef.current = null
|
| 230 |
+
window.removeEventListener('pointermove', handlePointerMove)
|
| 231 |
+
window.removeEventListener('pointerup', handlePointerUp)
|
| 232 |
+
}, [handlePointerMove])
|
| 233 |
+
|
| 234 |
+
useEffect(() => {
|
| 235 |
+
return () => {
|
| 236 |
+
window.removeEventListener('pointermove', handlePointerMove)
|
| 237 |
+
window.removeEventListener('pointerup', handlePointerUp)
|
| 238 |
+
}
|
| 239 |
+
}, [handlePointerMove, handlePointerUp])
|
| 240 |
+
|
| 241 |
+
// Scroll sync
|
| 242 |
+
useEffect(() => {
|
| 243 |
+
const container = scrollContainerRef.current
|
| 244 |
+
const ruler = rulerScrollRef.current
|
| 245 |
+
if (!container || !ruler) return
|
| 246 |
+
|
| 247 |
+
const handleScroll = () => {
|
| 248 |
+
ruler.scrollLeft = container.scrollLeft
|
| 249 |
+
setScrollTop(container.scrollTop)
|
| 250 |
+
setScrollLeft(container.scrollLeft)
|
| 251 |
+
if (onScroll) onScroll(container.scrollLeft)
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
container.addEventListener('scroll', handleScroll)
|
| 255 |
+
return () => container.removeEventListener('scroll', handleScroll)
|
| 256 |
+
}, [onScroll])
|
| 257 |
+
|
| 258 |
+
// Zoom support via wheel/trackpad
|
| 259 |
+
// Mac: Cmd+滚轮 (水平缩放), Cmd+Shift+滚轮 (垂直缩放), 或双指捏合
|
| 260 |
+
// Windows/Linux: Ctrl+滚轮 (水平缩放), Ctrl+Shift+滚轮 (垂直缩放)
|
| 261 |
+
useEffect(() => {
|
| 262 |
+
const container = scrollContainerRef.current
|
| 263 |
+
if (!container || !onZoom) return
|
| 264 |
+
|
| 265 |
+
const handleWheel = (e: WheelEvent) => {
|
| 266 |
+
// Ctrl (Windows/Linux/捏合) or Cmd (Mac) triggers zoom
|
| 267 |
+
const isZoomTrigger = e.ctrlKey || e.metaKey
|
| 268 |
+
|
| 269 |
+
if (isZoomTrigger) {
|
| 270 |
+
e.preventDefault()
|
| 271 |
+
e.stopPropagation()
|
| 272 |
+
|
| 273 |
+
// Use deltaY for zoom amount, normalize for different input methods
|
| 274 |
+
// Pinch gestures typically have smaller delta values
|
| 275 |
+
let delta = -e.deltaY
|
| 276 |
+
if (Math.abs(delta) > 10) {
|
| 277 |
+
// Likely a mouse wheel, scale down
|
| 278 |
+
delta = delta * 0.01
|
| 279 |
+
} else {
|
| 280 |
+
// Likely a trackpad pinch, scale appropriately
|
| 281 |
+
delta = delta * 0.05
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// Shift or Alt/Option for vertical zoom, otherwise horizontal
|
| 285 |
+
if (e.shiftKey || e.altKey) {
|
| 286 |
+
onZoom(0, delta)
|
| 287 |
+
} else {
|
| 288 |
+
onZoom(delta, 0)
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
container.addEventListener('wheel', handleWheel, { passive: false })
|
| 294 |
+
return () => container.removeEventListener('wheel', handleWheel)
|
| 295 |
+
}, [onZoom])
|
| 296 |
+
|
| 297 |
+
// Playhead auto-scroll
|
| 298 |
+
useEffect(() => {
|
| 299 |
+
if (!scrollContainerRef.current) return
|
| 300 |
+
const container = scrollContainerRef.current
|
| 301 |
+
const playheadX = beatToSeconds(playhead) * gridSecondWidth
|
| 302 |
+
const viewStart = container.scrollLeft
|
| 303 |
+
const viewEnd = viewStart + container.clientWidth
|
| 304 |
+
|
| 305 |
+
if (playheadX > viewEnd) {
|
| 306 |
+
container.scrollLeft = playheadX
|
| 307 |
+
} else if (playheadX < viewStart) {
|
| 308 |
+
container.scrollLeft = playheadX
|
| 309 |
+
}
|
| 310 |
+
}, [playhead, gridSecondWidth, beatToSeconds])
|
| 311 |
+
|
| 312 |
+
// Selection auto-scroll
|
| 313 |
+
useEffect(() => {
|
| 314 |
+
if (!scrollContainerRef.current || !selectedId) return
|
| 315 |
+
const note = notes.find((n) => n.id === selectedId)
|
| 316 |
+
if (!note) return
|
| 317 |
+
const container = scrollContainerRef.current
|
| 318 |
+
const noteX = beatToSeconds(note.start) * gridSecondWidth
|
| 319 |
+
const noteY = (HIGH_NOTE - note.midi) * rowHeight
|
| 320 |
+
|
| 321 |
+
const viewStart = container.scrollLeft
|
| 322 |
+
const viewEnd = viewStart + container.clientWidth
|
| 323 |
+
if (noteX < viewStart + 50 || noteX > viewEnd - 50) {
|
| 324 |
+
container.scrollLeft = Math.max(0, noteX - container.clientWidth * 0.35)
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
const viewTop = container.scrollTop
|
| 328 |
+
const viewBottom = viewTop + container.clientHeight
|
| 329 |
+
if (noteY < viewTop || noteY > viewBottom - rowHeight) {
|
| 330 |
+
container.scrollTop = Math.max(0, noteY - container.clientHeight * 0.4)
|
| 331 |
+
}
|
| 332 |
+
}, [selectedId, notes, gridSecondWidth, rowHeight, beatToSeconds])
|
| 333 |
+
|
| 334 |
+
const handleGridDoubleClick = (event: React.MouseEvent<HTMLDivElement>) => {
|
| 335 |
+
// Only add note if clicking on empty space (not on a note)
|
| 336 |
+
const target = event.target as HTMLElement
|
| 337 |
+
if (target.closest('.note-chip')) return
|
| 338 |
+
|
| 339 |
+
if (!scrollContainerRef.current) return
|
| 340 |
+
const container = scrollContainerRef.current
|
| 341 |
+
const rect = container.getBoundingClientRect()
|
| 342 |
+
const x = event.clientX - rect.left + container.scrollLeft
|
| 343 |
+
const y = event.clientY - rect.top + container.scrollTop
|
| 344 |
+
|
| 345 |
+
const seconds = snapSeconds(x / gridSecondWidth, gridSecondWidth)
|
| 346 |
+
const pitch = Math.min(HIGH_NOTE, Math.max(LOW_NOTE, HIGH_NOTE - Math.floor(y / rowHeight)))
|
| 347 |
+
|
| 348 |
+
const created = onAddNote({
|
| 349 |
+
start: secondsToBeat(seconds),
|
| 350 |
+
midi: pitch,
|
| 351 |
+
duration: secondsToBeat(0.5),
|
| 352 |
+
lyric: ''
|
| 353 |
+
})
|
| 354 |
+
onSelect(created.id)
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
const startDrag = (
|
| 358 |
+
event: React.PointerEvent<HTMLDivElement>,
|
| 359 |
+
note: NoteEvent,
|
| 360 |
+
mode: 'move' | 'resize-start' | 'resize-end',
|
| 361 |
+
) => {
|
| 362 |
+
event.preventDefault()
|
| 363 |
+
event.stopPropagation()
|
| 364 |
+
dragRef.current = {
|
| 365 |
+
id: note.id,
|
| 366 |
+
mode,
|
| 367 |
+
originX: event.clientX,
|
| 368 |
+
originY: event.clientY,
|
| 369 |
+
startSeconds: beatToSeconds(note.start),
|
| 370 |
+
durationSeconds: beatToSeconds(note.duration),
|
| 371 |
+
midi: note.midi,
|
| 372 |
+
lastMidi: note.midi, // Initialize last midi
|
| 373 |
+
}
|
| 374 |
+
window.addEventListener('pointermove', handlePointerMove)
|
| 375 |
+
window.addEventListener('pointerup', handlePointerUp)
|
| 376 |
+
onSelect(note.id)
|
| 377 |
+
|
| 378 |
+
// Play sound when clicking/selecting note
|
| 379 |
+
if (onPlayNote) {
|
| 380 |
+
onPlayNote(note.midi)
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
// Second-based ruler labels
|
| 385 |
+
const secondLabels = useMemo(() => {
|
| 386 |
+
const labels = [] as Array<{ left: number; label: string }>
|
| 387 |
+
const totalSeconds = Math.ceil(maxSeconds)
|
| 388 |
+
for (let s = 0; s <= totalSeconds; s += 1) {
|
| 389 |
+
labels.push({ left: s * gridSecondWidth, label: `${s}s` })
|
| 390 |
+
}
|
| 391 |
+
return labels
|
| 392 |
+
}, [maxSeconds, gridSecondWidth])
|
| 393 |
+
|
| 394 |
+
// Piano keys
|
| 395 |
+
const pitchRows = useMemo(() => {
|
| 396 |
+
const rows = [] as Array<{ midi: number; isBlack: boolean; label: string; isC: boolean }>
|
| 397 |
+
const black = new Set([1, 3, 6, 8, 10])
|
| 398 |
+
for (let p = HIGH_NOTE; p >= LOW_NOTE; p -= 1) {
|
| 399 |
+
const name = midiToName(p)
|
| 400 |
+
const isC = p % 12 === 0
|
| 401 |
+
rows.push({ midi: p, isBlack: black.has(p % 12), label: name, isC })
|
| 402 |
+
}
|
| 403 |
+
return rows
|
| 404 |
+
}, [])
|
| 405 |
+
|
| 406 |
+
// Detect overlapping notes using optimized sweep line algorithm
|
| 407 |
+
const overlappingNoteIds = useMemo(() => {
|
| 408 |
+
if (notes.length < 2) return new Set<string>()
|
| 409 |
+
|
| 410 |
+
const overlapping = new Set<string>()
|
| 411 |
+
const sortedNotes = [...notes].sort((a, b) => a.start - b.start)
|
| 412 |
+
const EPSILON = 0.05 // Tolerance for floating point comparison
|
| 413 |
+
|
| 414 |
+
// Use a sliding window approach - more efficient for typical music data
|
| 415 |
+
// Active notes: notes that haven't ended yet
|
| 416 |
+
const activeNotes: typeof sortedNotes = []
|
| 417 |
+
|
| 418 |
+
for (const note of sortedNotes) {
|
| 419 |
+
// Remove notes that have ended before current note starts
|
| 420 |
+
while (activeNotes.length > 0) {
|
| 421 |
+
const firstActive = activeNotes[0]
|
| 422 |
+
const firstActiveEnd = firstActive.start + firstActive.duration
|
| 423 |
+
if (firstActiveEnd <= note.start + EPSILON) {
|
| 424 |
+
activeNotes.shift()
|
| 425 |
+
} else {
|
| 426 |
+
break
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
// Check overlap with remaining active notes
|
| 431 |
+
for (const activeNote of activeNotes) {
|
| 432 |
+
const activeEnd = activeNote.start + activeNote.duration
|
| 433 |
+
if (note.start < activeEnd - EPSILON) {
|
| 434 |
+
overlapping.add(activeNote.id)
|
| 435 |
+
overlapping.add(note.id)
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
// Add current note to active set (maintain sorted order by end time)
|
| 440 |
+
const noteEnd = note.start + note.duration
|
| 441 |
+
let insertIndex = activeNotes.length
|
| 442 |
+
for (let i = 0; i < activeNotes.length; i++) {
|
| 443 |
+
const aEnd = activeNotes[i].start + activeNotes[i].duration
|
| 444 |
+
if (noteEnd < aEnd) {
|
| 445 |
+
insertIndex = i
|
| 446 |
+
break
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
activeNotes.splice(insertIndex, 0, note)
|
| 450 |
+
}
|
| 451 |
+
return overlapping
|
| 452 |
+
}, [notes])
|
| 453 |
+
|
| 454 |
+
// Calculate visible area with buffer for smooth scrolling
|
| 455 |
+
const BUFFER_PX = 200 // Render notes slightly outside viewport for smooth scrolling
|
| 456 |
+
const visibleArea = useMemo(() => {
|
| 457 |
+
return {
|
| 458 |
+
left: Math.max(0, scrollLeft - BUFFER_PX),
|
| 459 |
+
right: scrollLeft + viewportWidth + BUFFER_PX,
|
| 460 |
+
top: Math.max(0, scrollTop - BUFFER_PX),
|
| 461 |
+
bottom: scrollTop + viewportHeight + BUFFER_PX,
|
| 462 |
+
}
|
| 463 |
+
}, [scrollLeft, scrollTop, viewportWidth, viewportHeight])
|
| 464 |
+
|
| 465 |
+
// Filter notes to only render visible ones (virtualization)
|
| 466 |
+
const visibleNotes = useMemo(() => {
|
| 467 |
+
return notes.filter(note => {
|
| 468 |
+
const noteSeconds = beatToSeconds(note.start)
|
| 469 |
+
const noteDurationSeconds = beatToSeconds(note.duration)
|
| 470 |
+
const noteLeft = noteSeconds * gridSecondWidth
|
| 471 |
+
const noteRight = noteLeft + noteDurationSeconds * gridSecondWidth
|
| 472 |
+
const noteTop = (HIGH_NOTE - note.midi) * rowHeight
|
| 473 |
+
const noteBottom = noteTop + rowHeight
|
| 474 |
+
|
| 475 |
+
// Check if note intersects with visible area
|
| 476 |
+
const horizontallyVisible = noteRight >= visibleArea.left && noteLeft <= visibleArea.right
|
| 477 |
+
const verticallyVisible = noteBottom >= visibleArea.top && noteTop <= visibleArea.bottom
|
| 478 |
+
|
| 479 |
+
return horizontallyVisible && verticallyVisible
|
| 480 |
+
})
|
| 481 |
+
}, [notes, visibleArea, gridSecondWidth, rowHeight, beatToSeconds])
|
| 482 |
+
|
| 483 |
+
// Calculate visible grid lines (virtualization)
|
| 484 |
+
const visibleGridLines = useMemo(() => {
|
| 485 |
+
const startSecond = Math.max(0, Math.floor(visibleArea.left / gridSecondWidth) - 1)
|
| 486 |
+
const endSecond = Math.ceil(visibleArea.right / gridSecondWidth) + 1
|
| 487 |
+
const startRow = Math.max(0, Math.floor(visibleArea.top / rowHeight) - 1)
|
| 488 |
+
const endRow = Math.min(totalRows, Math.ceil(visibleArea.bottom / rowHeight) + 1)
|
| 489 |
+
|
| 490 |
+
return {
|
| 491 |
+
horizontalLines: Array.from({ length: endRow - startRow + 1 }, (_, i) => startRow + i),
|
| 492 |
+
verticalLines: Array.from({ length: endSecond - startSecond + 1 }, (_, i) => startSecond + i),
|
| 493 |
+
}
|
| 494 |
+
}, [visibleArea, gridSecondWidth, rowHeight, totalRows])
|
| 495 |
+
|
| 496 |
+
const playheadSeconds = beatToSeconds(playhead)
|
| 497 |
+
|
| 498 |
+
// Selection drag handlers
|
| 499 |
+
const handleRulerPointerDown = (event: React.PointerEvent<HTMLDivElement>) => {
|
| 500 |
+
if (!isSelectingRange) {
|
| 501 |
+
// Normal click to seek
|
| 502 |
+
const rect = event.currentTarget.getBoundingClientRect()
|
| 503 |
+
const x = event.clientX - rect.left + (rulerScrollRef.current?.scrollLeft ?? 0)
|
| 504 |
+
const seconds = x / gridSecondWidth
|
| 505 |
+
onSeek(secondsToBeat(seconds))
|
| 506 |
+
return
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
// Start selection drag
|
| 510 |
+
const rect = event.currentTarget.getBoundingClientRect()
|
| 511 |
+
const x = event.clientX - rect.left + (rulerScrollRef.current?.scrollLeft ?? 0)
|
| 512 |
+
const seconds = Math.max(0, x / gridSecondWidth)
|
| 513 |
+
|
| 514 |
+
selectionDragRef.current = {
|
| 515 |
+
startX: event.clientX,
|
| 516 |
+
startSeconds: seconds,
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
onSelectionChange?.(seconds, seconds)
|
| 520 |
+
|
| 521 |
+
const handleSelectionMove = (e: PointerEvent) => {
|
| 522 |
+
if (!selectionDragRef.current) return
|
| 523 |
+
const currentX = e.clientX - rect.left + (rulerScrollRef.current?.scrollLeft ?? 0)
|
| 524 |
+
const currentSeconds = Math.max(0, currentX / gridSecondWidth)
|
| 525 |
+
const start = Math.min(selectionDragRef.current.startSeconds, currentSeconds)
|
| 526 |
+
const end = Math.max(selectionDragRef.current.startSeconds, currentSeconds)
|
| 527 |
+
onSelectionChange?.(start, end)
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
const handleSelectionUp = () => {
|
| 531 |
+
selectionDragRef.current = null
|
| 532 |
+
window.removeEventListener('pointermove', handleSelectionMove)
|
| 533 |
+
window.removeEventListener('pointerup', handleSelectionUp)
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
window.addEventListener('pointermove', handleSelectionMove)
|
| 537 |
+
window.addEventListener('pointerup', handleSelectionUp)
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
return (
|
| 541 |
+
<div className="piano-shell">
|
| 542 |
+
{/* Ruler */}
|
| 543 |
+
<div className="ruler-shell">
|
| 544 |
+
<div className="ruler-spacer" style={{ width: PITCH_WIDTH, flexShrink: 0 }} />
|
| 545 |
+
<div
|
| 546 |
+
ref={rulerScrollRef}
|
| 547 |
+
className={`ruler-scroll ${isSelectingRange ? 'selecting' : ''}`}
|
| 548 |
+
onPointerDown={handleRulerPointerDown}
|
| 549 |
+
>
|
| 550 |
+
<div className="ruler" style={{ width: contentWidth }}>
|
| 551 |
+
{secondLabels.map((mark) => (
|
| 552 |
+
<div key={mark.left} className="measure-mark" style={{ left: mark.left }}>
|
| 553 |
+
<span>{mark.label}</span>
|
| 554 |
+
</div>
|
| 555 |
+
))}
|
| 556 |
+
{/* Selection range indicator */}
|
| 557 |
+
{selectionStart !== null && selectionEnd !== null && selectionEnd > selectionStart && (
|
| 558 |
+
<div
|
| 559 |
+
className="selection-range"
|
| 560 |
+
style={{
|
| 561 |
+
left: selectionStart * gridSecondWidth,
|
| 562 |
+
width: (selectionEnd - selectionStart) * gridSecondWidth
|
| 563 |
+
}}
|
| 564 |
+
/>
|
| 565 |
+
)}
|
| 566 |
+
{/* Ruler playhead indicator */}
|
| 567 |
+
<div
|
| 568 |
+
className="ruler-playhead"
|
| 569 |
+
style={{ left: playheadSeconds * gridSecondWidth }}
|
| 570 |
+
/>
|
| 571 |
+
</div>
|
| 572 |
+
</div>
|
| 573 |
+
</div>
|
| 574 |
+
|
| 575 |
+
{/* Main content area */}
|
| 576 |
+
<div className="roll-body">
|
| 577 |
+
{/* Piano keys - synced with vertical scroll */}
|
| 578 |
+
<div className="pitch-rail" style={{ width: PITCH_WIDTH }}>
|
| 579 |
+
<div
|
| 580 |
+
className="pitch-rail-inner"
|
| 581 |
+
style={{
|
| 582 |
+
transform: `translateY(${-scrollTop}px)`,
|
| 583 |
+
height: contentHeight
|
| 584 |
+
}}
|
| 585 |
+
>
|
| 586 |
+
{pitchRows.map((pitch) => (
|
| 587 |
+
<div
|
| 588 |
+
key={pitch.midi}
|
| 589 |
+
className={`pitch-cell ${pitch.isBlack ? 'pitch-black' : 'pitch-white'} ${pitch.isC ? 'pitch-c' : ''}`}
|
| 590 |
+
style={{ height: rowHeight, cursor: 'pointer' }}
|
| 591 |
+
onClick={() => onPlayNote?.(pitch.midi)}
|
| 592 |
+
onMouseDown={(e) => e.preventDefault()}
|
| 593 |
+
>
|
| 594 |
+
<span className="pitch-label">{pitch.label}</span>
|
| 595 |
+
</div>
|
| 596 |
+
))}
|
| 597 |
+
</div>
|
| 598 |
+
</div>
|
| 599 |
+
|
| 600 |
+
{/* Scrollable grid area */}
|
| 601 |
+
<div
|
| 602 |
+
ref={scrollContainerRef}
|
| 603 |
+
className="roll-grid"
|
| 604 |
+
onDoubleClick={handleGridDoubleClick}
|
| 605 |
+
>
|
| 606 |
+
<div
|
| 607 |
+
className="grid-content"
|
| 608 |
+
style={{
|
| 609 |
+
width: contentWidth,
|
| 610 |
+
height: contentHeight,
|
| 611 |
+
position: 'relative'
|
| 612 |
+
}}
|
| 613 |
+
>
|
| 614 |
+
{/* SVG Grid - virtualized for performance */}
|
| 615 |
+
<svg
|
| 616 |
+
className="grid-svg"
|
| 617 |
+
width={contentWidth}
|
| 618 |
+
height={contentHeight}
|
| 619 |
+
style={{ position: 'absolute', top: 0, left: 0, pointerEvents: 'none' }}
|
| 620 |
+
>
|
| 621 |
+
{/* Horizontal lines (pitch rows) - only visible ones */}
|
| 622 |
+
{visibleGridLines.horizontalLines.map(i => (
|
| 623 |
+
<line
|
| 624 |
+
key={`h-${i}`}
|
| 625 |
+
x1={visibleArea.left}
|
| 626 |
+
y1={i * rowHeight}
|
| 627 |
+
x2={visibleArea.right}
|
| 628 |
+
y2={i * rowHeight}
|
| 629 |
+
stroke="var(--grid-line-minor)"
|
| 630 |
+
strokeWidth={1}
|
| 631 |
+
/>
|
| 632 |
+
))}
|
| 633 |
+
{/* Vertical lines (seconds) - only visible ones */}
|
| 634 |
+
{visibleGridLines.verticalLines.map(i => (
|
| 635 |
+
<line
|
| 636 |
+
key={`v-${i}`}
|
| 637 |
+
x1={i * gridSecondWidth}
|
| 638 |
+
y1={visibleArea.top}
|
| 639 |
+
x2={i * gridSecondWidth}
|
| 640 |
+
y2={visibleArea.bottom}
|
| 641 |
+
stroke="var(--grid-line-minor)"
|
| 642 |
+
strokeWidth={1}
|
| 643 |
+
/>
|
| 644 |
+
))}
|
| 645 |
+
</svg>
|
| 646 |
+
|
| 647 |
+
{/* Selection range in grid */}
|
| 648 |
+
{selectionStart !== null && selectionEnd !== null && selectionEnd > selectionStart && (
|
| 649 |
+
<div
|
| 650 |
+
className="grid-selection-range"
|
| 651 |
+
style={{
|
| 652 |
+
left: selectionStart * gridSecondWidth,
|
| 653 |
+
width: (selectionEnd - selectionStart) * gridSecondWidth,
|
| 654 |
+
height: contentHeight
|
| 655 |
+
}}
|
| 656 |
+
/>
|
| 657 |
+
)}
|
| 658 |
+
|
| 659 |
+
{/* Playhead */}
|
| 660 |
+
<div
|
| 661 |
+
className="playhead"
|
| 662 |
+
style={{
|
| 663 |
+
left: playheadSeconds * gridSecondWidth,
|
| 664 |
+
height: contentHeight
|
| 665 |
+
}}
|
| 666 |
+
/>
|
| 667 |
+
|
| 668 |
+
{/* Notes - virtualized: only render visible notes */}
|
| 669 |
+
{visibleNotes.map((note) => {
|
| 670 |
+
const noteSeconds = beatToSeconds(note.start)
|
| 671 |
+
const noteDurationSeconds = beatToSeconds(note.duration)
|
| 672 |
+
const left = noteSeconds * gridSecondWidth
|
| 673 |
+
const top = (HIGH_NOTE - note.midi) * rowHeight
|
| 674 |
+
const noteWidthPx = Math.max(noteDurationSeconds * gridSecondWidth, 4)
|
| 675 |
+
const noteHeight = rowHeight - 2
|
| 676 |
+
const isOverlapping = overlappingNoteIds.has(note.id)
|
| 677 |
+
// Dynamic font size based on row height (base: 12px at 20px row height)
|
| 678 |
+
const fontSize = Math.max(10, Math.min(24, rowHeight * 0.6))
|
| 679 |
+
|
| 680 |
+
return (
|
| 681 |
+
<NoteChip
|
| 682 |
+
key={note.id}
|
| 683 |
+
note={note}
|
| 684 |
+
left={left}
|
| 685 |
+
top={top}
|
| 686 |
+
width={noteWidthPx}
|
| 687 |
+
height={noteHeight}
|
| 688 |
+
fontSize={fontSize}
|
| 689 |
+
isSelected={selectedId === note.id}
|
| 690 |
+
isOverlapping={isOverlapping}
|
| 691 |
+
onPointerDown={(event, mode) => startDrag(event, note, mode)}
|
| 692 |
+
onDoubleClick={(event) => {
|
| 693 |
+
event.stopPropagation()
|
| 694 |
+
onFocusLyric?.(note.id)
|
| 695 |
+
}}
|
| 696 |
+
/>
|
| 697 |
+
)
|
| 698 |
+
})}
|
| 699 |
+
</div>
|
| 700 |
+
</div>
|
| 701 |
+
</div>
|
| 702 |
+
</div>
|
| 703 |
+
)
|
| 704 |
+
}
|
preprocess/tools/midi_editor/src/constants.ts
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Base values used for scaling; actual runtime values are derived in components
|
| 2 |
+
export const BASE_GRID_SECOND_WIDTH = 80
|
| 3 |
+
export const BASE_ROW_HEIGHT = 20
|
| 4 |
+
export const PITCH_WIDTH = 60
|
| 5 |
+
// C-1 to C8 range (MIDI note numbers)
|
| 6 |
+
// LOW_NOTE = 0 to support SP markers (pitch=0) in some MIDI files
|
| 7 |
+
export const LOW_NOTE = 0 // C-1 (also supports pitch=0 for SP markers)
|
| 8 |
+
export const HIGH_NOTE = 108 // C8
|
preprocess/tools/midi_editor/src/i18n.ts
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export type Lang = 'zh' | 'en'
|
| 2 |
+
|
| 3 |
+
const zh = {
|
| 4 |
+
// Header
|
| 5 |
+
eyebrow: '歌声 MIDI 编辑器',
|
| 6 |
+
title: 'SoulX-Singer MIDI Editor',
|
| 7 |
+
subtitle: '导入、拖拽、实时修改歌词并导出标准 MIDI。',
|
| 8 |
+
switchToLight: '切换到亮色',
|
| 9 |
+
switchToDark: '切换到暗色',
|
| 10 |
+
importJson: '导入 JSON',
|
| 11 |
+
exportJson: '导出 JSON',
|
| 12 |
+
importMidi: '导入 MIDI',
|
| 13 |
+
exportMidi: '导出 MIDI',
|
| 14 |
+
transpose: '移调',
|
| 15 |
+
transposeTooltip: '整体升降调:所有音符的音高同步改变',
|
| 16 |
+
transposed: (n: number) => `已移调 ${n > 0 ? '+' : ''}${n} 半音`,
|
| 17 |
+
fixOverlaps: '消除重叠',
|
| 18 |
+
fixOverlapsTooltip: '自动消除重叠:将重叠音符的音尾提前到下一个音的音头',
|
| 19 |
+
jsonImported: (name: string) => `已从 JSON 载入 ${name}`,
|
| 20 |
+
jsonImportFailed: 'JSON 导入失败,请确认文件格式正确',
|
| 21 |
+
jsonExported: '已导出 META JSON 文件',
|
| 22 |
+
|
| 23 |
+
// Audio bar
|
| 24 |
+
importAudio: '对齐音频导入',
|
| 25 |
+
audioHint: '导入后显示音频波形并与 MIDI 同步走带',
|
| 26 |
+
midiLabel: 'MIDI',
|
| 27 |
+
audioLabel: '音频',
|
| 28 |
+
|
| 29 |
+
// Controls
|
| 30 |
+
horizontalZoom: '水平缩放',
|
| 31 |
+
verticalZoom: '垂直缩放',
|
| 32 |
+
goToStart: '回到开头',
|
| 33 |
+
back2s: '后退 2 秒',
|
| 34 |
+
pause: '暂停',
|
| 35 |
+
playSelection: '播放选区',
|
| 36 |
+
play: '播放',
|
| 37 |
+
forward2s: '前进 2 秒',
|
| 38 |
+
goToEnd: '回到结尾',
|
| 39 |
+
selectingRange: '选区中',
|
| 40 |
+
setRange: '设选区',
|
| 41 |
+
exitSelectMode: '退出选区模式(并清除选区)',
|
| 42 |
+
setRangeTooltip: '设置选区:在时间轴上拖拽选择播放范围',
|
| 43 |
+
|
| 44 |
+
// Status
|
| 45 |
+
ready: '准备就绪',
|
| 46 |
+
selectionPlayback: '选区回放中...',
|
| 47 |
+
playing: '正在回放...',
|
| 48 |
+
selectionDone: '选区播放完毕',
|
| 49 |
+
paused: '已暂停',
|
| 50 |
+
imported: (name: string) => `已载入 ${name}`,
|
| 51 |
+
importFailed: '导入失败,请确认文件合法',
|
| 52 |
+
audioImported: (name: string) => `已载入音频 ${name}`,
|
| 53 |
+
unsupportedFormat: (exts: string) => `不支持的文件格式,请选择音频文件(${exts})`,
|
| 54 |
+
fixedOverlaps: (count: number) => `已修复 ${count} 个重叠音符`,
|
| 55 |
+
noOverlaps: '没有检测到重叠音符',
|
| 56 |
+
exported: '已导出包含歌词的 MIDI 文件',
|
| 57 |
+
|
| 58 |
+
// Lyric table
|
| 59 |
+
fillPlaceholderSelected: '从选中音符开始按词/字填充',
|
| 60 |
+
fillPlaceholderDefault: '输入歌词,点击按词/字填充',
|
| 61 |
+
fillButton: '按词\n填充',
|
| 62 |
+
lyricPlaceholder: '输入歌词',
|
| 63 |
+
emptyHint: '导入或双击钢琴卷帘以添加音符',
|
| 64 |
+
confirmEdit: '确认修改 (Enter)',
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
const en: typeof zh = {
|
| 68 |
+
// Header
|
| 69 |
+
eyebrow: 'Vocal MIDI Editor',
|
| 70 |
+
title: 'SoulX-Singer MIDI Editor',
|
| 71 |
+
subtitle: 'Import, drag, edit lyrics in real-time, and export standard MIDI.',
|
| 72 |
+
switchToLight: 'Switch to light',
|
| 73 |
+
switchToDark: 'Switch to dark',
|
| 74 |
+
importJson: 'Import JSON',
|
| 75 |
+
exportJson: 'Export JSON',
|
| 76 |
+
importMidi: 'Import MIDI',
|
| 77 |
+
exportMidi: 'Export MIDI',
|
| 78 |
+
transpose: 'Transpose',
|
| 79 |
+
transposeTooltip: 'Transpose all notes up or down by semitones',
|
| 80 |
+
transposed: (n: number) => `Transposed ${n > 0 ? '+' : ''}${n} semitone(s)`,
|
| 81 |
+
fixOverlaps: 'Fix Overlaps',
|
| 82 |
+
fixOverlapsTooltip: 'Auto fix overlaps: trim note end to the start of the next note',
|
| 83 |
+
jsonImported: (name: string) => `Loaded from JSON ${name}`,
|
| 84 |
+
jsonImportFailed: 'JSON import failed, please check the file format',
|
| 85 |
+
jsonExported: 'Exported META JSON file',
|
| 86 |
+
|
| 87 |
+
// Audio bar
|
| 88 |
+
importAudio: 'Import Audio',
|
| 89 |
+
audioHint: 'Display audio waveform synced with MIDI transport',
|
| 90 |
+
midiLabel: 'MIDI',
|
| 91 |
+
audioLabel: 'Audio',
|
| 92 |
+
|
| 93 |
+
// Controls
|
| 94 |
+
horizontalZoom: 'H-Zoom',
|
| 95 |
+
verticalZoom: 'V-Zoom',
|
| 96 |
+
goToStart: 'Go to start',
|
| 97 |
+
back2s: 'Back 2s',
|
| 98 |
+
pause: 'Pause',
|
| 99 |
+
playSelection: 'Play selection',
|
| 100 |
+
play: 'Play',
|
| 101 |
+
forward2s: 'Forward 2s',
|
| 102 |
+
goToEnd: 'Go to end',
|
| 103 |
+
selectingRange: 'Selecting',
|
| 104 |
+
setRange: 'Select',
|
| 105 |
+
exitSelectMode: 'Exit selection mode (and clear selection)',
|
| 106 |
+
setRangeTooltip: 'Set selection: drag on the timeline to select playback range',
|
| 107 |
+
|
| 108 |
+
// Status
|
| 109 |
+
ready: 'Ready',
|
| 110 |
+
selectionPlayback: 'Playing selection...',
|
| 111 |
+
playing: 'Playing...',
|
| 112 |
+
selectionDone: 'Selection playback done',
|
| 113 |
+
paused: 'Paused',
|
| 114 |
+
imported: (name: string) => `Loaded ${name}`,
|
| 115 |
+
importFailed: 'Import failed, please check the file',
|
| 116 |
+
audioImported: (name: string) => `Loaded audio ${name}`,
|
| 117 |
+
unsupportedFormat: (exts: string) => `Unsupported format, please select an audio file (${exts})`,
|
| 118 |
+
fixedOverlaps: (count: number) => `Fixed ${count} overlapping note(s)`,
|
| 119 |
+
noOverlaps: 'No overlapping notes detected',
|
| 120 |
+
exported: 'Exported MIDI file with lyrics',
|
| 121 |
+
|
| 122 |
+
// Lyric table
|
| 123 |
+
fillPlaceholderSelected: 'Fill words from selected note',
|
| 124 |
+
fillPlaceholderDefault: 'Enter lyrics, click fill button',
|
| 125 |
+
fillButton: 'Fill\nWords',
|
| 126 |
+
lyricPlaceholder: 'Type lyric',
|
| 127 |
+
emptyHint: 'Import or double-click piano roll to add notes',
|
| 128 |
+
confirmEdit: 'Confirm (Enter)',
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
const translations: Record<Lang, typeof zh> = { zh, en }
|
| 132 |
+
|
| 133 |
+
export type Translations = typeof zh
|
| 134 |
+
|
| 135 |
+
export function getTranslations(lang: Lang): Translations {
|
| 136 |
+
return translations[lang]
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Smart tokenizer for lyrics: CJK characters are individual tokens, Latin words are grouped
|
| 140 |
+
function isCJK(char: string): boolean {
|
| 141 |
+
const code = char.codePointAt(0) || 0
|
| 142 |
+
return (
|
| 143 |
+
(code >= 0x4E00 && code <= 0x9FFF) || // CJK Unified Ideographs
|
| 144 |
+
(code >= 0x3400 && code <= 0x4DBF) || // CJK Extension A
|
| 145 |
+
(code >= 0x20000 && code <= 0x2A6DF) || // CJK Extension B
|
| 146 |
+
(code >= 0x3040 && code <= 0x309F) || // Hiragana
|
| 147 |
+
(code >= 0x30A0 && code <= 0x30FF) || // Katakana
|
| 148 |
+
(code >= 0xAC00 && code <= 0xD7AF) // Hangul Syllables
|
| 149 |
+
)
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/**
|
| 153 |
+
* Tokenize lyrics text for note filling.
|
| 154 |
+
* - CJK characters: each character becomes one token (one per note)
|
| 155 |
+
* - Latin/English words: each space-separated word becomes one token (one per note)
|
| 156 |
+
* - Mixed text is handled correctly
|
| 157 |
+
*
|
| 158 |
+
* Examples:
|
| 159 |
+
* "你好世界" -> ["你", "好", "世", "界"]
|
| 160 |
+
* "hello world" -> ["hello", "world"]
|
| 161 |
+
* "I love 你" -> ["I", "love", "你"]
|
| 162 |
+
* "something wrong" -> ["something", "wrong"]
|
| 163 |
+
*/
|
| 164 |
+
export function tokenizeLyrics(text: string): string[] {
|
| 165 |
+
const tokens: string[] = []
|
| 166 |
+
const cleaned = text.trim()
|
| 167 |
+
if (!cleaned) return tokens
|
| 168 |
+
|
| 169 |
+
let i = 0
|
| 170 |
+
while (i < cleaned.length) {
|
| 171 |
+
const char = cleaned[i]
|
| 172 |
+
|
| 173 |
+
// Skip whitespace
|
| 174 |
+
if (/\s/.test(char)) {
|
| 175 |
+
i++
|
| 176 |
+
continue
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// CJK character - each is a separate token
|
| 180 |
+
if (isCJK(char)) {
|
| 181 |
+
tokens.push(char)
|
| 182 |
+
i++
|
| 183 |
+
continue
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// Latin/number/other - collect until whitespace or CJK
|
| 187 |
+
let word = ''
|
| 188 |
+
while (i < cleaned.length && !/\s/.test(cleaned[i]) && !isCJK(cleaned[i])) {
|
| 189 |
+
word += cleaned[i]
|
| 190 |
+
i++
|
| 191 |
+
}
|
| 192 |
+
if (word) tokens.push(word)
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return tokens
|
| 196 |
+
}
|
preprocess/tools/midi_editor/src/index.css
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@tailwind base;
|
| 2 |
+
@tailwind components;
|
| 3 |
+
@tailwind utilities;
|
| 4 |
+
|
| 5 |
+
:root {
|
| 6 |
+
font-family: 'Space Grotesk', 'IBM Plex Sans', system-ui, sans-serif;
|
| 7 |
+
color: var(--text-primary);
|
| 8 |
+
background: radial-gradient(circle at 20% 20%, rgba(72, 228, 194, 0.08), transparent 35%),
|
| 9 |
+
radial-gradient(circle at 80% 0%, rgba(75, 100, 188, 0.24), transparent 40%),
|
| 10 |
+
#0f1528;
|
| 11 |
+
text-rendering: optimizeLegibility;
|
| 12 |
+
-webkit-font-smoothing: antialiased;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
:root[data-theme='light'] {
|
| 16 |
+
background: radial-gradient(circle at 20% 20%, rgba(63, 140, 255, 0.08), transparent 35%),
|
| 17 |
+
radial-gradient(circle at 80% 0%, rgba(75, 100, 188, 0.14), transparent 40%),
|
| 18 |
+
#f5f7fb;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
* {
|
| 22 |
+
box-sizing: border-box;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
body {
|
| 26 |
+
margin: 0;
|
| 27 |
+
min-height: 100vh;
|
| 28 |
+
background: transparent;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
#root {
|
| 32 |
+
min-height: 100vh;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
a {
|
| 36 |
+
color: inherit;
|
| 37 |
+
}
|
preprocess/tools/midi_editor/src/lib/midi.ts
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Midi } from '@tonejs/midi'
|
| 2 |
+
import { writeMidi } from 'midi-file'
|
| 3 |
+
import type { MidiData, MidiEvent } from 'midi-file'
|
| 4 |
+
import type { NoteEvent, ProjectSnapshot, TimeSignature } from '../types'
|
| 5 |
+
|
| 6 |
+
const DEFAULT_SIGNATURE: TimeSignature = [4, 4]
|
| 7 |
+
|
| 8 |
+
// Decode UTF-8 byte string (latin1 encoded) to proper Unicode string
|
| 9 |
+
// This matches: text.encode("latin1").decode("utf-8") in Python
|
| 10 |
+
function decodeUtf8ByteString(byteString: string): string {
|
| 11 |
+
try {
|
| 12 |
+
const bytes = new Uint8Array(byteString.length)
|
| 13 |
+
for (let i = 0; i < byteString.length; i++) {
|
| 14 |
+
bytes[i] = byteString.charCodeAt(i)
|
| 15 |
+
}
|
| 16 |
+
return new TextDecoder('utf-8').decode(bytes)
|
| 17 |
+
} catch {
|
| 18 |
+
return byteString
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
// Encode Unicode string to UTF-8 byte string (latin1 encoding)
|
| 23 |
+
// This matches: text.encode("utf-8").decode("latin1") in Python
|
| 24 |
+
function encodeUtf8ByteString(text: string): string {
|
| 25 |
+
const bytes = new TextEncoder().encode(text)
|
| 26 |
+
let output = ''
|
| 27 |
+
bytes.forEach((b) => {
|
| 28 |
+
output += String.fromCharCode(b)
|
| 29 |
+
})
|
| 30 |
+
return output
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
export async function importMidiFile(file: File): Promise<ProjectSnapshot> {
|
| 34 |
+
const buffer = await file.arrayBuffer()
|
| 35 |
+
return parseMidiBuffer(buffer)
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
export async function parseMidiBuffer(buffer: ArrayBuffer): Promise<ProjectSnapshot> {
|
| 39 |
+
const midi = new Midi(buffer)
|
| 40 |
+
const tempo = midi.header.tempos[0]?.bpm ?? 120
|
| 41 |
+
const timeSignature = (midi.header.timeSignatures[0]?.timeSignature as TimeSignature | undefined) ?? DEFAULT_SIGNATURE
|
| 42 |
+
|
| 43 |
+
// Merge notes from all tracks and sort by ticks then by midi (for stable ordering)
|
| 44 |
+
const allNotes = midi.tracks
|
| 45 |
+
.flatMap(t => t.notes)
|
| 46 |
+
.sort((a, b) => a.ticks - b.ticks || a.midi - b.midi)
|
| 47 |
+
|
| 48 |
+
// Get lyrics from header.meta and sort by ticks
|
| 49 |
+
const lyricEvents = midi.header.meta
|
| 50 |
+
.filter((event) => event.type === 'lyrics')
|
| 51 |
+
.sort((a, b) => a.ticks - b.ticks)
|
| 52 |
+
|
| 53 |
+
// Match lyrics to notes by tick position
|
| 54 |
+
// Each lyric should be consumed by exactly one note at the same tick
|
| 55 |
+
const lyricsByTick = new Map<number, string[]>()
|
| 56 |
+
for (const event of lyricEvents) {
|
| 57 |
+
const existing = lyricsByTick.get(event.ticks) || []
|
| 58 |
+
existing.push(decodeUtf8ByteString(event.text))
|
| 59 |
+
lyricsByTick.set(event.ticks, existing)
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// Track which lyrics have been used at each tick position
|
| 63 |
+
const usedLyricIndices = new Map<number, number>()
|
| 64 |
+
|
| 65 |
+
const notes: NoteEvent[] = allNotes.map((note, index) => {
|
| 66 |
+
const beat = note.ticks / midi.header.ppq
|
| 67 |
+
const durationBeats = note.durationTicks / midi.header.ppq
|
| 68 |
+
|
| 69 |
+
let lyric = ''
|
| 70 |
+
|
| 71 |
+
// First try exact tick match
|
| 72 |
+
const lyricsAtTick = lyricsByTick.get(note.ticks)
|
| 73 |
+
if (lyricsAtTick && lyricsAtTick.length > 0) {
|
| 74 |
+
const usedIndex = usedLyricIndices.get(note.ticks) || 0
|
| 75 |
+
if (usedIndex < lyricsAtTick.length) {
|
| 76 |
+
lyric = lyricsAtTick[usedIndex]
|
| 77 |
+
usedLyricIndices.set(note.ticks, usedIndex + 1)
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// If no exact match, try nearby ticks (within small tolerance)
|
| 82 |
+
if (!lyric) {
|
| 83 |
+
const tolerance = midi.header.ppq / 100 // Very small tolerance
|
| 84 |
+
for (const [tick, lyrics] of lyricsByTick.entries()) {
|
| 85 |
+
if (Math.abs(tick - note.ticks) <= tolerance) {
|
| 86 |
+
const usedIndex = usedLyricIndices.get(tick) || 0
|
| 87 |
+
if (usedIndex < lyrics.length) {
|
| 88 |
+
lyric = lyrics[usedIndex]
|
| 89 |
+
usedLyricIndices.set(tick, usedIndex + 1)
|
| 90 |
+
break
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
id: `${index}-${note.midi}-${Math.round(note.ticks)}`,
|
| 98 |
+
midi: note.midi,
|
| 99 |
+
start: beat,
|
| 100 |
+
duration: Math.max(durationBeats, 0.0625),
|
| 101 |
+
velocity: note.velocity,
|
| 102 |
+
lyric,
|
| 103 |
+
}
|
| 104 |
+
})
|
| 105 |
+
|
| 106 |
+
return { tempo, timeSignature, notes, ppq: midi.header.ppq }
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Used to add absoluteTime property for sorting
|
| 110 |
+
type WithAbsoluteTime<T> = T & { absoluteTime: number }
|
| 111 |
+
|
| 112 |
+
export function exportMidi(snapshot: ProjectSnapshot): Blob {
|
| 113 |
+
const ppq = snapshot.ppq ?? 480 // Use original ppq if available, otherwise default to 480
|
| 114 |
+
const microsecondsPerBeat = Math.round(60000000 / snapshot.tempo) // Convert BPM to microseconds per beat
|
| 115 |
+
|
| 116 |
+
// Sort notes by start time, then by midi for stable ordering
|
| 117 |
+
const sortedNotes = [...snapshot.notes].sort((a, b) => a.start - b.start || a.midi - b.midi)
|
| 118 |
+
|
| 119 |
+
// Build events for a single track containing both lyrics and notes
|
| 120 |
+
// Event order at same tick: note_off (0) < lyrics (1) < note_on (2)
|
| 121 |
+
// This matches meta.py's tg2midi implementation
|
| 122 |
+
const events: Array<WithAbsoluteTime<MidiEvent>> = []
|
| 123 |
+
|
| 124 |
+
// Add all note events and their corresponding lyrics
|
| 125 |
+
sortedNotes.forEach((note) => {
|
| 126 |
+
const startTicks = Math.round(note.start * ppq)
|
| 127 |
+
const endTicks = Math.round((note.start + note.duration) * ppq)
|
| 128 |
+
const velocity = Math.round(note.velocity * 127)
|
| 129 |
+
|
| 130 |
+
// Add lyric event at the same tick as note_on (but will be sorted before it)
|
| 131 |
+
const lyricText = note.lyric ?? ''
|
| 132 |
+
const encodedLyric = encodeUtf8ByteString(lyricText)
|
| 133 |
+
|
| 134 |
+
// Lyric event - sort key 1 (after note_off, before note_on)
|
| 135 |
+
events.push({
|
| 136 |
+
absoluteTime: startTicks,
|
| 137 |
+
deltaTime: 0,
|
| 138 |
+
meta: true,
|
| 139 |
+
type: 'lyrics',
|
| 140 |
+
text: encodedLyric,
|
| 141 |
+
_sortKey: 1,
|
| 142 |
+
} as WithAbsoluteTime<MidiEvent> & { _sortKey: number })
|
| 143 |
+
|
| 144 |
+
// Note on event - sort key 2 (after lyrics)
|
| 145 |
+
events.push({
|
| 146 |
+
absoluteTime: startTicks,
|
| 147 |
+
deltaTime: 0,
|
| 148 |
+
type: 'noteOn',
|
| 149 |
+
channel: 0,
|
| 150 |
+
noteNumber: note.midi,
|
| 151 |
+
velocity: velocity,
|
| 152 |
+
_sortKey: 2,
|
| 153 |
+
} as WithAbsoluteTime<MidiEvent> & { _sortKey: number })
|
| 154 |
+
|
| 155 |
+
// Note off event - sort key 0 (before everything at same tick)
|
| 156 |
+
events.push({
|
| 157 |
+
absoluteTime: endTicks,
|
| 158 |
+
deltaTime: 0,
|
| 159 |
+
type: 'noteOff',
|
| 160 |
+
channel: 0,
|
| 161 |
+
noteNumber: note.midi,
|
| 162 |
+
velocity: 0,
|
| 163 |
+
_sortKey: 0,
|
| 164 |
+
} as WithAbsoluteTime<MidiEvent> & { _sortKey: number })
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
// Sort events by absoluteTime, then by _sortKey
|
| 168 |
+
events.sort((a, b) => {
|
| 169 |
+
const aKey = (a as { _sortKey?: number })._sortKey ?? 1
|
| 170 |
+
const bKey = (b as { _sortKey?: number })._sortKey ?? 1
|
| 171 |
+
return a.absoluteTime - b.absoluteTime || aKey - bKey
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
// Convert absolute time to delta time
|
| 175 |
+
let lastTick = 0
|
| 176 |
+
events.forEach(event => {
|
| 177 |
+
event.deltaTime = event.absoluteTime - lastTick
|
| 178 |
+
lastTick = event.absoluteTime
|
| 179 |
+
delete (event as { absoluteTime?: number }).absoluteTime
|
| 180 |
+
delete (event as { _sortKey?: number })._sortKey
|
| 181 |
+
})
|
| 182 |
+
|
| 183 |
+
// Build the MIDI track with header events
|
| 184 |
+
const track: MidiEvent[] = [
|
| 185 |
+
// Set tempo
|
| 186 |
+
{
|
| 187 |
+
deltaTime: 0,
|
| 188 |
+
meta: true,
|
| 189 |
+
type: 'setTempo',
|
| 190 |
+
microsecondsPerBeat: microsecondsPerBeat,
|
| 191 |
+
},
|
| 192 |
+
// Time signature
|
| 193 |
+
{
|
| 194 |
+
deltaTime: 0,
|
| 195 |
+
meta: true,
|
| 196 |
+
type: 'timeSignature',
|
| 197 |
+
numerator: snapshot.timeSignature[0],
|
| 198 |
+
denominator: snapshot.timeSignature[1],
|
| 199 |
+
metronome: 24,
|
| 200 |
+
thirtyseconds: 8,
|
| 201 |
+
},
|
| 202 |
+
// All note and lyric events
|
| 203 |
+
...events,
|
| 204 |
+
// End of track
|
| 205 |
+
{
|
| 206 |
+
deltaTime: 0,
|
| 207 |
+
meta: true,
|
| 208 |
+
type: 'endOfTrack',
|
| 209 |
+
},
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
// Build MIDI data structure
|
| 213 |
+
const midiData: MidiData = {
|
| 214 |
+
header: {
|
| 215 |
+
format: 0, // Single track format (type 0)
|
| 216 |
+
numTracks: 1,
|
| 217 |
+
ticksPerBeat: ppq,
|
| 218 |
+
},
|
| 219 |
+
tracks: [track],
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
const bytes = writeMidi(midiData)
|
| 223 |
+
return new Blob([new Uint8Array(bytes)], { type: 'audio/midi' })
|
| 224 |
+
}
|
preprocess/tools/midi_editor/src/main.tsx
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { StrictMode } from 'react'
|
| 2 |
+
import { createRoot } from 'react-dom/client'
|
| 3 |
+
import './index.css'
|
| 4 |
+
import App from './App.tsx'
|
| 5 |
+
|
| 6 |
+
createRoot(document.getElementById('root')!).render(
|
| 7 |
+
<StrictMode>
|
| 8 |
+
<App />
|
| 9 |
+
</StrictMode>,
|
| 10 |
+
)
|
preprocess/tools/midi_editor/src/store/useMidiStore.ts
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { nanoid } from 'nanoid'
|
| 2 |
+
import { create } from 'zustand'
|
| 3 |
+
import type { NoteEvent, TimeSignature } from '../types'
|
| 4 |
+
|
| 5 |
+
const clamp = (value: number, min: number, max: number) =>
|
| 6 |
+
Math.min(Math.max(value, min), max)
|
| 7 |
+
|
| 8 |
+
export type MidiStore = {
|
| 9 |
+
tempo: number
|
| 10 |
+
timeSignature: TimeSignature
|
| 11 |
+
notes: NoteEvent[]
|
| 12 |
+
selectedId: string | null
|
| 13 |
+
playhead: number
|
| 14 |
+
ppq: number | undefined // Ticks per quarter note (for preserving original MIDI timing)
|
| 15 |
+
addNote: (partial?: Partial<NoteEvent>) => NoteEvent
|
| 16 |
+
updateNote: (id: string, partial: Partial<NoteEvent>) => void
|
| 17 |
+
removeNote: (id: string) => void
|
| 18 |
+
setNotes: (notes: NoteEvent[]) => void
|
| 19 |
+
setTempo: (tempo: number) => void
|
| 20 |
+
setTimeSignature: (sig: TimeSignature) => void
|
| 21 |
+
setPpq: (ppq: number | undefined) => void
|
| 22 |
+
select: (id: string | null) => void
|
| 23 |
+
setLyric: (id: string, lyric: string) => void
|
| 24 |
+
setPlayhead: (beat: number) => void
|
| 25 |
+
clear: () => void
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
const defaultNotes: NoteEvent[] = [
|
| 29 |
+
{ id: nanoid(), midi: 64, start: 0, duration: 1.5, velocity: 0.9, lyric: 'la' },
|
| 30 |
+
{ id: nanoid(), midi: 67, start: 1.5, duration: 1.5, velocity: 0.85, lyric: 'na' },
|
| 31 |
+
{ id: nanoid(), midi: 69, start: 3, duration: 2, velocity: 0.8, lyric: 'ah' },
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
export const useMidiStore = create<MidiStore>((set) => ({
|
| 35 |
+
tempo: 110,
|
| 36 |
+
timeSignature: [4, 4],
|
| 37 |
+
notes: defaultNotes,
|
| 38 |
+
selectedId: null,
|
| 39 |
+
playhead: 0,
|
| 40 |
+
ppq: undefined,
|
| 41 |
+
addNote: (partial = {}) => {
|
| 42 |
+
const note: NoteEvent = {
|
| 43 |
+
id: nanoid(),
|
| 44 |
+
midi: partial.midi ?? 64,
|
| 45 |
+
start: partial.start ?? 0,
|
| 46 |
+
duration: partial.duration ?? 1,
|
| 47 |
+
velocity: clamp(partial.velocity ?? 0.85, 0, 1),
|
| 48 |
+
lyric: partial.lyric ?? '',
|
| 49 |
+
}
|
| 50 |
+
set((state) => ({ notes: [...state.notes, note] }))
|
| 51 |
+
return note
|
| 52 |
+
},
|
| 53 |
+
updateNote: (id, partial) => {
|
| 54 |
+
set((state) => ({
|
| 55 |
+
notes: state.notes.map((note) =>
|
| 56 |
+
note.id === id
|
| 57 |
+
? {
|
| 58 |
+
...note,
|
| 59 |
+
...partial,
|
| 60 |
+
duration: Math.max(partial.duration ?? note.duration, 0.0625),
|
| 61 |
+
}
|
| 62 |
+
: note,
|
| 63 |
+
),
|
| 64 |
+
}))
|
| 65 |
+
},
|
| 66 |
+
removeNote: (id) => set((state) => ({ notes: state.notes.filter((n) => n.id !== id) })),
|
| 67 |
+
setNotes: (notes) => set(() => ({ notes })),
|
| 68 |
+
setTempo: (tempo) => set(() => ({ tempo: clamp(tempo, 30, 240) })),
|
| 69 |
+
setTimeSignature: (sig) => set(() => ({ timeSignature: sig })),
|
| 70 |
+
setPpq: (ppq) => set(() => ({ ppq })),
|
| 71 |
+
select: (id) => set(() => ({ selectedId: id })),
|
| 72 |
+
setLyric: (id, lyric) =>
|
| 73 |
+
set((state) => ({
|
| 74 |
+
notes: state.notes.map((note) => (note.id === id ? { ...note, lyric } : note)),
|
| 75 |
+
})),
|
| 76 |
+
setPlayhead: (beat) => set(() => ({ playhead: Math.max(beat, 0) })),
|
| 77 |
+
clear: () => set(() => ({ notes: [], selectedId: null })),
|
| 78 |
+
}))
|
preprocess/tools/midi_editor/src/types.ts
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export type NoteEvent = {
|
| 2 |
+
id: string
|
| 3 |
+
midi: number
|
| 4 |
+
start: number // in beats
|
| 5 |
+
duration: number // in beats
|
| 6 |
+
velocity: number
|
| 7 |
+
lyric: string
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
export type TimeSignature = [number, number]
|
| 11 |
+
|
| 12 |
+
export type ProjectSnapshot = {
|
| 13 |
+
tempo: number
|
| 14 |
+
timeSignature: TimeSignature
|
| 15 |
+
notes: NoteEvent[]
|
| 16 |
+
ppq?: number // Ticks per quarter note (for preserving original MIDI timing)
|
| 17 |
+
}
|
preprocess/tools/midi_editor/tailwind.config.js
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/** @type {import('tailwindcss').Config} */
|
| 2 |
+
export default {
|
| 3 |
+
content: ['./index.html', './src/**/*.{ts,tsx,js,jsx}'],
|
| 4 |
+
theme: {
|
| 5 |
+
extend: {
|
| 6 |
+
fontFamily: {
|
| 7 |
+
display: ['"Space Grotesk"', '"IBM Plex Sans"', 'system-ui', 'sans-serif'],
|
| 8 |
+
mono: ['"JetBrains Mono"', 'ui-monospace', 'SFMono-Regular', 'monospace'],
|
| 9 |
+
},
|
| 10 |
+
colors: {
|
| 11 |
+
ink: {
|
| 12 |
+
50: '#f4f7fb',
|
| 13 |
+
100: '#dfe7f5',
|
| 14 |
+
200: '#beceec',
|
| 15 |
+
300: '#95addf',
|
| 16 |
+
400: '#6a87ce',
|
| 17 |
+
500: '#4b64bc',
|
| 18 |
+
600: '#3b4ea7',
|
| 19 |
+
700: '#32418a',
|
| 20 |
+
800: '#2c376f',
|
| 21 |
+
900: '#262f5c',
|
| 22 |
+
},
|
| 23 |
+
ember: '#ff7043',
|
| 24 |
+
mint: '#48e4c2',
|
| 25 |
+
},
|
| 26 |
+
boxShadow: {
|
| 27 |
+
panel: '0 14px 35px rgba(0, 0, 0, 0.25)',
|
| 28 |
+
},
|
| 29 |
+
},
|
| 30 |
+
},
|
| 31 |
+
plugins: [],
|
| 32 |
+
}
|
| 33 |
+
|
preprocess/tools/midi_editor/tsconfig.app.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
| 4 |
+
"target": "ES2022",
|
| 5 |
+
"useDefineForClassFields": true,
|
| 6 |
+
"lib": ["ES2022", "DOM", "DOM.Iterable"],
|
| 7 |
+
"module": "ESNext",
|
| 8 |
+
"types": ["vite/client"],
|
| 9 |
+
"skipLibCheck": true,
|
| 10 |
+
|
| 11 |
+
/* Bundler mode */
|
| 12 |
+
"moduleResolution": "bundler",
|
| 13 |
+
"allowImportingTsExtensions": true,
|
| 14 |
+
"verbatimModuleSyntax": true,
|
| 15 |
+
"moduleDetection": "force",
|
| 16 |
+
"noEmit": true,
|
| 17 |
+
"jsx": "react-jsx",
|
| 18 |
+
|
| 19 |
+
/* Linting */
|
| 20 |
+
"strict": true,
|
| 21 |
+
"noUnusedLocals": true,
|
| 22 |
+
"noUnusedParameters": true,
|
| 23 |
+
"erasableSyntaxOnly": true,
|
| 24 |
+
"noFallthroughCasesInSwitch": true,
|
| 25 |
+
"noUncheckedSideEffectImports": true
|
| 26 |
+
},
|
| 27 |
+
"include": ["src"]
|
| 28 |
+
}
|
preprocess/tools/midi_editor/tsconfig.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"files": [],
|
| 3 |
+
"references": [
|
| 4 |
+
{ "path": "./tsconfig.app.json" },
|
| 5 |
+
{ "path": "./tsconfig.node.json" }
|
| 6 |
+
]
|
| 7 |
+
}
|
preprocess/tools/midi_editor/tsconfig.node.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
| 4 |
+
"target": "ES2023",
|
| 5 |
+
"lib": ["ES2023"],
|
| 6 |
+
"module": "ESNext",
|
| 7 |
+
"types": ["node"],
|
| 8 |
+
"skipLibCheck": true,
|
| 9 |
+
|
| 10 |
+
/* Bundler mode */
|
| 11 |
+
"moduleResolution": "bundler",
|
| 12 |
+
"allowImportingTsExtensions": true,
|
| 13 |
+
"verbatimModuleSyntax": true,
|
| 14 |
+
"moduleDetection": "force",
|
| 15 |
+
"noEmit": true,
|
| 16 |
+
|
| 17 |
+
/* Linting */
|
| 18 |
+
"strict": true,
|
| 19 |
+
"noUnusedLocals": true,
|
| 20 |
+
"noUnusedParameters": true,
|
| 21 |
+
"erasableSyntaxOnly": true,
|
| 22 |
+
"noFallthroughCasesInSwitch": true,
|
| 23 |
+
"noUncheckedSideEffectImports": true
|
| 24 |
+
},
|
| 25 |
+
"include": ["vite.config.ts"]
|
| 26 |
+
}
|
preprocess/tools/midi_editor/vite.config.ts
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { defineConfig } from 'vite'
|
| 2 |
+
import react from '@vitejs/plugin-react'
|
| 3 |
+
|
| 4 |
+
// https://vite.dev/config/
|
| 5 |
+
export default defineConfig({
|
| 6 |
+
plugins: [react()],
|
| 7 |
+
})
|
preprocess/tools/midi_parser.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SoulX-Singer MIDI <-> metadata converter.
|
| 3 |
+
|
| 4 |
+
Converts between SoulX-Singer-style metadata JSON (with note_text, note_dur,
|
| 5 |
+
note_pitch, note_type per segment) and standard MIDI files. Uses an internal
|
| 6 |
+
Note dataclass (start_s, note_dur, note_text, note_pitch, note_type) as the
|
| 7 |
+
intermediate representation.
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import shutil
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Any, List, Tuple, Union
|
| 14 |
+
|
| 15 |
+
import librosa
|
| 16 |
+
import mido
|
| 17 |
+
from soundfile import write
|
| 18 |
+
|
| 19 |
+
from .f0_extraction import F0Extractor
|
| 20 |
+
from .g2p import g2p_transform
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Audio, MIDI and segmentation constants
|
| 24 |
+
SAMPLE_RATE = 44100 # Audio sample rate for any wav cuts during midi2meta
|
| 25 |
+
MIDI_TICKS_PER_BEAT = 500 # The number of MIDI ticks per beat; affects the time resolution of MIDI output and conversion accuracy.
|
| 26 |
+
MIDI_TEMPO = 500000 # Microseconds per beat (120 BPM)
|
| 27 |
+
MIDI_TIME_SIGNATURE = (4, 4) # Default time signature; not critical for conversion but included in MIDI output.
|
| 28 |
+
MIDI_VELOCITY = 64 # Default velocity for note_on events; not critical for conversion but required for MIDI format.
|
| 29 |
+
END_EXTENSION_SEC = 0.4 # Extend each segment end by this much silence (sec) to give the model more context
|
| 30 |
+
MAX_GAP_SEC = 2.0 # Gap threshold to split segments in midi2meta (sec)
|
| 31 |
+
MAX_SEGMENT_DUR_SUM_SEC = 60.0 # Max total duration sum of notes in a single metadata segment before splitting into multiple segments (sec)
|
| 32 |
+
SILENCE_THRESHOLD_SEC = 0.2 # Threshold to insert explicit <SP> note for long silences between notes in midi2notes (sec)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Note:
|
| 37 |
+
"""Single note: text, duration (seconds), pitch (MIDI), type. start_s is absolute start time in seconds (for ordering / MIDI)."""
|
| 38 |
+
start_s: float
|
| 39 |
+
note_dur: float
|
| 40 |
+
note_text: str
|
| 41 |
+
note_pitch: int
|
| 42 |
+
note_type: int
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def end_s(self) -> float:
|
| 46 |
+
return self.start_s + self.note_dur
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _seconds_to_ticks(seconds: float, ticks_per_beat: int, tempo: int) -> int:
|
| 50 |
+
"""Convert seconds to MIDI ticks based on tempo and ticks per beat."""
|
| 51 |
+
return int(round(seconds * ticks_per_beat * 1_000_000 / tempo))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _append_segment_to_meta(
|
| 55 |
+
meta_data: List[dict],
|
| 56 |
+
meta_path_str: str,
|
| 57 |
+
cut_wavs_output_dir: str | None,
|
| 58 |
+
vocal_file: str | None,
|
| 59 |
+
language: str,
|
| 60 |
+
audio_data: Any | None,
|
| 61 |
+
pitch_extractor: F0Extractor | None,
|
| 62 |
+
note_start: List[float],
|
| 63 |
+
note_end: List[float],
|
| 64 |
+
note_text: List[Any],
|
| 65 |
+
note_pitch: List[Any],
|
| 66 |
+
note_type: List[Any],
|
| 67 |
+
note_dur: List[float],
|
| 68 |
+
) -> None:
|
| 69 |
+
"""Helper function for midi2meta to append the current segment (accumulated in note_*) to meta_data list, with optional wav cut and pitch extraction."""
|
| 70 |
+
if not all((note_start, note_end, note_text, note_pitch, note_type, note_dur)):
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
base_name = os.path.splitext(os.path.basename(meta_path_str))[0]
|
| 74 |
+
item_name = f"{base_name}_{len(meta_data)}"
|
| 75 |
+
wav_fn = None
|
| 76 |
+
if cut_wavs_output_dir and vocal_file and audio_data is not None:
|
| 77 |
+
wav_fn = os.path.join(cut_wavs_output_dir, f"{item_name}.wav")
|
| 78 |
+
end_pad = int(END_EXTENSION_SEC * SAMPLE_RATE)
|
| 79 |
+
start_sample = max(0, int(note_start[0] * SAMPLE_RATE))
|
| 80 |
+
end_sample = min(len(audio_data), int(note_end[-1] * SAMPLE_RATE) + end_pad)
|
| 81 |
+
|
| 82 |
+
end_pad_dur = (end_sample / SAMPLE_RATE - note_end[-1]) if end_sample > int(note_end[-1] * SAMPLE_RATE) else 0.0
|
| 83 |
+
if end_pad_dur > 0:
|
| 84 |
+
note_dur = note_dur + [end_pad_dur]
|
| 85 |
+
note_text = note_text + ["<SP>"]
|
| 86 |
+
note_pitch = note_pitch + [0]
|
| 87 |
+
note_type = note_type + [1]
|
| 88 |
+
start_ms = int(start_sample / SAMPLE_RATE * 1000)
|
| 89 |
+
end_ms = int(end_sample / SAMPLE_RATE * 1000)
|
| 90 |
+
write(wav_fn, audio_data[start_sample:end_sample], SAMPLE_RATE)
|
| 91 |
+
else:
|
| 92 |
+
start_ms = int(note_start[0] * 1000)
|
| 93 |
+
end_ms = int(note_end[-1] * 1000)
|
| 94 |
+
|
| 95 |
+
if pitch_extractor is not None:
|
| 96 |
+
if not wav_fn or not os.path.isfile(wav_fn):
|
| 97 |
+
raise FileNotFoundError(f"Segment wav file not found: {wav_fn}")
|
| 98 |
+
f0 = pitch_extractor.process(wav_fn)
|
| 99 |
+
else:
|
| 100 |
+
f0 = []
|
| 101 |
+
|
| 102 |
+
note_text_list = list(note_text)
|
| 103 |
+
note_pitch_list = list(note_pitch)
|
| 104 |
+
note_type_list = list(note_type)
|
| 105 |
+
note_dur_list = list(note_dur)
|
| 106 |
+
|
| 107 |
+
meta_data.append(
|
| 108 |
+
{
|
| 109 |
+
"index": item_name,
|
| 110 |
+
"language": language,
|
| 111 |
+
"time": [start_ms, end_ms],
|
| 112 |
+
"duration": " ".join(str(round(x, 2)) for x in note_dur_list),
|
| 113 |
+
"text": " ".join(note_text_list),
|
| 114 |
+
"phoneme": " ".join(g2p_transform(note_text_list, language)),
|
| 115 |
+
"note_pitch": " ".join(str(x) for x in note_pitch_list),
|
| 116 |
+
"note_type": " ".join(str(x) for x in note_type_list),
|
| 117 |
+
"f0": " ".join(str(round(float(x), 1)) for x in f0),
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def meta2notes(meta_path: str) -> List[Note]:
|
| 123 |
+
"""Parse SoulX-Singer metadata JSON into a flat list of Note (absolute start_s)."""
|
| 124 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 125 |
+
segments = json.load(f)
|
| 126 |
+
if not isinstance(segments, list):
|
| 127 |
+
raise ValueError(f"Metadata must be a list of segments, got {type(segments).__name__}")
|
| 128 |
+
if not segments:
|
| 129 |
+
raise ValueError("Metadata has no segments.")
|
| 130 |
+
|
| 131 |
+
notes: List[Note] = []
|
| 132 |
+
for seg in segments:
|
| 133 |
+
offset_s = seg["time"][0] / 1000
|
| 134 |
+
words = [str(x).replace("<AP>", "<SP>") for x in seg["text"].split()]
|
| 135 |
+
word_durs = [float(x) for x in seg["duration"].split()]
|
| 136 |
+
pitches = [int(x) for x in seg["note_pitch"].split()]
|
| 137 |
+
types = [int(x) if words[i] != "<SP>" else 1 for i, x in enumerate(seg["note_type"].split())]
|
| 138 |
+
if len(words) != len(word_durs) or len(word_durs) != len(pitches) or len(pitches) != len(types):
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"Length mismatch in segment {seg.get('item_name', '?')}: "
|
| 141 |
+
"note_text, note_dur, note_pitch, note_type must have same length"
|
| 142 |
+
)
|
| 143 |
+
current_s = offset_s
|
| 144 |
+
for text, dur, pitch, type_ in zip(words, word_durs, pitches, types):
|
| 145 |
+
notes.append(
|
| 146 |
+
Note(
|
| 147 |
+
start_s=current_s,
|
| 148 |
+
note_dur=float(dur),
|
| 149 |
+
note_text=str(text),
|
| 150 |
+
note_pitch=int(pitch),
|
| 151 |
+
note_type=int(type_),
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
current_s += float(dur)
|
| 155 |
+
return notes
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def notes2meta(
|
| 159 |
+
notes: List[Note],
|
| 160 |
+
meta_path: str,
|
| 161 |
+
vocal_file: str | None,
|
| 162 |
+
language: str,
|
| 163 |
+
pitch_extractor: F0Extractor | None,
|
| 164 |
+
) -> None:
|
| 165 |
+
"""Write SoulX-Singer metadata JSON from a list of Note (segmenting + wav cuts)."""
|
| 166 |
+
meta_path_str = str(meta_path)
|
| 167 |
+
|
| 168 |
+
cut_wavs_output_dir = None
|
| 169 |
+
if vocal_file:
|
| 170 |
+
cut_wavs_output_dir = os.path.join(os.path.dirname(vocal_file), "cut_wavs_tmp")
|
| 171 |
+
os.makedirs(cut_wavs_output_dir, exist_ok=True)
|
| 172 |
+
|
| 173 |
+
note_text: List[Any] = []
|
| 174 |
+
note_pitch: List[Any] = []
|
| 175 |
+
note_type: List[Any] = []
|
| 176 |
+
note_dur: List[float] = []
|
| 177 |
+
note_start: List[float] = []
|
| 178 |
+
note_end: List[float] = []
|
| 179 |
+
meta_data: List[dict] = []
|
| 180 |
+
audio_data = None
|
| 181 |
+
if vocal_file:
|
| 182 |
+
audio_data, _ = librosa.load(vocal_file, sr=SAMPLE_RATE, mono=True)
|
| 183 |
+
dur_sum = 0.0
|
| 184 |
+
|
| 185 |
+
def flush_current_segment() -> None:
|
| 186 |
+
nonlocal dur_sum
|
| 187 |
+
_append_segment_to_meta(
|
| 188 |
+
meta_data,
|
| 189 |
+
meta_path_str,
|
| 190 |
+
cut_wavs_output_dir,
|
| 191 |
+
vocal_file,
|
| 192 |
+
language,
|
| 193 |
+
audio_data,
|
| 194 |
+
pitch_extractor,
|
| 195 |
+
note_start,
|
| 196 |
+
note_end,
|
| 197 |
+
note_text,
|
| 198 |
+
note_pitch,
|
| 199 |
+
note_type,
|
| 200 |
+
note_dur,
|
| 201 |
+
)
|
| 202 |
+
note_text.clear()
|
| 203 |
+
note_pitch.clear()
|
| 204 |
+
note_type.clear()
|
| 205 |
+
note_dur.clear()
|
| 206 |
+
note_start.clear()
|
| 207 |
+
note_end.clear()
|
| 208 |
+
dur_sum = 0.0
|
| 209 |
+
|
| 210 |
+
def append_note(start: float, end: float, text: str, pitch: int, type_: int) -> None:
|
| 211 |
+
nonlocal dur_sum
|
| 212 |
+
duration = end - start
|
| 213 |
+
if duration <= 0:
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
if len(note_text) > 0 and text == "<SP>" and note_text[-1] == "<SP>":
|
| 217 |
+
note_dur[-1] += duration
|
| 218 |
+
note_end[-1] = end
|
| 219 |
+
else:
|
| 220 |
+
note_text.append(text)
|
| 221 |
+
note_pitch.append(pitch)
|
| 222 |
+
note_type.append(type_)
|
| 223 |
+
note_dur.append(duration)
|
| 224 |
+
note_start.append(start)
|
| 225 |
+
note_end.append(end)
|
| 226 |
+
dur_sum += duration
|
| 227 |
+
|
| 228 |
+
for note in notes:
|
| 229 |
+
start = float(note.start_s)
|
| 230 |
+
end = float(note.end_s)
|
| 231 |
+
text = note.note_text
|
| 232 |
+
pitch = note.note_pitch
|
| 233 |
+
type_ = note.note_type
|
| 234 |
+
|
| 235 |
+
if text == "" or pitch == "" or type_ == "":
|
| 236 |
+
append_note(start, end, "<SP>", 0, 1)
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
# cut the segment when ends with a long <SP> note
|
| 240 |
+
if (
|
| 241 |
+
len(note_text) > 0
|
| 242 |
+
and note_text[-1] == "<SP>"
|
| 243 |
+
and note_dur[-1] > MAX_GAP_SEC
|
| 244 |
+
):
|
| 245 |
+
note_text.pop()
|
| 246 |
+
note_pitch.pop()
|
| 247 |
+
note_type.pop()
|
| 248 |
+
note_dur.pop()
|
| 249 |
+
note_start.pop()
|
| 250 |
+
note_end.pop()
|
| 251 |
+
|
| 252 |
+
dur_sum = sum(note_dur)
|
| 253 |
+
flush_current_segment()
|
| 254 |
+
|
| 255 |
+
# cut the segment if adding the current note would exceed the max duration sum threshold
|
| 256 |
+
if dur_sum + (end - start) > MAX_SEGMENT_DUR_SUM_SEC and len(note_text) > 0:
|
| 257 |
+
flush_current_segment()
|
| 258 |
+
|
| 259 |
+
append_note(start, end, text, int(pitch), int(type_))
|
| 260 |
+
|
| 261 |
+
if note_text:
|
| 262 |
+
flush_current_segment()
|
| 263 |
+
|
| 264 |
+
with open(meta_path_str, "w", encoding="utf-8") as f:
|
| 265 |
+
json.dump(meta_data, f, ensure_ascii=False, indent=2)
|
| 266 |
+
|
| 267 |
+
if cut_wavs_output_dir:
|
| 268 |
+
try:
|
| 269 |
+
shutil.rmtree(cut_wavs_output_dir, ignore_errors=True)
|
| 270 |
+
except Exception:
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def notes2midi(
|
| 275 |
+
notes: List[Note],
|
| 276 |
+
midi_path: str,
|
| 277 |
+
) -> None:
|
| 278 |
+
"""Write MIDI file from a list of Note."""
|
| 279 |
+
if not notes:
|
| 280 |
+
raise ValueError("Empty note list.")
|
| 281 |
+
|
| 282 |
+
events: List[Tuple[int, int, Union[mido.Message, mido.MetaMessage]]] = []
|
| 283 |
+
for n in notes:
|
| 284 |
+
start_s = n.start_s
|
| 285 |
+
end_s = n.end_s
|
| 286 |
+
if end_s <= start_s:
|
| 287 |
+
continue
|
| 288 |
+
|
| 289 |
+
start_ticks = _seconds_to_ticks(
|
| 290 |
+
start_s, MIDI_TICKS_PER_BEAT, MIDI_TEMPO
|
| 291 |
+
)
|
| 292 |
+
end_ticks = _seconds_to_ticks(
|
| 293 |
+
end_s, MIDI_TICKS_PER_BEAT, MIDI_TEMPO
|
| 294 |
+
)
|
| 295 |
+
if end_ticks <= start_ticks:
|
| 296 |
+
end_ticks = start_ticks + 1
|
| 297 |
+
|
| 298 |
+
lyric = n.note_text
|
| 299 |
+
# Some DAWs store lyric text as latin1-compatible bytes; keep best-effort round-trip.
|
| 300 |
+
try:
|
| 301 |
+
lyric = lyric.encode("utf-8").decode("latin1")
|
| 302 |
+
except (UnicodeEncodeError, UnicodeDecodeError):
|
| 303 |
+
pass
|
| 304 |
+
if n.note_type == 3:
|
| 305 |
+
lyric = "-"
|
| 306 |
+
|
| 307 |
+
events.append(
|
| 308 |
+
(start_ticks, 1, mido.MetaMessage("lyrics", text=lyric, time=0))
|
| 309 |
+
)
|
| 310 |
+
events.append(
|
| 311 |
+
(
|
| 312 |
+
start_ticks,
|
| 313 |
+
2,
|
| 314 |
+
mido.Message(
|
| 315 |
+
"note_on",
|
| 316 |
+
note=n.note_pitch,
|
| 317 |
+
velocity=MIDI_VELOCITY,
|
| 318 |
+
time=0,
|
| 319 |
+
),
|
| 320 |
+
)
|
| 321 |
+
)
|
| 322 |
+
events.append(
|
| 323 |
+
(
|
| 324 |
+
end_ticks,
|
| 325 |
+
0,
|
| 326 |
+
mido.Message("note_off", note=n.note_pitch, velocity=0, time=0),
|
| 327 |
+
)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
events.sort(key=lambda x: (x[0], x[1]))
|
| 331 |
+
|
| 332 |
+
mid = mido.MidiFile(ticks_per_beat=MIDI_TICKS_PER_BEAT)
|
| 333 |
+
track = mido.MidiTrack()
|
| 334 |
+
mid.tracks.append(track)
|
| 335 |
+
|
| 336 |
+
track.append(mido.MetaMessage("set_tempo", tempo=MIDI_TEMPO, time=0))
|
| 337 |
+
track.append(
|
| 338 |
+
mido.MetaMessage(
|
| 339 |
+
"time_signature",
|
| 340 |
+
numerator=MIDI_TIME_SIGNATURE[0],
|
| 341 |
+
denominator=MIDI_TIME_SIGNATURE[1],
|
| 342 |
+
time=0,
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
last_tick = 0
|
| 347 |
+
for tick, _, msg in events:
|
| 348 |
+
msg.time = max(0, tick - last_tick)
|
| 349 |
+
track.append(msg)
|
| 350 |
+
last_tick = tick
|
| 351 |
+
|
| 352 |
+
track.append(mido.MetaMessage("end_of_track", time=0))
|
| 353 |
+
mid.save(midi_path)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def midi2notes(midi_path: str) -> List[Note]:
|
| 357 |
+
"""Parse MIDI file into a list of Note."""
|
| 358 |
+
mid = mido.MidiFile(midi_path)
|
| 359 |
+
ticks_per_beat = mid.ticks_per_beat
|
| 360 |
+
tempo = 500000
|
| 361 |
+
|
| 362 |
+
raw_notes: List[dict] = []
|
| 363 |
+
lyrics: List[Tuple[int, str]] = []
|
| 364 |
+
|
| 365 |
+
for track in mid.tracks:
|
| 366 |
+
abs_ticks = 0
|
| 367 |
+
active = {}
|
| 368 |
+
for msg in track:
|
| 369 |
+
abs_ticks += msg.time
|
| 370 |
+
if msg.type == "set_tempo":
|
| 371 |
+
tempo = msg.tempo
|
| 372 |
+
elif msg.type == "lyrics":
|
| 373 |
+
text = msg.text
|
| 374 |
+
try:
|
| 375 |
+
text = text.encode("latin1").decode("utf-8")
|
| 376 |
+
except Exception:
|
| 377 |
+
pass
|
| 378 |
+
lyrics.append((abs_ticks, text))
|
| 379 |
+
elif msg.type == "note_on":
|
| 380 |
+
key = (msg.channel, msg.note)
|
| 381 |
+
if msg.velocity > 0:
|
| 382 |
+
active[key] = (abs_ticks, msg.velocity)
|
| 383 |
+
else:
|
| 384 |
+
if key in active:
|
| 385 |
+
start_ticks, vel = active.pop(key)
|
| 386 |
+
raw_notes.append(
|
| 387 |
+
{
|
| 388 |
+
"midi": msg.note,
|
| 389 |
+
"start_ticks": start_ticks,
|
| 390 |
+
"duration_ticks": abs_ticks - start_ticks,
|
| 391 |
+
"velocity": vel,
|
| 392 |
+
"lyric": "",
|
| 393 |
+
}
|
| 394 |
+
)
|
| 395 |
+
elif msg.type == "note_off":
|
| 396 |
+
key = (msg.channel, msg.note)
|
| 397 |
+
if key in active:
|
| 398 |
+
start_ticks, vel = active.pop(key)
|
| 399 |
+
raw_notes.append(
|
| 400 |
+
{
|
| 401 |
+
"midi": msg.note,
|
| 402 |
+
"start_ticks": start_ticks,
|
| 403 |
+
"duration_ticks": abs_ticks - start_ticks,
|
| 404 |
+
"velocity": vel,
|
| 405 |
+
"lyric": "",
|
| 406 |
+
}
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if not raw_notes:
|
| 410 |
+
raise ValueError("No notes found in MIDI file")
|
| 411 |
+
|
| 412 |
+
for n in raw_notes:
|
| 413 |
+
n["end_ticks"] = n["start_ticks"] + n["duration_ticks"]
|
| 414 |
+
|
| 415 |
+
raw_notes.sort(key=lambda n: n["start_ticks"])
|
| 416 |
+
lyrics.sort(key=lambda x: x[0])
|
| 417 |
+
|
| 418 |
+
trimmed = []
|
| 419 |
+
# Remove/trim overlaps so generated notes are strictly non-overlapping in tick domain.
|
| 420 |
+
for note in raw_notes:
|
| 421 |
+
while trimmed:
|
| 422 |
+
prev = trimmed[-1]
|
| 423 |
+
if note["start_ticks"] < prev["end_ticks"]:
|
| 424 |
+
prev["end_ticks"] = note["start_ticks"]
|
| 425 |
+
prev["duration_ticks"] = prev["end_ticks"] - prev["start_ticks"]
|
| 426 |
+
if prev["duration_ticks"] <= 0:
|
| 427 |
+
trimmed.pop()
|
| 428 |
+
continue
|
| 429 |
+
break
|
| 430 |
+
trimmed.append(note)
|
| 431 |
+
raw_notes = trimmed
|
| 432 |
+
|
| 433 |
+
tolerance = ticks_per_beat // 100
|
| 434 |
+
# Attach lyrics near note_on positions with a small tick tolerance.
|
| 435 |
+
lyric_idx = 0
|
| 436 |
+
for note in raw_notes:
|
| 437 |
+
while lyric_idx < len(lyrics) and lyrics[lyric_idx][0] < note["start_ticks"] - tolerance:
|
| 438 |
+
lyric_idx += 1
|
| 439 |
+
if lyric_idx < len(lyrics):
|
| 440 |
+
lyric_ticks, lyric_text = lyrics[lyric_idx]
|
| 441 |
+
if abs(lyric_ticks - note["start_ticks"]) <= tolerance:
|
| 442 |
+
note["lyric"] = lyric_text
|
| 443 |
+
lyric_idx += 1
|
| 444 |
+
|
| 445 |
+
def ticks_to_seconds(ticks: int) -> float:
|
| 446 |
+
return (ticks / ticks_per_beat) * (tempo / 1_000_000)
|
| 447 |
+
|
| 448 |
+
result: List[Note] = []
|
| 449 |
+
prev_end_s = 0.0
|
| 450 |
+
for idx, n in enumerate(raw_notes):
|
| 451 |
+
start_s = ticks_to_seconds(n["start_ticks"])
|
| 452 |
+
end_s = ticks_to_seconds(n["end_ticks"])
|
| 453 |
+
if prev_end_s > start_s:
|
| 454 |
+
start_s = prev_end_s
|
| 455 |
+
dur_s = end_s - start_s
|
| 456 |
+
if dur_s <= 0:
|
| 457 |
+
continue
|
| 458 |
+
|
| 459 |
+
lyric = n.get("lyric", "")
|
| 460 |
+
# SoulX-Singer convention mapping from lyric token to note_type/text.
|
| 461 |
+
if not lyric:
|
| 462 |
+
note_type = 2
|
| 463 |
+
text = "啦"
|
| 464 |
+
elif lyric == "<SP>":
|
| 465 |
+
note_type = 1
|
| 466 |
+
text = "<SP>"
|
| 467 |
+
elif lyric == "-":
|
| 468 |
+
note_type = 3
|
| 469 |
+
text = raw_notes[idx - 1].get("lyric", "-") if idx > 0 else "-"
|
| 470 |
+
else:
|
| 471 |
+
note_type = 2
|
| 472 |
+
text = lyric
|
| 473 |
+
|
| 474 |
+
if start_s - prev_end_s > SILENCE_THRESHOLD_SEC:
|
| 475 |
+
# Explicitly represent long gaps as <SP> notes.
|
| 476 |
+
result.append(
|
| 477 |
+
Note(
|
| 478 |
+
start_s=prev_end_s,
|
| 479 |
+
note_dur=start_s - prev_end_s,
|
| 480 |
+
note_text="<SP>",
|
| 481 |
+
note_pitch=0,
|
| 482 |
+
note_type=1,
|
| 483 |
+
)
|
| 484 |
+
)
|
| 485 |
+
else:
|
| 486 |
+
if len(result) > 0:
|
| 487 |
+
result[-1].note_dur = start_s - result[-1].start_s
|
| 488 |
+
|
| 489 |
+
result.append(
|
| 490 |
+
Note(
|
| 491 |
+
start_s=start_s,
|
| 492 |
+
note_dur=dur_s,
|
| 493 |
+
note_text=text,
|
| 494 |
+
note_pitch=n["midi"],
|
| 495 |
+
note_type=note_type,
|
| 496 |
+
)
|
| 497 |
+
)
|
| 498 |
+
prev_end_s = end_s
|
| 499 |
+
|
| 500 |
+
return result
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class MidiParser:
|
| 504 |
+
def __init__(
|
| 505 |
+
self,
|
| 506 |
+
rmvpe_model_path: str,
|
| 507 |
+
device: str = "cuda",
|
| 508 |
+
) -> None:
|
| 509 |
+
self.rmvpe_model_path = rmvpe_model_path
|
| 510 |
+
self.device = device
|
| 511 |
+
self.pitch_extractor: F0Extractor | None = None
|
| 512 |
+
|
| 513 |
+
def _get_pitch_extractor(self) -> F0Extractor:
|
| 514 |
+
if self.pitch_extractor is None:
|
| 515 |
+
self.pitch_extractor = F0Extractor(
|
| 516 |
+
self.rmvpe_model_path,
|
| 517 |
+
device=self.device,
|
| 518 |
+
verbose=False,
|
| 519 |
+
)
|
| 520 |
+
return self.pitch_extractor
|
| 521 |
+
|
| 522 |
+
def midi2meta(
|
| 523 |
+
self,
|
| 524 |
+
midi_path: str,
|
| 525 |
+
meta_path: str,
|
| 526 |
+
vocal_file: str | None = None,
|
| 527 |
+
language: str = "Mandarin",
|
| 528 |
+
) -> None:
|
| 529 |
+
meta_dir = os.path.dirname(meta_path)
|
| 530 |
+
if meta_dir:
|
| 531 |
+
os.makedirs(meta_dir, exist_ok=True)
|
| 532 |
+
|
| 533 |
+
notes = midi2notes(midi_path)
|
| 534 |
+
pitch_extractor = self._get_pitch_extractor() if vocal_file else None
|
| 535 |
+
notes2meta(
|
| 536 |
+
notes,
|
| 537 |
+
meta_path,
|
| 538 |
+
vocal_file,
|
| 539 |
+
language,
|
| 540 |
+
pitch_extractor=pitch_extractor,
|
| 541 |
+
)
|
| 542 |
+
print(f"Saved Meta to {meta_path}")
|
| 543 |
+
|
| 544 |
+
def meta2midi(self, meta_path: str, midi_path: str) -> None:
|
| 545 |
+
notes = meta2notes(meta_path)
|
| 546 |
+
notes2midi(notes, midi_path)
|
| 547 |
+
print(f"Saved MIDI to {midi_path}")
|
| 548 |
+
|
| 549 |
+
if __name__ == "__main__":
|
| 550 |
+
import argparse
|
| 551 |
+
|
| 552 |
+
parser = argparse.ArgumentParser(
|
| 553 |
+
description="Convert SoulX-Singer metadata JSON <-> MIDI."
|
| 554 |
+
)
|
| 555 |
+
parser.add_argument("--meta", type=str, help="Path to metadata JSON")
|
| 556 |
+
parser.add_argument("--midi", type=str, help="Path to MIDI file")
|
| 557 |
+
parser.add_argument("--vocal", type=str, default=None, help="Path to vocal wav (optional for midi2meta)")
|
| 558 |
+
parser.add_argument("--language", type=str, default="Mandarin", help="Lyric language for metadata phoneme conversion (default: Mandarin)")
|
| 559 |
+
parser.add_argument(
|
| 560 |
+
"--meta2midi",
|
| 561 |
+
action="store_true",
|
| 562 |
+
help="Convert meta -> midi (requires --meta and --midi)",
|
| 563 |
+
)
|
| 564 |
+
parser.add_argument(
|
| 565 |
+
"--midi2meta",
|
| 566 |
+
action="store_true",
|
| 567 |
+
help="Convert midi -> meta (requires --midi and --meta; --vocal is optional)",
|
| 568 |
+
)
|
| 569 |
+
parser.add_argument(
|
| 570 |
+
"--rmvpe_model_path",
|
| 571 |
+
type=str,
|
| 572 |
+
help="Path to RMVPE model",
|
| 573 |
+
default="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
|
| 574 |
+
)
|
| 575 |
+
parser.add_argument(
|
| 576 |
+
"--device",
|
| 577 |
+
type=str,
|
| 578 |
+
help="Device to use for RMVPE",
|
| 579 |
+
default="cuda",
|
| 580 |
+
)
|
| 581 |
+
args = parser.parse_args()
|
| 582 |
+
midi_parser = MidiParser(
|
| 583 |
+
rmvpe_model_path=args.rmvpe_model_path,
|
| 584 |
+
device=args.device,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
if args.meta2midi:
|
| 588 |
+
if not args.meta or not args.midi:
|
| 589 |
+
parser.error("--meta2midi requires --meta and --midi")
|
| 590 |
+
midi_parser.meta2midi(args.meta, args.midi)
|
| 591 |
+
elif args.midi2meta:
|
| 592 |
+
if not args.midi or not args.meta:
|
| 593 |
+
parser.error(
|
| 594 |
+
"--midi2meta requires --midi and --meta"
|
| 595 |
+
)
|
| 596 |
+
midi_parser.midi2meta(args.midi, args.meta, args.vocal, args.language)
|
| 597 |
+
else:
|
| 598 |
+
parser.print_help()
|
preprocess/tools/note_transcription/__init__.py
ADDED
|
File without changes
|
preprocess/tools/note_transcription/model.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/RickyL-2000/ROSVOT
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import traceback
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
|
| 10 |
+
import librosa
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
from .utils.os_utils import safe_path
|
| 16 |
+
from .utils.commons.hparams import set_hparams
|
| 17 |
+
from .utils.commons.ckpt_utils import load_ckpt
|
| 18 |
+
from .utils.commons.dataset_utils import pad_or_cut_xd
|
| 19 |
+
from .utils.audio.mel import MelNet
|
| 20 |
+
from .utils.audio.pitch_utils import (
|
| 21 |
+
norm_interp_f0,
|
| 22 |
+
denorm_f0,
|
| 23 |
+
f0_to_coarse,
|
| 24 |
+
boundary2Interval,
|
| 25 |
+
save_midi,
|
| 26 |
+
midi_to_hz,
|
| 27 |
+
)
|
| 28 |
+
from .utils.rosvot_utils import (
|
| 29 |
+
get_mel_len,
|
| 30 |
+
align_word,
|
| 31 |
+
regulate_real_note_itv,
|
| 32 |
+
regulate_ill_slur,
|
| 33 |
+
bd_to_durs,
|
| 34 |
+
)
|
| 35 |
+
from .modules.pe.rmvpe import RMVPE
|
| 36 |
+
from .modules.rosvot.rosvot import MidiExtractor, WordbdExtractor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def infer_sample(
|
| 41 |
+
item: Dict[str, Any],
|
| 42 |
+
hparams: Dict[str, Any],
|
| 43 |
+
models: Dict[str, Any],
|
| 44 |
+
device: torch.device,
|
| 45 |
+
*,
|
| 46 |
+
save_dir: Optional[str] = None,
|
| 47 |
+
apply_rwbd: Optional[bool] = None,
|
| 48 |
+
# outputs
|
| 49 |
+
save_plot: bool = False,
|
| 50 |
+
no_save_midi: bool = True,
|
| 51 |
+
no_save_npy: bool = True,
|
| 52 |
+
verbose: bool = False,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
if "item_name" not in item or "wav_fn" not in item:
|
| 55 |
+
raise ValueError('item must contain keys: "item_name" and "wav_fn"')
|
| 56 |
+
|
| 57 |
+
item_name = item["item_name"]
|
| 58 |
+
wav_src = item["wav_fn"]
|
| 59 |
+
|
| 60 |
+
# Decide RWBD usage
|
| 61 |
+
if apply_rwbd is None:
|
| 62 |
+
apply_rwbd_ = ("word_durs" not in item)
|
| 63 |
+
else:
|
| 64 |
+
apply_rwbd_ = bool(apply_rwbd)
|
| 65 |
+
|
| 66 |
+
# Models
|
| 67 |
+
model = models["model"]
|
| 68 |
+
mel_net = models["mel_net"]
|
| 69 |
+
pe = models.get("pe")
|
| 70 |
+
wbd_predictor = models.get("wbd_predictor")
|
| 71 |
+
|
| 72 |
+
if wbd_predictor is None and apply_rwbd_:
|
| 73 |
+
raise ValueError("apply_rwbd is True but wbd_predictor model is not provided in models")
|
| 74 |
+
|
| 75 |
+
# ---- Prepare Data ----
|
| 76 |
+
if isinstance(wav_src, str):
|
| 77 |
+
wav, _ = librosa.core.load(wav_src, sr=hparams["audio_sample_rate"])
|
| 78 |
+
else:
|
| 79 |
+
wav = wav_src
|
| 80 |
+
if not isinstance(wav, np.ndarray):
|
| 81 |
+
wav = np.asarray(wav)
|
| 82 |
+
wav = wav.astype(np.float32)
|
| 83 |
+
|
| 84 |
+
# Calculate timestamps and alignment lengths
|
| 85 |
+
wav_len_samples = wav.shape[-1]
|
| 86 |
+
mel_len = get_mel_len(wav_len_samples, hparams["hop_size"])
|
| 87 |
+
|
| 88 |
+
# Word boundary preparation
|
| 89 |
+
mel2word = None
|
| 90 |
+
word_durs_filtered = None
|
| 91 |
+
|
| 92 |
+
if not apply_rwbd_:
|
| 93 |
+
if "word_durs" not in item:
|
| 94 |
+
raise ValueError('apply_rwbd=False but item has no "word_durs"')
|
| 95 |
+
|
| 96 |
+
wd_raw = list(item["word_durs"])
|
| 97 |
+
min_word_dur = hparams.get("min_word_dur", 20) / 1000
|
| 98 |
+
word_durs_filtered = []
|
| 99 |
+
|
| 100 |
+
for i, wd in enumerate(wd_raw):
|
| 101 |
+
if wd < min_word_dur:
|
| 102 |
+
if i == 0 and len(wd_raw) > 1:
|
| 103 |
+
wd_raw[i + 1] += wd
|
| 104 |
+
elif len(word_durs_filtered) > 0:
|
| 105 |
+
word_durs_filtered[-1] += wd
|
| 106 |
+
else:
|
| 107 |
+
word_durs_filtered.append(wd)
|
| 108 |
+
|
| 109 |
+
mel2word, _ = align_word(word_durs_filtered, mel_len, hparams["hop_size"], hparams["audio_sample_rate"])
|
| 110 |
+
mel2word = np.asarray(mel2word)
|
| 111 |
+
if mel2word.size > 0 and mel2word[0] == 0:
|
| 112 |
+
mel2word = mel2word + 1
|
| 113 |
+
|
| 114 |
+
mel2word_len = int(np.sum(mel2word > 0))
|
| 115 |
+
real_len = min(mel_len, mel2word_len)
|
| 116 |
+
else:
|
| 117 |
+
real_len = min(mel_len, hparams["max_frames"])
|
| 118 |
+
|
| 119 |
+
T = math.ceil(min(real_len, hparams["max_frames"]) / hparams["frames_multiple"]) * hparams["frames_multiple"]
|
| 120 |
+
|
| 121 |
+
# ---- Input Tensors & Padding ----
|
| 122 |
+
target_samples = T * hparams["hop_size"]
|
| 123 |
+
wav_t = torch.from_numpy(wav).float().to(device).unsqueeze(0) # [1, L]
|
| 124 |
+
if wav_t.shape[-1] < target_samples:
|
| 125 |
+
wav_t = pad_or_cut_xd(wav_t, target_samples, 1)
|
| 126 |
+
|
| 127 |
+
# ---- Pitch Extraction ----
|
| 128 |
+
if pe is not None:
|
| 129 |
+
f0s, uvs = pe.get_pitch_batch(
|
| 130 |
+
wav_t,
|
| 131 |
+
sample_rate=hparams["audio_sample_rate"],
|
| 132 |
+
hop_size=hparams["hop_size"],
|
| 133 |
+
lengths=[real_len],
|
| 134 |
+
fmax=hparams["f0_max"],
|
| 135 |
+
fmin=hparams["f0_min"],
|
| 136 |
+
)
|
| 137 |
+
f0_1d, uv_1d = norm_interp_f0(f0s[0][:T])
|
| 138 |
+
f0_t = pad_or_cut_xd(torch.FloatTensor(f0_1d).to(device), T, 0).unsqueeze(0)
|
| 139 |
+
uv_t = pad_or_cut_xd(torch.FloatTensor(uv_1d).to(device), T, 0).long().unsqueeze(0)
|
| 140 |
+
pitch_coarse = f0_to_coarse(denorm_f0(f0_t, uv_t)).to(device)
|
| 141 |
+
f0_np = denorm_f0(f0_t, uv_t)[0].detach().cpu().numpy()[:real_len]
|
| 142 |
+
else:
|
| 143 |
+
f0_t = uv_t = pitch_coarse = None
|
| 144 |
+
f0_np = None
|
| 145 |
+
|
| 146 |
+
# ---- Mel Extraction ----
|
| 147 |
+
mel = mel_net(wav_t) # [1, T_padded, C]
|
| 148 |
+
mel = pad_or_cut_xd(mel, T, 1)
|
| 149 |
+
|
| 150 |
+
# Construct non-padding mask
|
| 151 |
+
mel_nonpadding_mask = torch.zeros(1, T, device=device)
|
| 152 |
+
mel_nonpadding_mask[:, :real_len] = 1.0
|
| 153 |
+
|
| 154 |
+
# Apply mask to mel (zero out padding)
|
| 155 |
+
mel = (mel.transpose(1, 2) * mel_nonpadding_mask.unsqueeze(1)).transpose(1, 2)
|
| 156 |
+
# Re-calculate non_padding bool mask
|
| 157 |
+
mel_nonpadding = mel.abs().sum(-1) > 0
|
| 158 |
+
|
| 159 |
+
# ---- Word Boundary ----
|
| 160 |
+
word_durs_used = None
|
| 161 |
+
if apply_rwbd_:
|
| 162 |
+
mel_input = mel[:, :, : hparams.get("wbd_use_mel_bins", 80)]
|
| 163 |
+
wbd_outputs = wbd_predictor(
|
| 164 |
+
mel=mel_input,
|
| 165 |
+
pitch=pitch_coarse,
|
| 166 |
+
uv=uv_t,
|
| 167 |
+
non_padding=mel_nonpadding,
|
| 168 |
+
train=False,
|
| 169 |
+
)
|
| 170 |
+
word_bd = wbd_outputs["word_bd_pred"] # [1, T]
|
| 171 |
+
else:
|
| 172 |
+
# Construct word_bd from provided durs
|
| 173 |
+
mel2word_t = pad_or_cut_xd(torch.LongTensor(mel2word).to(device), T, 0)
|
| 174 |
+
word_bd = torch.zeros_like(mel2word_t)
|
| 175 |
+
# Vectorized check
|
| 176 |
+
word_bd[1:] = (mel2word_t[1:] != mel2word_t[:-1]).long()
|
| 177 |
+
word_bd[real_len:] = 0
|
| 178 |
+
word_bd = word_bd.unsqueeze(0) # [1, T]
|
| 179 |
+
|
| 180 |
+
word_durs_used = np.array(word_durs_filtered)
|
| 181 |
+
|
| 182 |
+
# ---- Main Inference ----
|
| 183 |
+
mel_input = mel[:, :, : hparams.get("use_mel_bins", 80)]
|
| 184 |
+
outputs = model(
|
| 185 |
+
mel=mel_input,
|
| 186 |
+
word_bd=word_bd,
|
| 187 |
+
pitch=pitch_coarse,
|
| 188 |
+
uv=uv_t,
|
| 189 |
+
non_padding=mel_nonpadding,
|
| 190 |
+
train=False,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
note_lengths = outputs["note_lengths"].detach().cpu().numpy()
|
| 194 |
+
note_bd_pred = outputs["note_bd_pred"][0].detach().cpu().numpy()[:real_len]
|
| 195 |
+
note_pred = outputs["note_pred"][0].detach().cpu().numpy()[: note_lengths[0]]
|
| 196 |
+
note_bd_logits = torch.sigmoid(outputs["note_bd_logits"])[0].detach().cpu().numpy()[:real_len]
|
| 197 |
+
|
| 198 |
+
if note_pred.shape == (0,):
|
| 199 |
+
if verbose:
|
| 200 |
+
print(f"skip {item_name}: no notes detected")
|
| 201 |
+
return {
|
| 202 |
+
"item_name": item_name,
|
| 203 |
+
"pitches": [],
|
| 204 |
+
"note_durs": [],
|
| 205 |
+
"note2words": None,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# ---- Post-Processing & Regulation ----
|
| 209 |
+
note_itv_pred = boundary2Interval(note_bd_pred)
|
| 210 |
+
note2words = None
|
| 211 |
+
|
| 212 |
+
if apply_rwbd_:
|
| 213 |
+
word_bd_np = outputs['word_bd_pred'][0].detach().cpu().numpy()[:real_len]
|
| 214 |
+
word_durs_derived = np.array(bd_to_durs(word_bd_np)) * hparams['hop_size'] / hparams['audio_sample_rate']
|
| 215 |
+
word_durs_for_reg = word_durs_derived
|
| 216 |
+
word_bd_for_reg = word_bd_np
|
| 217 |
+
else:
|
| 218 |
+
word_bd_for_reg = word_bd[0].detach().cpu().numpy()[:real_len]
|
| 219 |
+
word_durs_for_reg = word_durs_used
|
| 220 |
+
|
| 221 |
+
should_regulate = hparams.get("infer_regulate_real_note_itv", True) and (not apply_rwbd_)
|
| 222 |
+
|
| 223 |
+
if should_regulate and (word_durs_for_reg is not None):
|
| 224 |
+
try:
|
| 225 |
+
note_itv_pred_secs, note2words = regulate_real_note_itv(
|
| 226 |
+
note_itv_pred,
|
| 227 |
+
note_bd_pred,
|
| 228 |
+
word_bd_for_reg,
|
| 229 |
+
word_durs_for_reg,
|
| 230 |
+
hparams["hop_size"],
|
| 231 |
+
hparams["audio_sample_rate"],
|
| 232 |
+
)
|
| 233 |
+
note_pred, note_itv_pred_secs, note2words = regulate_ill_slur(note_pred, note_itv_pred_secs, note2words)
|
| 234 |
+
except Exception as err:
|
| 235 |
+
if verbose:
|
| 236 |
+
_, exc_value, exc_tb = sys.exc_info()
|
| 237 |
+
tb = traceback.extract_tb(exc_tb)[-1]
|
| 238 |
+
print(f"postprocess failed: {err}: {exc_value} in {tb[0]}:{tb[1]} '{tb[2]}' in {tb[3]}")
|
| 239 |
+
# Fallback
|
| 240 |
+
note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
|
| 241 |
+
note2words = None
|
| 242 |
+
else:
|
| 243 |
+
note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
|
| 244 |
+
|
| 245 |
+
# ---- Output ----
|
| 246 |
+
note_durs = [float((itv[1] - itv[0])) for itv in note_itv_pred_secs]
|
| 247 |
+
|
| 248 |
+
out = {
|
| 249 |
+
"item_name": item_name,
|
| 250 |
+
"pitches": note_pred.tolist(),
|
| 251 |
+
"note_durs": note_durs,
|
| 252 |
+
"note2words": note2words.tolist() if note2words is not None else None,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# ---- Saving ----
|
| 256 |
+
if save_dir is not None:
|
| 257 |
+
save_dir_path = Path(save_dir)
|
| 258 |
+
save_dir_path.mkdir(parents=True, exist_ok=True)
|
| 259 |
+
fn = str(item_name)
|
| 260 |
+
|
| 261 |
+
if not no_save_midi:
|
| 262 |
+
save_midi(note_pred, note_itv_pred_secs, safe_path(save_dir_path / "midi" / f"{fn}.mid"))
|
| 263 |
+
|
| 264 |
+
if not no_save_npy:
|
| 265 |
+
np.save(safe_path(save_dir_path / "npy" / f"[note]{fn}.npy"), out, allow_pickle=True)
|
| 266 |
+
|
| 267 |
+
if save_plot:
|
| 268 |
+
fig = plt.figure()
|
| 269 |
+
if f0_np is not None:
|
| 270 |
+
plt.plot(f0_np, color="red", label="f0")
|
| 271 |
+
|
| 272 |
+
midi_pred = np.zeros(note_bd_pred.shape[0], dtype=np.float32)
|
| 273 |
+
itvs = np.round(note_itv_pred_secs * hparams["audio_sample_rate"] / hparams["hop_size"]).astype(int)
|
| 274 |
+
for i, itv in enumerate(itvs):
|
| 275 |
+
midi_pred[itv[0] : itv[1]] = note_pred[i]
|
| 276 |
+
plt.plot(midi_to_hz(midi_pred), color="blue", label="pred midi")
|
| 277 |
+
plt.plot(note_bd_logits * 100, color="green", label="note bd logits x100")
|
| 278 |
+
plt.legend()
|
| 279 |
+
plt.tight_layout()
|
| 280 |
+
plt.savefig(safe_path(save_dir_path / "plot" / f"[MIDI]{fn}.png"), format="png")
|
| 281 |
+
plt.close(fig)
|
| 282 |
+
|
| 283 |
+
return out
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def load_rosvot_models(ckpt, config="", wbd_ckpt="", wbd_config="", device="cuda:0", verbose=False, thr=0.85):
|
| 287 |
+
"""
|
| 288 |
+
Load models once to reuse across multiple items.
|
| 289 |
+
"""
|
| 290 |
+
dev = torch.device(device)
|
| 291 |
+
|
| 292 |
+
# 1. Hparams
|
| 293 |
+
config_path = Path(ckpt).with_name("config.yaml") if config == "" else config
|
| 294 |
+
pe_ckpt = Path(ckpt).parent.parent / "rmvpe/model.pt"
|
| 295 |
+
hparams = set_hparams(
|
| 296 |
+
config=config_path,
|
| 297 |
+
print_hparams=verbose,
|
| 298 |
+
hparams_str=f"note_bd_threshold={thr}",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# 2. Main Model
|
| 302 |
+
model = MidiExtractor(hparams)
|
| 303 |
+
load_ckpt(model, ckpt, verbose=verbose)
|
| 304 |
+
model.eval().to(dev)
|
| 305 |
+
|
| 306 |
+
# 3. MelNet
|
| 307 |
+
mel_net = MelNet(hparams)
|
| 308 |
+
mel_net.to(dev)
|
| 309 |
+
|
| 310 |
+
# 4. Pitch Extractor
|
| 311 |
+
pe = None
|
| 312 |
+
if hparams.get("use_pitch_embed", False):
|
| 313 |
+
pe = RMVPE(pe_ckpt, device=dev)
|
| 314 |
+
|
| 315 |
+
# 5. Word Boundary Predictor (optional but we load if ckpt provided or needed)
|
| 316 |
+
wbd_predictor = None
|
| 317 |
+
if wbd_ckpt:
|
| 318 |
+
wbd_config_path = Path(wbd_ckpt).with_name("config.yaml") if wbd_config == "" else wbd_config
|
| 319 |
+
wbd_hparams = set_hparams(
|
| 320 |
+
config=wbd_config_path,
|
| 321 |
+
print_hparams=False,
|
| 322 |
+
hparams_str="",
|
| 323 |
+
)
|
| 324 |
+
hparams.update({
|
| 325 |
+
"wbd_use_mel_bins": wbd_hparams["use_mel_bins"],
|
| 326 |
+
"min_word_dur": wbd_hparams["min_word_dur"],
|
| 327 |
+
})
|
| 328 |
+
wbd_predictor = WordbdExtractor(wbd_hparams)
|
| 329 |
+
load_ckpt(wbd_predictor, wbd_ckpt, verbose=verbose)
|
| 330 |
+
wbd_predictor.eval().to(dev)
|
| 331 |
+
|
| 332 |
+
models = {
|
| 333 |
+
"model": model,
|
| 334 |
+
"mel_net": mel_net,
|
| 335 |
+
"pe": pe,
|
| 336 |
+
"wbd_predictor": wbd_predictor
|
| 337 |
+
}
|
| 338 |
+
return hparams, models
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class NoteTranscriber:
|
| 342 |
+
"""Note transcription wrapper based on ROSVOT.
|
| 343 |
+
|
| 344 |
+
Loads ROSVOT and optional RWBD models once in ``__init__`` and
|
| 345 |
+
exposes a :py:meth:`process` API that turns an item dict into
|
| 346 |
+
aligned note metadata for downstream SVS.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
def __init__(
|
| 350 |
+
self,
|
| 351 |
+
rosvot_model_path: str,
|
| 352 |
+
rwbd_model_path: str,
|
| 353 |
+
*,
|
| 354 |
+
rosvot_config_path: str = "",
|
| 355 |
+
rwbd_config_path: str = "",
|
| 356 |
+
device: str = "cuda:0",
|
| 357 |
+
thr: float = 0.85,
|
| 358 |
+
verbose: bool = True,
|
| 359 |
+
):
|
| 360 |
+
"""Initialize the note transcriber.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
ckpt: Path to the main ROSVOT checkpoint.
|
| 364 |
+
config: Optional config YAML path for ROSVOT.
|
| 365 |
+
wbd_ckpt: Optional word-boundary checkpoint path.
|
| 366 |
+
wbd_config: Optional config YAML path for RWBD.
|
| 367 |
+
device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
|
| 368 |
+
thr: Note boundary threshold.
|
| 369 |
+
verbose: Whether to print verbose logs.
|
| 370 |
+
"""
|
| 371 |
+
self.verbose = verbose
|
| 372 |
+
self.device = torch.device(device)
|
| 373 |
+
self.hparams, self.models = load_rosvot_models(
|
| 374 |
+
ckpt=rosvot_model_path,
|
| 375 |
+
config=rosvot_config_path,
|
| 376 |
+
wbd_ckpt=rwbd_model_path,
|
| 377 |
+
wbd_config=rwbd_config_path,
|
| 378 |
+
device=device,
|
| 379 |
+
verbose=verbose,
|
| 380 |
+
thr=thr,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if self.verbose:
|
| 384 |
+
print(
|
| 385 |
+
"[note transcription] init success:",
|
| 386 |
+
f"device={self.device}",
|
| 387 |
+
f"rosvot_model_path={rosvot_model_path}",
|
| 388 |
+
f"rwbd_model_path={rwbd_model_path if rwbd_model_path else 'None'}",
|
| 389 |
+
f"thr={thr}",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def process(
|
| 393 |
+
self,
|
| 394 |
+
item: Dict[str, Any],
|
| 395 |
+
*,
|
| 396 |
+
segment_info: Optional[Dict[str, Any]] = None,
|
| 397 |
+
save_dir: Optional[str] = None,
|
| 398 |
+
apply_rwbd: Optional[bool] = None,
|
| 399 |
+
save_plot: bool = False,
|
| 400 |
+
no_save_midi: bool = True,
|
| 401 |
+
no_save_npy: bool = True,
|
| 402 |
+
verbose: Optional[bool] = None,
|
| 403 |
+
) -> Dict[str, Any]:
|
| 404 |
+
"""Run ROSVOT on a single item and post-process outputs.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
item: Input metadata dict with at least ``item_name`` and ``wav_fn``.
|
| 408 |
+
segment_info: Optional segment metadata for sliced audio.
|
| 409 |
+
save_dir: Optional directory for debug artifacts (plots, midis).
|
| 410 |
+
apply_rwbd: Whether to run RWBD-based word boundary refinement.
|
| 411 |
+
save_plot: Whether to save diagnostic plots.
|
| 412 |
+
no_save_midi: If True, skip saving midi.
|
| 413 |
+
no_save_npy: If True, skip saving numpy intermediates.
|
| 414 |
+
verbose: Override instance-level verbose flag for this call.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
Dict with aligned note information for downstream SVS.
|
| 418 |
+
"""
|
| 419 |
+
v = self.verbose if verbose is None else verbose
|
| 420 |
+
if v:
|
| 421 |
+
item_name = item.get("item_name", "")
|
| 422 |
+
wav_fn = item.get("wav_fn", "")
|
| 423 |
+
print(f"[note transcription] process: start: item_name={item_name} wav_fn={wav_fn}")
|
| 424 |
+
t0 = time.time()
|
| 425 |
+
|
| 426 |
+
rosvot_out = infer_sample(
|
| 427 |
+
item,
|
| 428 |
+
self.hparams,
|
| 429 |
+
self.models,
|
| 430 |
+
device=self.device,
|
| 431 |
+
save_dir=save_dir,
|
| 432 |
+
apply_rwbd=apply_rwbd,
|
| 433 |
+
save_plot=save_plot,
|
| 434 |
+
no_save_midi=no_save_midi,
|
| 435 |
+
no_save_npy=no_save_npy,
|
| 436 |
+
verbose=v,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
out = self.post_process(
|
| 440 |
+
metadata=item,
|
| 441 |
+
segment_info=segment_info,
|
| 442 |
+
rosvot_out=rosvot_out,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if v:
|
| 446 |
+
dt = time.time() - t0
|
| 447 |
+
print(
|
| 448 |
+
"[note transcription] process: done:",
|
| 449 |
+
f"item_name={out.get('item_name','')}",
|
| 450 |
+
f"n_notes={len(out.get('note_pitch', []) or [])}",
|
| 451 |
+
f"time={dt:.3f}s",
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
return out
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def _normalize_note2words(note2words: list[int]) -> list[int]:
|
| 458 |
+
if not note2words:
|
| 459 |
+
return []
|
| 460 |
+
normalized = [note2words[0]]
|
| 461 |
+
for idx in range(1, len(note2words)):
|
| 462 |
+
if note2words[idx] < normalized[-1]:
|
| 463 |
+
normalized.append(normalized[-1])
|
| 464 |
+
else:
|
| 465 |
+
normalized.append(note2words[idx])
|
| 466 |
+
return normalized
|
| 467 |
+
|
| 468 |
+
@staticmethod
|
| 469 |
+
def _build_ep_types(note2words: list[int], align_words: list[str]) -> list[int]:
|
| 470 |
+
ep_types: list[int] = []
|
| 471 |
+
prev = -1
|
| 472 |
+
for i, w in zip(note2words, align_words):
|
| 473 |
+
if w == "<SP>":
|
| 474 |
+
ep_types.append(1)
|
| 475 |
+
else:
|
| 476 |
+
ep_types.append(2 if i != prev else 3)
|
| 477 |
+
prev = i
|
| 478 |
+
return ep_types
|
| 479 |
+
|
| 480 |
+
def post_process(
|
| 481 |
+
self,
|
| 482 |
+
*,
|
| 483 |
+
metadata: Dict[str, Any],
|
| 484 |
+
segment_info: Dict[str, Any],
|
| 485 |
+
rosvot_out: Dict[str, Any],
|
| 486 |
+
) -> Dict[str, Any]:
|
| 487 |
+
"""Build aligned note metadata using ROSVOT outputs."""
|
| 488 |
+
note2words_raw = rosvot_out.get("note2words") or []
|
| 489 |
+
note2words = self._normalize_note2words(note2words_raw)
|
| 490 |
+
align_words = [
|
| 491 |
+
metadata["words"][idx - 1]
|
| 492 |
+
for idx in note2words_raw
|
| 493 |
+
if 0 < idx <= len(metadata["words"])
|
| 494 |
+
]
|
| 495 |
+
ep_types = self._build_ep_types(note2words, align_words) if align_words else []
|
| 496 |
+
|
| 497 |
+
return {
|
| 498 |
+
"item_name": rosvot_out.get("item_name", "") if not segment_info else segment_info["item_name"],
|
| 499 |
+
"wav_fn": metadata.get("wav_fn", "") if not segment_info else segment_info["wav_fn"],
|
| 500 |
+
"origin_wav_fn": metadata.get("origin_wav_fn", "") if not segment_info else segment_info["origin_wav_fn"],
|
| 501 |
+
"start_time_ms": "" if not segment_info else segment_info["start_time_ms"],
|
| 502 |
+
"end_time_ms": "" if not segment_info else segment_info["end_time_ms"],
|
| 503 |
+
"language": metadata.get("language", ""),
|
| 504 |
+
"note_text": align_words,
|
| 505 |
+
"note_dur": rosvot_out.get("note_durs", []),
|
| 506 |
+
"note_type": ep_types,
|
| 507 |
+
"note_pitch": rosvot_out.get("pitches", []),
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
if __name__ == "__main__":
|
| 511 |
+
|
| 512 |
+
item = {
|
| 513 |
+
'item_name': 'vocal_0',
|
| 514 |
+
'wav_fn': 'example/audio/zh_prompt.mp3',
|
| 515 |
+
'start_time_ms': 320,
|
| 516 |
+
'end_time_ms': 10687,
|
| 517 |
+
'origin_wav_fn': 'example/audio/zh_prompt.mp3',
|
| 518 |
+
'duration': 10367,
|
| 519 |
+
'words': ['<SP>', '除', '了', '想', '你', '<SP>', '除', '了', '爱', '你', '<SP>', '我', '什', '么', '什', '么', '都', '愿', '意'],
|
| 520 |
+
'word_durs': [0.21, 0.36, 0.26, 0.7000000000000001, 0.96, 0.3800000000000001, 0.43999999999999995, 0.3799999999999999, 0.6400000000000001, 0.9600000000000002, 1.1199999999999999, 0.28000000000000025, 0.3799999999999999, 0.3199999999999994, 0.3200000000000003, 0.3799999999999999, 0.3200000000000003, 0.5, 1.457981859410431],
|
| 521 |
+
'language': 'Mandarin'
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
m = NoteTranscriber(
|
| 525 |
+
rosvot_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rosvot/model.pt",
|
| 526 |
+
rwbd_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rwbd/model.pt",
|
| 527 |
+
device="cuda"
|
| 528 |
+
)
|
| 529 |
+
out = m.process(item, segment_info=item)
|
| 530 |
+
|
| 531 |
+
print(out)
|
preprocess/tools/note_transcription/modules/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ROSVOT model submodules."""
|
preprocess/tools/note_transcription/modules/commons/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Common ROSVOT layers and utilities."""
|
preprocess/tools/note_transcription/modules/commons/conformer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Conformer layers for ROSVOT."""
|
preprocess/tools/note_transcription/modules/commons/conformer/conformer.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from .espnet_positional_embedding import RelPositionalEncoding, ScaledPositionalEncoding, PositionalEncoding
|
| 3 |
+
from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
|
| 4 |
+
from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
|
| 5 |
+
from ..layers import Embedding
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConformerLayers(nn.Module):
|
| 9 |
+
def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
|
| 10 |
+
use_last_norm=True, save_hidden=False):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.use_last_norm = use_last_norm
|
| 13 |
+
self.layers = nn.ModuleList()
|
| 14 |
+
positionwise_layer = MultiLayeredConv1d
|
| 15 |
+
positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
|
| 16 |
+
self.pos_embed = RelPositionalEncoding(hidden_size, dropout)
|
| 17 |
+
self.encoder_layers = nn.ModuleList([EncoderLayer(
|
| 18 |
+
hidden_size,
|
| 19 |
+
RelPositionMultiHeadedAttention(num_heads, hidden_size, 0.0),
|
| 20 |
+
positionwise_layer(*positionwise_layer_args),
|
| 21 |
+
positionwise_layer(*positionwise_layer_args),
|
| 22 |
+
ConvolutionModule(hidden_size, kernel_size, Swish()),
|
| 23 |
+
dropout,
|
| 24 |
+
) for _ in range(num_layers)])
|
| 25 |
+
if self.use_last_norm:
|
| 26 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 27 |
+
else:
|
| 28 |
+
self.layer_norm = nn.Linear(hidden_size, hidden_size)
|
| 29 |
+
self.save_hidden = save_hidden
|
| 30 |
+
if save_hidden:
|
| 31 |
+
self.hiddens = []
|
| 32 |
+
|
| 33 |
+
def forward(self, x, padding_mask=None):
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
:param x: [B, T, H]
|
| 37 |
+
:param padding_mask: [B, T]
|
| 38 |
+
:return: [B, T, H]
|
| 39 |
+
"""
|
| 40 |
+
self.hiddens = []
|
| 41 |
+
nonpadding_mask = x.abs().sum(-1) > 0
|
| 42 |
+
x = self.pos_embed(x)
|
| 43 |
+
for l in self.encoder_layers:
|
| 44 |
+
x, mask = l(x, nonpadding_mask[:, None, :])
|
| 45 |
+
if self.save_hidden:
|
| 46 |
+
self.hiddens.append(x[0])
|
| 47 |
+
x = x[0]
|
| 48 |
+
x = self.layer_norm(x) * nonpadding_mask.float()[:, :, None]
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
class FastConformerLayers(ConformerLayers):
|
| 52 |
+
def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
|
| 53 |
+
use_last_norm=True, save_hidden=False):
|
| 54 |
+
super(ConformerLayers, self).__init__()
|
| 55 |
+
self.use_last_norm = use_last_norm
|
| 56 |
+
self.layers = nn.ModuleList()
|
| 57 |
+
positionwise_layer = MultiLayeredConv1d
|
| 58 |
+
positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
|
| 59 |
+
self.pos_embed = PositionalEncoding(hidden_size, dropout)
|
| 60 |
+
self.encoder_layers = nn.ModuleList([EncoderLayer(
|
| 61 |
+
hidden_size,
|
| 62 |
+
MultiHeadedAttention(num_heads, hidden_size, 0.0, flash=True),
|
| 63 |
+
positionwise_layer(*positionwise_layer_args),
|
| 64 |
+
positionwise_layer(*positionwise_layer_args),
|
| 65 |
+
ConvolutionModule(hidden_size, kernel_size, Swish()),
|
| 66 |
+
dropout,
|
| 67 |
+
) for _ in range(num_layers)])
|
| 68 |
+
if self.use_last_norm:
|
| 69 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 70 |
+
else:
|
| 71 |
+
self.layer_norm = nn.Linear(hidden_size, hidden_size)
|
| 72 |
+
self.save_hidden = save_hidden
|
| 73 |
+
if save_hidden:
|
| 74 |
+
self.hiddens = []
|
| 75 |
+
|
| 76 |
+
class ConformerEncoder(ConformerLayers):
|
| 77 |
+
def __init__(self, hidden_size, dict_size, num_layers=None):
|
| 78 |
+
conformer_enc_kernel_size = 9
|
| 79 |
+
super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
|
| 80 |
+
self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
:param src_tokens: [B, T]
|
| 86 |
+
:return: [B x T x C]
|
| 87 |
+
"""
|
| 88 |
+
x = self.embed(x) # [B, T, H]
|
| 89 |
+
x = super(ConformerEncoder, self).forward(x)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ConformerDecoder(ConformerLayers):
|
| 94 |
+
def __init__(self, hidden_size, num_layers):
|
| 95 |
+
conformer_dec_kernel_size = 9
|
| 96 |
+
super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)
|
preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PositionalEncoding(torch.nn.Module):
|
| 6 |
+
"""Positional encoding.
|
| 7 |
+
Args:
|
| 8 |
+
d_model (int): Embedding dimension.
|
| 9 |
+
dropout_rate (float): Dropout rate.
|
| 10 |
+
max_len (int): Maximum input length.
|
| 11 |
+
reverse (bool): Whether to reverse the input position.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
| 15 |
+
"""Construct an PositionalEncoding object."""
|
| 16 |
+
super(PositionalEncoding, self).__init__()
|
| 17 |
+
self.d_model = d_model
|
| 18 |
+
self.reverse = reverse
|
| 19 |
+
self.xscale = math.sqrt(self.d_model)
|
| 20 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 21 |
+
self.pe = None
|
| 22 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 23 |
+
|
| 24 |
+
def extend_pe(self, x):
|
| 25 |
+
"""Reset the positional encodings."""
|
| 26 |
+
if self.pe is not None:
|
| 27 |
+
if self.pe.size(1) >= x.size(1):
|
| 28 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 29 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 30 |
+
return
|
| 31 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
| 32 |
+
if self.reverse:
|
| 33 |
+
position = torch.arange(
|
| 34 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
| 35 |
+
).unsqueeze(1)
|
| 36 |
+
else:
|
| 37 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 38 |
+
div_term = torch.exp(
|
| 39 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 40 |
+
* -(math.log(10000.0) / self.d_model)
|
| 41 |
+
)
|
| 42 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 43 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 44 |
+
pe = pe.unsqueeze(0)
|
| 45 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor):
|
| 48 |
+
"""Add positional encoding.
|
| 49 |
+
Args:
|
| 50 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 51 |
+
Returns:
|
| 52 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 53 |
+
"""
|
| 54 |
+
self.extend_pe(x)
|
| 55 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
| 56 |
+
return self.dropout(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
| 60 |
+
"""Scaled positional encoding module.
|
| 61 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
| 62 |
+
Args:
|
| 63 |
+
d_model (int): Embedding dimension.
|
| 64 |
+
dropout_rate (float): Dropout rate.
|
| 65 |
+
max_len (int): Maximum input length.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 69 |
+
"""Initialize class."""
|
| 70 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
| 71 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
| 72 |
+
|
| 73 |
+
def reset_parameters(self):
|
| 74 |
+
"""Reset parameters."""
|
| 75 |
+
self.alpha.data = torch.tensor(1.0)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
"""Add positional encoding.
|
| 79 |
+
Args:
|
| 80 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 81 |
+
Returns:
|
| 82 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 83 |
+
"""
|
| 84 |
+
self.extend_pe(x)
|
| 85 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
| 86 |
+
return self.dropout(x)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class RelPositionalEncoding(PositionalEncoding):
|
| 90 |
+
"""Relative positional encoding module.
|
| 91 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 92 |
+
Args:
|
| 93 |
+
d_model (int): Embedding dimension.
|
| 94 |
+
dropout_rate (float): Dropout rate.
|
| 95 |
+
max_len (int): Maximum input length.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 99 |
+
"""Initialize class."""
|
| 100 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
"""Compute positional encoding.
|
| 104 |
+
Args:
|
| 105 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 106 |
+
Returns:
|
| 107 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 108 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
| 109 |
+
"""
|
| 110 |
+
self.extend_pe(x)
|
| 111 |
+
x = x * self.xscale
|
| 112 |
+
pos_emb = self.pe[:, : x.size(1)]
|
| 113 |
+
return self.dropout(x), self.dropout(pos_emb)
|
preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Multi-Head Attention layer definition."""
|
| 8 |
+
|
| 9 |
+
from packaging import version
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import numpy
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MultiHeadedAttention(nn.Module):
|
| 18 |
+
"""Multi-Head Attention layer.
|
| 19 |
+
Args:
|
| 20 |
+
n_head (int): The number of heads.
|
| 21 |
+
n_feat (int): The number of features.
|
| 22 |
+
dropout_rate (float): Dropout rate.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, n_head, n_feat, dropout_rate, flash=False):
|
| 26 |
+
"""Construct an MultiHeadedAttention object."""
|
| 27 |
+
super(MultiHeadedAttention, self).__init__()
|
| 28 |
+
assert n_feat % n_head == 0
|
| 29 |
+
# We assume d_v always equals d_k
|
| 30 |
+
self.d_k = n_feat // n_head
|
| 31 |
+
self.h = n_head
|
| 32 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 33 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
| 34 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 35 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 36 |
+
self.attn = None
|
| 37 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 38 |
+
self.dropout_rate = dropout_rate
|
| 39 |
+
self.flash = flash
|
| 40 |
+
|
| 41 |
+
def forward_qkv(self, query, key, value):
|
| 42 |
+
"""Transform query, key and value.
|
| 43 |
+
Args:
|
| 44 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 45 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 46 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 47 |
+
Returns:
|
| 48 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
| 49 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
| 50 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
| 51 |
+
"""
|
| 52 |
+
n_batch = query.size(0)
|
| 53 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 54 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 55 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 56 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 57 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 58 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 59 |
+
|
| 60 |
+
return q, k, v
|
| 61 |
+
|
| 62 |
+
def forward_attention(self, value, scores, mask):
|
| 63 |
+
"""Compute attention context vector.
|
| 64 |
+
Args:
|
| 65 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
| 66 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
| 67 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
| 68 |
+
Returns:
|
| 69 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 70 |
+
weighted by the attention score (#batch, time1, time2).
|
| 71 |
+
"""
|
| 72 |
+
n_batch = value.size(0)
|
| 73 |
+
if mask is not None:
|
| 74 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 75 |
+
min_value = float(
|
| 76 |
+
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
| 77 |
+
)
|
| 78 |
+
scores = scores.masked_fill(mask, min_value)
|
| 79 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 80 |
+
mask, 0.0
|
| 81 |
+
) # (batch, head, time1, time2)
|
| 82 |
+
else:
|
| 83 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 84 |
+
|
| 85 |
+
p_attn = self.dropout(self.attn)
|
| 86 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 87 |
+
x = (
|
| 88 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
| 89 |
+
) # (batch, time1, d_model)
|
| 90 |
+
|
| 91 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 92 |
+
|
| 93 |
+
def forward(self, query, key, value, mask):
|
| 94 |
+
"""Compute scaled dot product attention.
|
| 95 |
+
Args:
|
| 96 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 97 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 98 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 99 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 100 |
+
(#batch, time1, time2).
|
| 101 |
+
Returns:
|
| 102 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 103 |
+
"""
|
| 104 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 105 |
+
if version.parse(torch.__version__) >= version.parse("2.0") and self.flash:
|
| 106 |
+
n_batch = value.size(0)
|
| 107 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 108 |
+
q, k, v, attn_mask=mask.unsqueeze(1) if mask is not None else None, dropout_p=self.dropout_rate)
|
| 109 |
+
x = (
|
| 110 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
| 111 |
+
) # (batch, time1, d_model)
|
| 112 |
+
return self.linear_out(x)
|
| 113 |
+
else:
|
| 114 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 115 |
+
return self.forward_attention(v, scores, mask)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 119 |
+
"""Multi-Head Attention layer with relative position encoding.
|
| 120 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 121 |
+
Args:
|
| 122 |
+
n_head (int): The number of heads.
|
| 123 |
+
n_feat (int): The number of features.
|
| 124 |
+
dropout_rate (float): Dropout rate.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, n_head, n_feat, dropout_rate):
|
| 128 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
| 129 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
| 130 |
+
# linear transformation for positional ecoding
|
| 131 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 132 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 133 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 134 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 135 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 136 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 137 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 138 |
+
|
| 139 |
+
def rel_shift(self, x, zero_triu=False):
|
| 140 |
+
"""Compute relative positinal encoding.
|
| 141 |
+
Args:
|
| 142 |
+
x (torch.Tensor): Input tensor (batch, time, size).
|
| 143 |
+
zero_triu (bool): If true, return the lower triangular part of the matrix.
|
| 144 |
+
Returns:
|
| 145 |
+
torch.Tensor: Output tensor.
|
| 146 |
+
"""
|
| 147 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
| 148 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 149 |
+
|
| 150 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
| 151 |
+
x = x_padded[:, :, 1:].view_as(x)
|
| 152 |
+
|
| 153 |
+
if zero_triu:
|
| 154 |
+
ones = torch.ones((x.size(2), x.size(3)))
|
| 155 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
| 156 |
+
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
def forward(self, query, key, value, pos_emb, mask):
|
| 160 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 161 |
+
Args:
|
| 162 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 163 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 164 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 165 |
+
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size).
|
| 166 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 167 |
+
(#batch, time1, time2).
|
| 168 |
+
Returns:
|
| 169 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 170 |
+
"""
|
| 171 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 172 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 173 |
+
|
| 174 |
+
n_batch_pos = pos_emb.size(0)
|
| 175 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 176 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
| 177 |
+
|
| 178 |
+
# (batch, head, time1, d_k)
|
| 179 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 180 |
+
# (batch, head, time1, d_k)
|
| 181 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 182 |
+
|
| 183 |
+
# compute attention score
|
| 184 |
+
# first compute matrix a and matrix c
|
| 185 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 186 |
+
# (batch, head, time1, time2)
|
| 187 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 188 |
+
|
| 189 |
+
# compute matrix b and matrix d
|
| 190 |
+
# (batch, head, time1, time2)
|
| 191 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 192 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 193 |
+
|
| 194 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 195 |
+
self.d_k
|
| 196 |
+
) # (batch, head, time1, time2)
|
| 197 |
+
|
| 198 |
+
return self.forward_attention(v, scores, mask)
|
preprocess/tools/note_transcription/modules/commons/conformer/layers.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from ..layers import LayerNorm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConvolutionModule(nn.Module):
|
| 8 |
+
"""ConvolutionModule in Conformer model.
|
| 9 |
+
Args:
|
| 10 |
+
channels (int): The number of channels of conv layers.
|
| 11 |
+
kernel_size (int): Kernerl size of conv layers.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
| 15 |
+
"""Construct an ConvolutionModule object."""
|
| 16 |
+
super(ConvolutionModule, self).__init__()
|
| 17 |
+
# kernerl_size should be a odd number for 'SAME' padding
|
| 18 |
+
assert (kernel_size - 1) % 2 == 0
|
| 19 |
+
|
| 20 |
+
self.pointwise_conv1 = nn.Conv1d(
|
| 21 |
+
channels,
|
| 22 |
+
2 * channels,
|
| 23 |
+
kernel_size=1,
|
| 24 |
+
stride=1,
|
| 25 |
+
padding=0,
|
| 26 |
+
bias=bias,
|
| 27 |
+
)
|
| 28 |
+
self.depthwise_conv = nn.Conv1d(
|
| 29 |
+
channels,
|
| 30 |
+
channels,
|
| 31 |
+
kernel_size,
|
| 32 |
+
stride=1,
|
| 33 |
+
padding=(kernel_size - 1) // 2,
|
| 34 |
+
groups=channels,
|
| 35 |
+
bias=bias,
|
| 36 |
+
)
|
| 37 |
+
self.norm = nn.BatchNorm1d(channels)
|
| 38 |
+
self.pointwise_conv2 = nn.Conv1d(
|
| 39 |
+
channels,
|
| 40 |
+
channels,
|
| 41 |
+
kernel_size=1,
|
| 42 |
+
stride=1,
|
| 43 |
+
padding=0,
|
| 44 |
+
bias=bias,
|
| 45 |
+
)
|
| 46 |
+
self.activation = activation
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
"""Compute convolution module.
|
| 50 |
+
Args:
|
| 51 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
| 54 |
+
"""
|
| 55 |
+
# exchange the temporal dimension and the feature dimension
|
| 56 |
+
x = x.transpose(1, 2)
|
| 57 |
+
|
| 58 |
+
# GLU mechanism
|
| 59 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
| 60 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
| 61 |
+
|
| 62 |
+
# 1D Depthwise Conv
|
| 63 |
+
x = self.depthwise_conv(x)
|
| 64 |
+
x = self.activation(self.norm(x))
|
| 65 |
+
|
| 66 |
+
x = self.pointwise_conv2(x)
|
| 67 |
+
|
| 68 |
+
return x.transpose(1, 2)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MultiLayeredConv1d(torch.nn.Module):
|
| 72 |
+
"""Multi-layered conv1d for Transformer block.
|
| 73 |
+
This is a module of multi-leyered conv1d designed
|
| 74 |
+
to replace positionwise feed-forward network
|
| 75 |
+
in Transforner block, which is introduced in
|
| 76 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
| 77 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
| 78 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
| 82 |
+
"""Initialize MultiLayeredConv1d module.
|
| 83 |
+
Args:
|
| 84 |
+
in_chans (int): Number of input channels.
|
| 85 |
+
hidden_chans (int): Number of hidden channels.
|
| 86 |
+
kernel_size (int): Kernel size of conv1d.
|
| 87 |
+
dropout_rate (float): Dropout rate.
|
| 88 |
+
"""
|
| 89 |
+
super(MultiLayeredConv1d, self).__init__()
|
| 90 |
+
self.w_1 = torch.nn.Conv1d(
|
| 91 |
+
in_chans,
|
| 92 |
+
hidden_chans,
|
| 93 |
+
kernel_size,
|
| 94 |
+
stride=1,
|
| 95 |
+
padding=(kernel_size - 1) // 2,
|
| 96 |
+
)
|
| 97 |
+
self.w_2 = torch.nn.Conv1d(
|
| 98 |
+
hidden_chans,
|
| 99 |
+
in_chans,
|
| 100 |
+
kernel_size,
|
| 101 |
+
stride=1,
|
| 102 |
+
padding=(kernel_size - 1) // 2,
|
| 103 |
+
)
|
| 104 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
"""Calculate forward propagation.
|
| 108 |
+
Args:
|
| 109 |
+
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
| 110 |
+
Returns:
|
| 111 |
+
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
| 112 |
+
"""
|
| 113 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
| 114 |
+
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Swish(torch.nn.Module):
|
| 118 |
+
"""Construct an Swish object."""
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
"""Return Swich activation function."""
|
| 122 |
+
return x * torch.sigmoid(x)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class EncoderLayer(nn.Module):
|
| 126 |
+
"""Encoder layer module.
|
| 127 |
+
Args:
|
| 128 |
+
size (int): Input dimension.
|
| 129 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 130 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
| 131 |
+
can be used as the argument.
|
| 132 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 133 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
| 134 |
+
can be used as the argument.
|
| 135 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
| 136 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
| 137 |
+
can be used as the argument.
|
| 138 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
| 139 |
+
`ConvlutionModule` instance can be used as the argument.
|
| 140 |
+
dropout_rate (float): Dropout rate.
|
| 141 |
+
normalize_before (bool): Whether to use layer_norm before the first block.
|
| 142 |
+
concat_after (bool): Whether to concat attention layer's input and output.
|
| 143 |
+
if True, additional linear will be applied.
|
| 144 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
| 145 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
size,
|
| 151 |
+
self_attn,
|
| 152 |
+
feed_forward,
|
| 153 |
+
feed_forward_macaron,
|
| 154 |
+
conv_module,
|
| 155 |
+
dropout_rate,
|
| 156 |
+
normalize_before=True,
|
| 157 |
+
concat_after=False,
|
| 158 |
+
):
|
| 159 |
+
"""Construct an EncoderLayer object."""
|
| 160 |
+
super(EncoderLayer, self).__init__()
|
| 161 |
+
self.self_attn = self_attn
|
| 162 |
+
self.feed_forward = feed_forward
|
| 163 |
+
self.feed_forward_macaron = feed_forward_macaron
|
| 164 |
+
self.conv_module = conv_module
|
| 165 |
+
self.norm_ff = LayerNorm(size) # for the FNN module
|
| 166 |
+
self.norm_mha = LayerNorm(size) # for the MHA module
|
| 167 |
+
if feed_forward_macaron is not None:
|
| 168 |
+
self.norm_ff_macaron = LayerNorm(size)
|
| 169 |
+
self.ff_scale = 0.5
|
| 170 |
+
else:
|
| 171 |
+
self.ff_scale = 1.0
|
| 172 |
+
if self.conv_module is not None:
|
| 173 |
+
self.norm_conv = LayerNorm(size) # for the CNN module
|
| 174 |
+
self.norm_final = LayerNorm(size) # for the final output of the block
|
| 175 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 176 |
+
self.size = size
|
| 177 |
+
self.normalize_before = normalize_before
|
| 178 |
+
self.concat_after = concat_after
|
| 179 |
+
if self.concat_after:
|
| 180 |
+
self.concat_linear = nn.Linear(size + size, size)
|
| 181 |
+
|
| 182 |
+
def forward(self, x_input, mask, cache=None):
|
| 183 |
+
"""Compute encoded features.
|
| 184 |
+
Args:
|
| 185 |
+
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
| 186 |
+
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
| 187 |
+
- w/o pos emb: Tensor (#batch, time, size).
|
| 188 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
| 189 |
+
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
| 190 |
+
Returns:
|
| 191 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 192 |
+
torch.Tensor: Mask tensor (#batch, time).
|
| 193 |
+
"""
|
| 194 |
+
if isinstance(x_input, tuple):
|
| 195 |
+
x, pos_emb = x_input[0], x_input[1]
|
| 196 |
+
else:
|
| 197 |
+
x, pos_emb = x_input, None
|
| 198 |
+
|
| 199 |
+
# whether to use macaron style
|
| 200 |
+
if self.feed_forward_macaron is not None:
|
| 201 |
+
residual = x
|
| 202 |
+
if self.normalize_before:
|
| 203 |
+
x = self.norm_ff_macaron(x)
|
| 204 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
| 205 |
+
if not self.normalize_before:
|
| 206 |
+
x = self.norm_ff_macaron(x)
|
| 207 |
+
|
| 208 |
+
# multi-headed self-attention module
|
| 209 |
+
residual = x
|
| 210 |
+
if self.normalize_before:
|
| 211 |
+
x = self.norm_mha(x)
|
| 212 |
+
|
| 213 |
+
if cache is None:
|
| 214 |
+
x_q = x
|
| 215 |
+
else:
|
| 216 |
+
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
| 217 |
+
x_q = x[:, -1:, :]
|
| 218 |
+
residual = residual[:, -1:, :]
|
| 219 |
+
mask = None if mask is None else mask[:, -1:, :]
|
| 220 |
+
|
| 221 |
+
if pos_emb is not None:
|
| 222 |
+
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
| 223 |
+
else:
|
| 224 |
+
x_att = self.self_attn(x_q, x, x, mask)
|
| 225 |
+
|
| 226 |
+
if self.concat_after:
|
| 227 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
| 228 |
+
x = residual + self.concat_linear(x_concat)
|
| 229 |
+
else:
|
| 230 |
+
x = residual + self.dropout(x_att)
|
| 231 |
+
if not self.normalize_before:
|
| 232 |
+
x = self.norm_mha(x)
|
| 233 |
+
|
| 234 |
+
# convolution module
|
| 235 |
+
if self.conv_module is not None:
|
| 236 |
+
residual = x
|
| 237 |
+
if self.normalize_before:
|
| 238 |
+
x = self.norm_conv(x)
|
| 239 |
+
x = residual + self.dropout(self.conv_module(x))
|
| 240 |
+
if not self.normalize_before:
|
| 241 |
+
x = self.norm_conv(x)
|
| 242 |
+
|
| 243 |
+
# feed forward module
|
| 244 |
+
residual = x
|
| 245 |
+
if self.normalize_before:
|
| 246 |
+
x = self.norm_ff(x)
|
| 247 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
| 248 |
+
if not self.normalize_before:
|
| 249 |
+
x = self.norm_ff(x)
|
| 250 |
+
|
| 251 |
+
if self.conv_module is not None:
|
| 252 |
+
x = self.norm_final(x)
|
| 253 |
+
|
| 254 |
+
if cache is not None:
|
| 255 |
+
x = torch.cat([cache, x], dim=1)
|
| 256 |
+
|
| 257 |
+
if pos_emb is not None:
|
| 258 |
+
return (x, pos_emb), mask
|
| 259 |
+
|
| 260 |
+
return x, mask
|
preprocess/tools/note_transcription/modules/commons/conv.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .layers import LayerNorm, Embedding
|
| 7 |
+
|
| 8 |
+
class LambdaLayer(nn.Module):
|
| 9 |
+
def __init__(self, lambd):
|
| 10 |
+
super(LambdaLayer, self).__init__()
|
| 11 |
+
self.lambd = lambd
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self.lambd(x)
|
| 15 |
+
|
| 16 |
+
def init_weights_func(m):
|
| 17 |
+
classname = m.__class__.__name__
|
| 18 |
+
if classname.find("Conv1d") != -1:
|
| 19 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 20 |
+
|
| 21 |
+
def get_norm_builder(norm_type, channels, ln_eps=1e-6):
|
| 22 |
+
if norm_type == 'bn':
|
| 23 |
+
norm_builder = lambda: nn.BatchNorm1d(channels)
|
| 24 |
+
elif norm_type == 'in':
|
| 25 |
+
norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
|
| 26 |
+
elif norm_type == 'gn':
|
| 27 |
+
norm_builder = lambda: nn.GroupNorm(8, channels)
|
| 28 |
+
elif norm_type == 'ln':
|
| 29 |
+
norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
|
| 30 |
+
else:
|
| 31 |
+
norm_builder = lambda: nn.Identity()
|
| 32 |
+
return norm_builder
|
| 33 |
+
|
| 34 |
+
def get_act_builder(act_type):
|
| 35 |
+
if act_type == 'gelu':
|
| 36 |
+
act_builder = lambda: nn.GELU()
|
| 37 |
+
elif act_type == 'relu':
|
| 38 |
+
act_builder = lambda: nn.ReLU(inplace=True)
|
| 39 |
+
elif act_type == 'leakyrelu':
|
| 40 |
+
act_builder = lambda: nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
| 41 |
+
elif act_type == 'swish':
|
| 42 |
+
act_builder = lambda: nn.SiLU(inplace=True)
|
| 43 |
+
else:
|
| 44 |
+
act_builder = lambda: nn.Identity()
|
| 45 |
+
return act_builder
|
| 46 |
+
|
| 47 |
+
class ResidualBlock(nn.Module):
|
| 48 |
+
"""Implements conv->PReLU->norm n-times"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
|
| 51 |
+
c_multiple=2, ln_eps=1e-12, act_type='gelu'):
|
| 52 |
+
super(ResidualBlock, self).__init__()
|
| 53 |
+
|
| 54 |
+
norm_builder = get_norm_builder(norm_type, channels, ln_eps)
|
| 55 |
+
act_builder = get_act_builder(act_type)
|
| 56 |
+
|
| 57 |
+
self.blocks = [
|
| 58 |
+
nn.Sequential(
|
| 59 |
+
norm_builder(),
|
| 60 |
+
nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
|
| 61 |
+
padding=(dilation * (kernel_size - 1)) // 2),
|
| 62 |
+
LambdaLayer(lambda x: x * kernel_size ** -0.5),
|
| 63 |
+
act_builder(),
|
| 64 |
+
nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation),
|
| 65 |
+
)
|
| 66 |
+
for i in range(n)
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
self.blocks = nn.ModuleList(self.blocks)
|
| 70 |
+
self.dropout = dropout
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
|
| 74 |
+
for b in self.blocks:
|
| 75 |
+
x_ = b(x)
|
| 76 |
+
if self.dropout > 0 and self.training:
|
| 77 |
+
x_ = F.dropout(x_, self.dropout, training=self.training)
|
| 78 |
+
x = x + x_
|
| 79 |
+
x = x * nonpadding
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ConvBlocks(nn.Module):
|
| 84 |
+
"""Decodes the expanded phoneme encoding into spectrograms"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, hidden_size, out_dims, dilations, kernel_size,
|
| 87 |
+
norm_type='ln', layers_in_block=2, c_multiple=2,
|
| 88 |
+
dropout=0.0, ln_eps=1e-5,
|
| 89 |
+
init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3, act_type='gelu'):
|
| 90 |
+
super(ConvBlocks, self).__init__()
|
| 91 |
+
self.is_BTC = is_BTC
|
| 92 |
+
if num_layers is not None:
|
| 93 |
+
dilations = [1] * num_layers
|
| 94 |
+
self.res_blocks = nn.Sequential(
|
| 95 |
+
*[ResidualBlock(hidden_size, kernel_size, d,
|
| 96 |
+
n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
|
| 97 |
+
dropout=dropout, ln_eps=ln_eps, act_type=act_type)
|
| 98 |
+
for d in dilations],
|
| 99 |
+
)
|
| 100 |
+
norm = get_norm_builder(norm_type, hidden_size, ln_eps)()
|
| 101 |
+
self.last_norm = norm
|
| 102 |
+
self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
|
| 103 |
+
padding=post_net_kernel // 2)
|
| 104 |
+
if init_weights:
|
| 105 |
+
self.apply(init_weights_func)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, nonpadding=None):
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
:param x: [B, T, H]
|
| 111 |
+
:return: [B, T, H]
|
| 112 |
+
"""
|
| 113 |
+
if self.is_BTC:
|
| 114 |
+
x = x.transpose(1, 2)
|
| 115 |
+
if nonpadding is None:
|
| 116 |
+
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
|
| 117 |
+
elif self.is_BTC:
|
| 118 |
+
nonpadding = nonpadding.transpose(1, 2)
|
| 119 |
+
x = self.res_blocks(x) * nonpadding
|
| 120 |
+
x = self.last_norm(x) * nonpadding
|
| 121 |
+
x = self.post_net1(x) * nonpadding
|
| 122 |
+
if self.is_BTC:
|
| 123 |
+
x = x.transpose(1, 2)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class TextConvEncoder(ConvBlocks):
|
| 128 |
+
def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
|
| 129 |
+
norm_type='ln', layers_in_block=2, c_multiple=2,
|
| 130 |
+
dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
|
| 131 |
+
super().__init__(hidden_size, out_dims, dilations, kernel_size,
|
| 132 |
+
norm_type, layers_in_block, c_multiple,
|
| 133 |
+
dropout, ln_eps, init_weights, num_layers=num_layers,
|
| 134 |
+
post_net_kernel=post_net_kernel)
|
| 135 |
+
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
|
| 136 |
+
self.embed_scale = math.sqrt(hidden_size)
|
| 137 |
+
|
| 138 |
+
def forward(self, txt_tokens):
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
:param txt_tokens: [B, T]
|
| 142 |
+
:return: {
|
| 143 |
+
'encoder_out': [B x T x C]
|
| 144 |
+
}
|
| 145 |
+
"""
|
| 146 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
| 147 |
+
return super().forward(x)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ConditionalConvBlocks(ConvBlocks):
|
| 151 |
+
def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
|
| 152 |
+
norm_type='ln', layers_in_block=2, c_multiple=2,
|
| 153 |
+
dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
|
| 154 |
+
super().__init__(hidden_size, c_out, dilations, kernel_size,
|
| 155 |
+
norm_type, layers_in_block, c_multiple,
|
| 156 |
+
dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
|
| 157 |
+
self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1)
|
| 158 |
+
self.is_BTC_ = is_BTC
|
| 159 |
+
if init_weights:
|
| 160 |
+
self.g_prenet.apply(init_weights_func)
|
| 161 |
+
|
| 162 |
+
def forward(self, x, cond, nonpadding=None):
|
| 163 |
+
if self.is_BTC_:
|
| 164 |
+
x = x.transpose(1, 2)
|
| 165 |
+
cond = cond.transpose(1, 2)
|
| 166 |
+
if nonpadding is not None:
|
| 167 |
+
nonpadding = nonpadding.transpose(1, 2)
|
| 168 |
+
if nonpadding is None:
|
| 169 |
+
nonpadding = x.abs().sum(1)[:, None]
|
| 170 |
+
x = x + self.g_prenet(cond)
|
| 171 |
+
x = x * nonpadding
|
| 172 |
+
x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
|
| 173 |
+
if self.is_BTC_:
|
| 174 |
+
x = x.transpose(1, 2)
|
| 175 |
+
return x
|
preprocess/tools/note_transcription/modules/commons/layers.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.autograd import Function
|
| 4 |
+
|
| 5 |
+
class LayerNorm(torch.nn.LayerNorm):
|
| 6 |
+
"""Layer normalization module.
|
| 7 |
+
:param int nout: output dim size
|
| 8 |
+
:param int dim: dimension to be normalized
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, nout, dim=-1, eps=1e-5):
|
| 12 |
+
"""Construct an LayerNorm object."""
|
| 13 |
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
| 14 |
+
self.dim = dim
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
"""Apply layer normalization.
|
| 18 |
+
:param torch.Tensor x: input tensor
|
| 19 |
+
:return: layer normalized tensor
|
| 20 |
+
:rtype torch.Tensor
|
| 21 |
+
"""
|
| 22 |
+
if self.dim == -1:
|
| 23 |
+
return super(LayerNorm, self).forward(x)
|
| 24 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Reshape(nn.Module):
|
| 28 |
+
def __init__(self, *args):
|
| 29 |
+
super(Reshape, self).__init__()
|
| 30 |
+
self.shape = args
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return x.view(self.shape)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Permute(nn.Module):
|
| 37 |
+
def __init__(self, *args):
|
| 38 |
+
super(Permute, self).__init__()
|
| 39 |
+
self.args = args
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return x.permute(self.args)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def Linear(in_features, out_features, bias=True, init_type='xavier'):
|
| 46 |
+
m = nn.Linear(in_features, out_features, bias)
|
| 47 |
+
if init_type == 'xavier':
|
| 48 |
+
nn.init.xavier_uniform_(m.weight)
|
| 49 |
+
elif init_type == 'kaiming':
|
| 50 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in')
|
| 51 |
+
if bias:
|
| 52 |
+
nn.init.constant_(m.bias, 0.)
|
| 53 |
+
return m
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None, init_type='normal'):
|
| 57 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
| 58 |
+
if init_type == 'normal':
|
| 59 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
| 60 |
+
elif init_type == 'kaiming':
|
| 61 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in')
|
| 62 |
+
if padding_idx is not None:
|
| 63 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
| 64 |
+
return m
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class GradientReverseFunction(Function):
|
| 68 |
+
@staticmethod
|
| 69 |
+
def forward(ctx, input, coeff=1.):
|
| 70 |
+
ctx.coeff = coeff
|
| 71 |
+
output = input * 1.0
|
| 72 |
+
return output
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def backward(ctx, grad_output):
|
| 76 |
+
return grad_output.neg() * ctx.coeff, None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class GRL(nn.Module):
|
| 80 |
+
def __init__(self):
|
| 81 |
+
super(GRL, self).__init__()
|
| 82 |
+
|
| 83 |
+
def forward(self, *input):
|
| 84 |
+
return GradientReverseFunction.apply(*input)
|
| 85 |
+
|
preprocess/tools/note_transcription/modules/commons/rel_transformer.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from .layers import Embedding
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_pad_shape(pad_shape):
|
| 10 |
+
l = pad_shape[::-1]
|
| 11 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 12 |
+
return pad_shape
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def shift_1d(x):
|
| 16 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sequence_mask(length, max_length=None):
|
| 21 |
+
if max_length is None:
|
| 22 |
+
max_length = length.max()
|
| 23 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 24 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Encoder(nn.Module):
|
| 28 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
|
| 29 |
+
window_size=None, block_length=None, pre_ln=False, **kwargs):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.hidden_channels = hidden_channels
|
| 32 |
+
self.filter_channels = filter_channels
|
| 33 |
+
self.n_heads = n_heads
|
| 34 |
+
self.n_layers = n_layers
|
| 35 |
+
self.kernel_size = kernel_size
|
| 36 |
+
self.p_dropout = p_dropout
|
| 37 |
+
self.window_size = window_size
|
| 38 |
+
self.block_length = block_length
|
| 39 |
+
self.pre_ln = pre_ln
|
| 40 |
+
|
| 41 |
+
self.drop = nn.Dropout(p_dropout)
|
| 42 |
+
self.attn_layers = nn.ModuleList()
|
| 43 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 44 |
+
self.ffn_layers = nn.ModuleList()
|
| 45 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 46 |
+
for i in range(self.n_layers):
|
| 47 |
+
self.attn_layers.append(
|
| 48 |
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
|
| 49 |
+
p_dropout=p_dropout, block_length=block_length))
|
| 50 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 51 |
+
self.ffn_layers.append(
|
| 52 |
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
| 53 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 54 |
+
if pre_ln:
|
| 55 |
+
self.last_ln = LayerNorm(hidden_channels)
|
| 56 |
+
|
| 57 |
+
def forward(self, x, x_mask):
|
| 58 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 59 |
+
for i in range(self.n_layers):
|
| 60 |
+
x = x * x_mask
|
| 61 |
+
x_ = x
|
| 62 |
+
if self.pre_ln:
|
| 63 |
+
x = self.norm_layers_1[i](x)
|
| 64 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 65 |
+
y = self.drop(y)
|
| 66 |
+
x = x_ + y
|
| 67 |
+
if not self.pre_ln:
|
| 68 |
+
x = self.norm_layers_1[i](x)
|
| 69 |
+
|
| 70 |
+
x_ = x
|
| 71 |
+
if self.pre_ln:
|
| 72 |
+
x = self.norm_layers_2[i](x)
|
| 73 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 74 |
+
y = self.drop(y)
|
| 75 |
+
x = x_ + y
|
| 76 |
+
if not self.pre_ln:
|
| 77 |
+
x = self.norm_layers_2[i](x)
|
| 78 |
+
if self.pre_ln:
|
| 79 |
+
x = self.last_ln(x)
|
| 80 |
+
x = x * x_mask
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiHeadAttention(nn.Module):
|
| 85 |
+
def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
|
| 86 |
+
block_length=None, proximal_bias=False, proximal_init=False):
|
| 87 |
+
super().__init__()
|
| 88 |
+
assert channels % n_heads == 0
|
| 89 |
+
|
| 90 |
+
self.channels = channels
|
| 91 |
+
self.out_channels = out_channels
|
| 92 |
+
self.n_heads = n_heads
|
| 93 |
+
self.window_size = window_size
|
| 94 |
+
self.heads_share = heads_share
|
| 95 |
+
self.block_length = block_length
|
| 96 |
+
self.proximal_bias = proximal_bias
|
| 97 |
+
self.p_dropout = p_dropout
|
| 98 |
+
self.attn = None
|
| 99 |
+
|
| 100 |
+
self.k_channels = channels // n_heads
|
| 101 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 102 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 103 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 104 |
+
if window_size is not None:
|
| 105 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 106 |
+
rel_stddev = self.k_channels ** -0.5
|
| 107 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 108 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 109 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 110 |
+
self.drop = nn.Dropout(p_dropout)
|
| 111 |
+
|
| 112 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 113 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 114 |
+
if proximal_init:
|
| 115 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
| 116 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
| 117 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 118 |
+
|
| 119 |
+
def forward(self, x, c, attn_mask=None):
|
| 120 |
+
q = self.conv_q(x)
|
| 121 |
+
k = self.conv_k(c)
|
| 122 |
+
v = self.conv_v(c)
|
| 123 |
+
|
| 124 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 125 |
+
|
| 126 |
+
x = self.conv_o(x)
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
def attention(self, query, key, value, mask=None):
|
| 130 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 131 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 132 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 133 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 134 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 135 |
+
|
| 136 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
| 137 |
+
if self.window_size is not None:
|
| 138 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
| 139 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 140 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
| 141 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
| 142 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
| 143 |
+
scores = scores + scores_local
|
| 144 |
+
if self.proximal_bias:
|
| 145 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 146 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
| 147 |
+
if mask is not None:
|
| 148 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 149 |
+
if self.block_length is not None:
|
| 150 |
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
| 151 |
+
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
| 152 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 153 |
+
p_attn = self.drop(p_attn)
|
| 154 |
+
output = torch.matmul(p_attn, value)
|
| 155 |
+
if self.window_size is not None:
|
| 156 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 157 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
| 158 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
| 159 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 160 |
+
return output, p_attn
|
| 161 |
+
|
| 162 |
+
def _matmul_with_relative_values(self, x, y):
|
| 163 |
+
"""
|
| 164 |
+
x: [b, h, l, m]
|
| 165 |
+
y: [h or 1, m, d]
|
| 166 |
+
ret: [b, h, l, d]
|
| 167 |
+
"""
|
| 168 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 169 |
+
return ret
|
| 170 |
+
|
| 171 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 172 |
+
"""
|
| 173 |
+
x: [b, h, l, d]
|
| 174 |
+
y: [h or 1, m, d]
|
| 175 |
+
ret: [b, h, l, m]
|
| 176 |
+
"""
|
| 177 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 178 |
+
return ret
|
| 179 |
+
|
| 180 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 181 |
+
max_relative_position = 2 * self.window_size + 1
|
| 182 |
+
# Pad first before slice to avoid using cond ops.
|
| 183 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 184 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 185 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 186 |
+
if pad_length > 0:
|
| 187 |
+
padded_relative_embeddings = F.pad(
|
| 188 |
+
relative_embeddings,
|
| 189 |
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
| 190 |
+
else:
|
| 191 |
+
padded_relative_embeddings = relative_embeddings
|
| 192 |
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
| 193 |
+
return used_relative_embeddings
|
| 194 |
+
|
| 195 |
+
def _relative_position_to_absolute_position(self, x):
|
| 196 |
+
"""
|
| 197 |
+
x: [b, h, l, 2*l-1]
|
| 198 |
+
ret: [b, h, l, l]
|
| 199 |
+
"""
|
| 200 |
+
batch, heads, length, _ = x.size()
|
| 201 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 202 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 203 |
+
|
| 204 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 205 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 206 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
| 207 |
+
|
| 208 |
+
# Reshape and slice out the padded elements.
|
| 209 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
|
| 210 |
+
return x_final
|
| 211 |
+
|
| 212 |
+
def _absolute_position_to_relative_position(self, x):
|
| 213 |
+
"""
|
| 214 |
+
x: [b, h, l, l]
|
| 215 |
+
ret: [b, h, l, 2*l-1]
|
| 216 |
+
"""
|
| 217 |
+
batch, heads, length, _ = x.size()
|
| 218 |
+
# padd along column
|
| 219 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
| 220 |
+
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
|
| 221 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 222 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 223 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 224 |
+
return x_final
|
| 225 |
+
|
| 226 |
+
def _attention_bias_proximal(self, length):
|
| 227 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 228 |
+
Args:
|
| 229 |
+
length: an integer scalar.
|
| 230 |
+
Returns:
|
| 231 |
+
a Tensor with shape [1, 1, length, length]
|
| 232 |
+
"""
|
| 233 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 234 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 235 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class FFN(nn.Module):
|
| 239 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.in_channels = in_channels
|
| 242 |
+
self.out_channels = out_channels
|
| 243 |
+
self.filter_channels = filter_channels
|
| 244 |
+
self.kernel_size = kernel_size
|
| 245 |
+
self.p_dropout = p_dropout
|
| 246 |
+
self.activation = activation
|
| 247 |
+
|
| 248 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
| 249 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
|
| 250 |
+
self.drop = nn.Dropout(p_dropout)
|
| 251 |
+
|
| 252 |
+
def forward(self, x, x_mask):
|
| 253 |
+
x = self.conv_1(x * x_mask)
|
| 254 |
+
if self.activation == "gelu":
|
| 255 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 256 |
+
else:
|
| 257 |
+
x = torch.relu(x)
|
| 258 |
+
x = self.drop(x)
|
| 259 |
+
x = self.conv_2(x * x_mask)
|
| 260 |
+
return x * x_mask
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class LayerNorm(nn.Module):
|
| 264 |
+
def __init__(self, channels, eps=1e-4):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.channels = channels
|
| 267 |
+
self.eps = eps
|
| 268 |
+
|
| 269 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 270 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 271 |
+
|
| 272 |
+
def forward(self, x):
|
| 273 |
+
n_dims = len(x.shape)
|
| 274 |
+
mean = torch.mean(x, 1, keepdim=True)
|
| 275 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
| 276 |
+
|
| 277 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
| 278 |
+
|
| 279 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
| 280 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class ConvReluNorm(nn.Module):
|
| 285 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.in_channels = in_channels
|
| 288 |
+
self.hidden_channels = hidden_channels
|
| 289 |
+
self.out_channels = out_channels
|
| 290 |
+
self.kernel_size = kernel_size
|
| 291 |
+
self.n_layers = n_layers
|
| 292 |
+
self.p_dropout = p_dropout
|
| 293 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
| 294 |
+
|
| 295 |
+
self.conv_layers = nn.ModuleList()
|
| 296 |
+
self.norm_layers = nn.ModuleList()
|
| 297 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
| 298 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 299 |
+
self.relu_drop = nn.Sequential(
|
| 300 |
+
nn.ReLU(),
|
| 301 |
+
nn.Dropout(p_dropout))
|
| 302 |
+
for _ in range(n_layers - 1):
|
| 303 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
| 304 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 305 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 306 |
+
self.proj.weight.data.zero_()
|
| 307 |
+
self.proj.bias.data.zero_()
|
| 308 |
+
|
| 309 |
+
def forward(self, x, x_mask):
|
| 310 |
+
x_org = x
|
| 311 |
+
for i in range(self.n_layers):
|
| 312 |
+
x = self.conv_layers[i](x * x_mask)
|
| 313 |
+
x = self.norm_layers[i](x)
|
| 314 |
+
x = self.relu_drop(x)
|
| 315 |
+
x = x_org + self.proj(x)
|
| 316 |
+
return x * x_mask
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class RelTransformerEncoder(nn.Module):
|
| 320 |
+
def __init__(self,
|
| 321 |
+
n_vocab,
|
| 322 |
+
out_channels,
|
| 323 |
+
hidden_channels,
|
| 324 |
+
filter_channels,
|
| 325 |
+
n_heads,
|
| 326 |
+
n_layers,
|
| 327 |
+
kernel_size,
|
| 328 |
+
p_dropout=0.0,
|
| 329 |
+
window_size=4,
|
| 330 |
+
block_length=None,
|
| 331 |
+
prenet=True,
|
| 332 |
+
pre_ln=True,
|
| 333 |
+
):
|
| 334 |
+
|
| 335 |
+
super().__init__()
|
| 336 |
+
|
| 337 |
+
self.n_vocab = n_vocab
|
| 338 |
+
self.out_channels = out_channels
|
| 339 |
+
self.hidden_channels = hidden_channels
|
| 340 |
+
self.filter_channels = filter_channels
|
| 341 |
+
self.n_heads = n_heads
|
| 342 |
+
self.n_layers = n_layers
|
| 343 |
+
self.kernel_size = kernel_size
|
| 344 |
+
self.p_dropout = p_dropout
|
| 345 |
+
self.window_size = window_size
|
| 346 |
+
self.block_length = block_length
|
| 347 |
+
self.prenet = prenet
|
| 348 |
+
if n_vocab > 0:
|
| 349 |
+
self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
|
| 350 |
+
|
| 351 |
+
if prenet:
|
| 352 |
+
self.pre = ConvReluNorm(hidden_channels, hidden_channels, hidden_channels,
|
| 353 |
+
kernel_size=5, n_layers=3, p_dropout=0)
|
| 354 |
+
self.encoder = Encoder(
|
| 355 |
+
hidden_channels,
|
| 356 |
+
filter_channels,
|
| 357 |
+
n_heads,
|
| 358 |
+
n_layers,
|
| 359 |
+
kernel_size,
|
| 360 |
+
p_dropout,
|
| 361 |
+
window_size=window_size,
|
| 362 |
+
block_length=block_length,
|
| 363 |
+
pre_ln=pre_ln,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def forward(self, x, x_mask=None):
|
| 367 |
+
if self.n_vocab > 0:
|
| 368 |
+
x_lengths = (x > 0).long().sum(-1)
|
| 369 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
| 370 |
+
else:
|
| 371 |
+
x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
|
| 372 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
| 373 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 374 |
+
|
| 375 |
+
if self.prenet:
|
| 376 |
+
x = self.pre(x, x_mask)
|
| 377 |
+
x = self.encoder(x, x_mask)
|
| 378 |
+
return x.transpose(1, 2)
|
preprocess/tools/note_transcription/modules/commons/rnn.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PreNet(nn.Module):
|
| 7 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
| 10 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
| 11 |
+
self.p = dropout
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
x = self.fc1(x)
|
| 15 |
+
x = F.relu(x)
|
| 16 |
+
x = F.dropout(x, self.p, training=self.training)
|
| 17 |
+
x = self.fc2(x)
|
| 18 |
+
x = F.relu(x)
|
| 19 |
+
x = F.dropout(x, self.p, training=self.training)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class HighwayNetwork(nn.Module):
|
| 24 |
+
def __init__(self, size):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.W1 = nn.Linear(size, size)
|
| 27 |
+
self.W2 = nn.Linear(size, size)
|
| 28 |
+
self.W1.bias.data.fill_(0.)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
x1 = self.W1(x)
|
| 32 |
+
x2 = self.W2(x)
|
| 33 |
+
g = torch.sigmoid(x2)
|
| 34 |
+
y = g * F.relu(x1) + (1. - g) * x
|
| 35 |
+
return y
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BatchNormConv(nn.Module):
|
| 39 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
| 42 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
| 43 |
+
self.relu = relu
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.conv(x)
|
| 47 |
+
x = F.relu(x) if self.relu is True else x
|
| 48 |
+
return self.bnorm(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConvNorm(torch.nn.Module):
|
| 52 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
| 53 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
| 54 |
+
super(ConvNorm, self).__init__()
|
| 55 |
+
if padding is None:
|
| 56 |
+
assert (kernel_size % 2 == 1)
|
| 57 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
| 58 |
+
|
| 59 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
| 60 |
+
kernel_size=kernel_size, stride=stride,
|
| 61 |
+
padding=padding, dilation=dilation,
|
| 62 |
+
bias=bias)
|
| 63 |
+
|
| 64 |
+
torch.nn.init.xavier_uniform_(
|
| 65 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
| 66 |
+
|
| 67 |
+
def forward(self, signal):
|
| 68 |
+
conv_signal = self.conv(signal)
|
| 69 |
+
return conv_signal
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CBHG(nn.Module):
|
| 73 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
# List of all rnns to call `flatten_parameters()` on
|
| 77 |
+
self._to_flatten = []
|
| 78 |
+
|
| 79 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
| 80 |
+
self.conv1d_bank = nn.ModuleList()
|
| 81 |
+
for k in self.bank_kernels:
|
| 82 |
+
conv = BatchNormConv(in_channels, channels, k)
|
| 83 |
+
self.conv1d_bank.append(conv)
|
| 84 |
+
|
| 85 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
| 86 |
+
|
| 87 |
+
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
| 88 |
+
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
| 89 |
+
|
| 90 |
+
# Fix the highway input if necessary
|
| 91 |
+
if proj_channels[-1] != channels:
|
| 92 |
+
self.highway_mismatch = True
|
| 93 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
| 94 |
+
else:
|
| 95 |
+
self.highway_mismatch = False
|
| 96 |
+
|
| 97 |
+
self.highways = nn.ModuleList()
|
| 98 |
+
for i in range(num_highways):
|
| 99 |
+
hn = HighwayNetwork(channels)
|
| 100 |
+
self.highways.append(hn)
|
| 101 |
+
|
| 102 |
+
self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
|
| 103 |
+
self._to_flatten.append(self.rnn)
|
| 104 |
+
|
| 105 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
| 106 |
+
self._flatten_parameters()
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
| 110 |
+
# the model gets replicated, making it no longer guaranteed that the
|
| 111 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
| 112 |
+
self._flatten_parameters()
|
| 113 |
+
|
| 114 |
+
# Save these for later
|
| 115 |
+
residual = x
|
| 116 |
+
seq_len = x.size(-1)
|
| 117 |
+
conv_bank = []
|
| 118 |
+
|
| 119 |
+
# Convolution Bank
|
| 120 |
+
for conv in self.conv1d_bank:
|
| 121 |
+
c = conv(x) # Convolution
|
| 122 |
+
conv_bank.append(c[:, :, :seq_len])
|
| 123 |
+
|
| 124 |
+
# Stack along the channel axis
|
| 125 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
| 126 |
+
|
| 127 |
+
# dump the last padding to fit residual
|
| 128 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
| 129 |
+
|
| 130 |
+
# Conv1d projections
|
| 131 |
+
x = self.conv_project1(x)
|
| 132 |
+
x = self.conv_project2(x)
|
| 133 |
+
|
| 134 |
+
# Residual Connect
|
| 135 |
+
x = x + residual
|
| 136 |
+
|
| 137 |
+
# Through the highways
|
| 138 |
+
x = x.transpose(1, 2)
|
| 139 |
+
if self.highway_mismatch is True:
|
| 140 |
+
x = self.pre_highway(x)
|
| 141 |
+
for h in self.highways:
|
| 142 |
+
x = h(x)
|
| 143 |
+
|
| 144 |
+
# And then the RNN
|
| 145 |
+
x, _ = self.rnn(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
def _flatten_parameters(self):
|
| 149 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
| 150 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
| 151 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class TacotronEncoder(nn.Module):
|
| 155 |
+
def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
| 158 |
+
self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
|
| 159 |
+
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
| 160 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
| 161 |
+
num_highways=num_highways)
|
| 162 |
+
self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
x = self.embedding(x)
|
| 166 |
+
x = self.pre_net(x)
|
| 167 |
+
x.transpose_(1, 2)
|
| 168 |
+
x = self.cbhg(x)
|
| 169 |
+
x = self.proj_out(x)
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RNNEncoder(nn.Module):
|
| 174 |
+
def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
|
| 175 |
+
super(RNNEncoder, self).__init__()
|
| 176 |
+
self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
|
| 177 |
+
convolutions = []
|
| 178 |
+
for _ in range(n_convolutions):
|
| 179 |
+
conv_layer = nn.Sequential(
|
| 180 |
+
ConvNorm(embedding_dim,
|
| 181 |
+
embedding_dim,
|
| 182 |
+
kernel_size=kernel_size, stride=1,
|
| 183 |
+
padding=int((kernel_size - 1) / 2),
|
| 184 |
+
dilation=1, w_init_gain='relu'),
|
| 185 |
+
nn.BatchNorm1d(embedding_dim))
|
| 186 |
+
convolutions.append(conv_layer)
|
| 187 |
+
self.convolutions = nn.ModuleList(convolutions)
|
| 188 |
+
|
| 189 |
+
self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
|
| 190 |
+
batch_first=True, bidirectional=True)
|
| 191 |
+
|
| 192 |
+
def forward(self, x):
|
| 193 |
+
input_lengths = (x > 0).sum(-1)
|
| 194 |
+
input_lengths = input_lengths.cpu().numpy()
|
| 195 |
+
|
| 196 |
+
x = self.embedding(x)
|
| 197 |
+
x = x.transpose(1, 2) # [B, H, T]
|
| 198 |
+
for conv in self.convolutions:
|
| 199 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
|
| 200 |
+
x = x.transpose(1, 2) # [B, T, H]
|
| 201 |
+
|
| 202 |
+
# pytorch tensor are not reversible, hence the conversion
|
| 203 |
+
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
|
| 204 |
+
|
| 205 |
+
self.lstm.flatten_parameters()
|
| 206 |
+
outputs, _ = self.lstm(x)
|
| 207 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
|
| 208 |
+
|
| 209 |
+
return outputs
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class DecoderRNN(torch.nn.Module):
|
| 213 |
+
def __init__(self, hidden_size, decoder_rnn_dim, dropout):
|
| 214 |
+
super(DecoderRNN, self).__init__()
|
| 215 |
+
self.in_conv1d = nn.Sequential(
|
| 216 |
+
torch.nn.Conv1d(
|
| 217 |
+
in_channels=hidden_size,
|
| 218 |
+
out_channels=hidden_size,
|
| 219 |
+
kernel_size=9, padding=4,
|
| 220 |
+
),
|
| 221 |
+
torch.nn.ReLU(),
|
| 222 |
+
torch.nn.Conv1d(
|
| 223 |
+
in_channels=hidden_size,
|
| 224 |
+
out_channels=hidden_size,
|
| 225 |
+
kernel_size=9, padding=4,
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
self.ln = nn.LayerNorm(hidden_size)
|
| 229 |
+
if decoder_rnn_dim == 0:
|
| 230 |
+
decoder_rnn_dim = hidden_size * 2
|
| 231 |
+
self.rnn = torch.nn.LSTM(
|
| 232 |
+
input_size=hidden_size,
|
| 233 |
+
hidden_size=decoder_rnn_dim,
|
| 234 |
+
num_layers=1,
|
| 235 |
+
batch_first=True,
|
| 236 |
+
bidirectional=True,
|
| 237 |
+
dropout=dropout
|
| 238 |
+
)
|
| 239 |
+
self.rnn.flatten_parameters()
|
| 240 |
+
self.conv1d = torch.nn.Conv1d(
|
| 241 |
+
in_channels=decoder_rnn_dim * 2,
|
| 242 |
+
out_channels=hidden_size,
|
| 243 |
+
kernel_size=3,
|
| 244 |
+
padding=1,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def forward(self, x):
|
| 248 |
+
input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
|
| 249 |
+
input_lengths = input_masks.sum([-1, -2])
|
| 250 |
+
input_lengths = input_lengths.cpu().numpy()
|
| 251 |
+
|
| 252 |
+
x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
|
| 253 |
+
x = self.ln(x)
|
| 254 |
+
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
|
| 255 |
+
self.rnn.flatten_parameters()
|
| 256 |
+
x, _ = self.rnn(x) # [B, T, C]
|
| 257 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
| 258 |
+
x = x * input_masks
|
| 259 |
+
pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
|
| 260 |
+
pre_mel = pre_mel * input_masks
|
| 261 |
+
return pre_mel
|
preprocess/tools/note_transcription/modules/commons/transformer.py
ADDED
|
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import Parameter, Linear
|
| 5 |
+
from .layers import LayerNorm, Embedding
|
| 6 |
+
from ...utils.nn.seq_utils import (
|
| 7 |
+
get_incremental_state,
|
| 8 |
+
set_incremental_state,
|
| 9 |
+
softmax,
|
| 10 |
+
make_positions,
|
| 11 |
+
)
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 2000
|
| 15 |
+
DEFAULT_MAX_TARGET_POSITIONS = 2000
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
| 19 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
| 20 |
+
|
| 21 |
+
Padding symbols are ignored.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.embedding_dim = embedding_dim
|
| 27 |
+
self.padding_idx = padding_idx
|
| 28 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
| 29 |
+
init_size,
|
| 30 |
+
embedding_dim,
|
| 31 |
+
padding_idx,
|
| 32 |
+
)
|
| 33 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
| 37 |
+
"""Build sinusoidal embeddings.
|
| 38 |
+
|
| 39 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
| 40 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
| 41 |
+
"""
|
| 42 |
+
half_dim = embedding_dim // 2
|
| 43 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 44 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
| 45 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
| 46 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
| 47 |
+
if embedding_dim % 2 == 1:
|
| 48 |
+
# zero pad
|
| 49 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
| 50 |
+
if padding_idx is not None:
|
| 51 |
+
emb[padding_idx, :] = 0
|
| 52 |
+
return emb
|
| 53 |
+
|
| 54 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
| 55 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
| 56 |
+
bsz, seq_len = input.shape[:2]
|
| 57 |
+
max_pos = self.padding_idx + 1 + seq_len
|
| 58 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
| 59 |
+
# recompute/expand embeddings if needed
|
| 60 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
| 61 |
+
max_pos,
|
| 62 |
+
self.embedding_dim,
|
| 63 |
+
self.padding_idx,
|
| 64 |
+
)
|
| 65 |
+
self.weights = self.weights.to(self._float_tensor)
|
| 66 |
+
|
| 67 |
+
if incremental_state is not None:
|
| 68 |
+
# positions is the same for every token when decoding a single step
|
| 69 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
| 70 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
| 71 |
+
|
| 72 |
+
positions = make_positions(input, self.padding_idx) if positions is None else positions
|
| 73 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
| 74 |
+
|
| 75 |
+
def max_positions(self):
|
| 76 |
+
"""Maximum number of supported positions."""
|
| 77 |
+
return int(1e5) # an arbitrary large number
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TransformerFFNLayer(nn.Module):
|
| 81 |
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.kernel_size = kernel_size
|
| 84 |
+
self.dropout = dropout
|
| 85 |
+
self.act = act
|
| 86 |
+
if padding == 'SAME':
|
| 87 |
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
|
| 88 |
+
elif padding == 'LEFT':
|
| 89 |
+
self.ffn_1 = nn.Sequential(
|
| 90 |
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
| 91 |
+
nn.Conv1d(hidden_size, filter_size, kernel_size)
|
| 92 |
+
)
|
| 93 |
+
self.ffn_2 = Linear(filter_size, hidden_size)
|
| 94 |
+
|
| 95 |
+
def forward(self, x, incremental_state=None):
|
| 96 |
+
# x: T x B x C
|
| 97 |
+
if incremental_state is not None:
|
| 98 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 99 |
+
if 'prev_input' in saved_state:
|
| 100 |
+
prev_input = saved_state['prev_input']
|
| 101 |
+
x = torch.cat((prev_input, x), dim=0)
|
| 102 |
+
x = x[-self.kernel_size:]
|
| 103 |
+
saved_state['prev_input'] = x
|
| 104 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 105 |
+
|
| 106 |
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
| 107 |
+
x = x * self.kernel_size ** -0.5
|
| 108 |
+
|
| 109 |
+
if incremental_state is not None:
|
| 110 |
+
x = x[-1:]
|
| 111 |
+
if self.act == 'gelu':
|
| 112 |
+
x = F.gelu(x)
|
| 113 |
+
if self.act == 'relu':
|
| 114 |
+
x = F.relu(x)
|
| 115 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 116 |
+
x = self.ffn_2(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
def _get_input_buffer(self, incremental_state):
|
| 120 |
+
return get_incremental_state(
|
| 121 |
+
self,
|
| 122 |
+
incremental_state,
|
| 123 |
+
'f',
|
| 124 |
+
) or {}
|
| 125 |
+
|
| 126 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
| 127 |
+
set_incremental_state(
|
| 128 |
+
self,
|
| 129 |
+
incremental_state,
|
| 130 |
+
'f',
|
| 131 |
+
buffer,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def clear_buffer(self, incremental_state):
|
| 135 |
+
if incremental_state is not None:
|
| 136 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 137 |
+
if 'prev_input' in saved_state:
|
| 138 |
+
del saved_state['prev_input']
|
| 139 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class MultiheadAttention(nn.Module):
|
| 143 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
| 144 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
| 145 |
+
encoder_decoder_attention=False):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.embed_dim = embed_dim
|
| 148 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 149 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 150 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 151 |
+
|
| 152 |
+
self.num_heads = num_heads
|
| 153 |
+
self.dropout = dropout
|
| 154 |
+
self.head_dim = embed_dim // num_heads
|
| 155 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 156 |
+
self.scaling = self.head_dim ** -0.5
|
| 157 |
+
|
| 158 |
+
self.self_attention = self_attention
|
| 159 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 160 |
+
|
| 161 |
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
| 162 |
+
'value to be of the same size'
|
| 163 |
+
|
| 164 |
+
if self.qkv_same_dim:
|
| 165 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
| 166 |
+
else:
|
| 167 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
| 168 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
| 169 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
| 170 |
+
|
| 171 |
+
if bias:
|
| 172 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
| 173 |
+
else:
|
| 174 |
+
self.register_parameter('in_proj_bias', None)
|
| 175 |
+
|
| 176 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 177 |
+
|
| 178 |
+
if add_bias_kv:
|
| 179 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 180 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 181 |
+
else:
|
| 182 |
+
self.bias_k = self.bias_v = None
|
| 183 |
+
|
| 184 |
+
self.add_zero_attn = add_zero_attn
|
| 185 |
+
|
| 186 |
+
self.reset_parameters()
|
| 187 |
+
|
| 188 |
+
self.enable_torch_version = False
|
| 189 |
+
if hasattr(F, "multi_head_attention_forward"):
|
| 190 |
+
self.enable_torch_version = True
|
| 191 |
+
else:
|
| 192 |
+
self.enable_torch_version = False
|
| 193 |
+
self.last_attn_probs = None
|
| 194 |
+
|
| 195 |
+
def reset_parameters(self):
|
| 196 |
+
if self.qkv_same_dim:
|
| 197 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
| 198 |
+
else:
|
| 199 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
| 200 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
| 201 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
| 202 |
+
|
| 203 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 204 |
+
if self.in_proj_bias is not None:
|
| 205 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
| 206 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
| 207 |
+
if self.bias_k is not None:
|
| 208 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 209 |
+
if self.bias_v is not None:
|
| 210 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 211 |
+
|
| 212 |
+
def forward(
|
| 213 |
+
self,
|
| 214 |
+
query, key, value,
|
| 215 |
+
key_padding_mask=None,
|
| 216 |
+
incremental_state=None,
|
| 217 |
+
need_weights=True,
|
| 218 |
+
static_kv=False,
|
| 219 |
+
attn_mask=None,
|
| 220 |
+
before_softmax=False,
|
| 221 |
+
need_head_weights=False,
|
| 222 |
+
enc_dec_attn_constraint_mask=None,
|
| 223 |
+
reset_attn_weight=None
|
| 224 |
+
):
|
| 225 |
+
"""Input shape: Time x Batch x Channel
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 229 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 230 |
+
padding elements are indicated by 1s.
|
| 231 |
+
need_weights (bool, optional): return the attention weights,
|
| 232 |
+
averaged over heads (default: False).
|
| 233 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 234 |
+
implement causal attention, where the mask prevents the
|
| 235 |
+
attention from looking forward in time (default: None).
|
| 236 |
+
before_softmax (bool, optional): return the raw attention
|
| 237 |
+
weights and values before the attention softmax.
|
| 238 |
+
need_head_weights (bool, optional): return the attention
|
| 239 |
+
weights for each head. Implies *need_weights*. Default:
|
| 240 |
+
return the average attention weights over all heads.
|
| 241 |
+
"""
|
| 242 |
+
if need_head_weights:
|
| 243 |
+
need_weights = True
|
| 244 |
+
|
| 245 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 246 |
+
assert embed_dim == self.embed_dim
|
| 247 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 248 |
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
| 249 |
+
if self.qkv_same_dim:
|
| 250 |
+
return F.multi_head_attention_forward(query, key, value,
|
| 251 |
+
self.embed_dim, self.num_heads,
|
| 252 |
+
self.in_proj_weight,
|
| 253 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
| 254 |
+
self.add_zero_attn, self.dropout,
|
| 255 |
+
self.out_proj.weight, self.out_proj.bias,
|
| 256 |
+
self.training, key_padding_mask, need_weights,
|
| 257 |
+
attn_mask)
|
| 258 |
+
else:
|
| 259 |
+
return F.multi_head_attention_forward(query, key, value,
|
| 260 |
+
self.embed_dim, self.num_heads,
|
| 261 |
+
torch.empty([0]),
|
| 262 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
| 263 |
+
self.add_zero_attn, self.dropout,
|
| 264 |
+
self.out_proj.weight, self.out_proj.bias,
|
| 265 |
+
self.training, key_padding_mask, need_weights,
|
| 266 |
+
attn_mask, use_separate_proj_weight=True,
|
| 267 |
+
q_proj_weight=self.q_proj_weight,
|
| 268 |
+
k_proj_weight=self.k_proj_weight,
|
| 269 |
+
v_proj_weight=self.v_proj_weight)
|
| 270 |
+
|
| 271 |
+
if incremental_state is not None:
|
| 272 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 273 |
+
if 'prev_key' in saved_state:
|
| 274 |
+
# previous time steps are cached - no need to recompute
|
| 275 |
+
# key and value if they are static
|
| 276 |
+
if static_kv:
|
| 277 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 278 |
+
key = value = None
|
| 279 |
+
else:
|
| 280 |
+
saved_state = None
|
| 281 |
+
|
| 282 |
+
if self.self_attention:
|
| 283 |
+
# self-attention
|
| 284 |
+
q, k, v = self.in_proj_qkv(query)
|
| 285 |
+
elif self.encoder_decoder_attention:
|
| 286 |
+
# encoder-decoder attention
|
| 287 |
+
q = self.in_proj_q(query)
|
| 288 |
+
if key is None:
|
| 289 |
+
assert value is None
|
| 290 |
+
k = v = None
|
| 291 |
+
else:
|
| 292 |
+
k = self.in_proj_k(key)
|
| 293 |
+
v = self.in_proj_v(key)
|
| 294 |
+
|
| 295 |
+
else:
|
| 296 |
+
q = self.in_proj_q(query)
|
| 297 |
+
k = self.in_proj_k(key)
|
| 298 |
+
v = self.in_proj_v(value)
|
| 299 |
+
q *= self.scaling
|
| 300 |
+
|
| 301 |
+
if self.bias_k is not None:
|
| 302 |
+
assert self.bias_v is not None
|
| 303 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 304 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 305 |
+
if attn_mask is not None:
|
| 306 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
| 307 |
+
if key_padding_mask is not None:
|
| 308 |
+
key_padding_mask = torch.cat(
|
| 309 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
| 310 |
+
|
| 311 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 312 |
+
if k is not None:
|
| 313 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 314 |
+
if v is not None:
|
| 315 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 316 |
+
|
| 317 |
+
if saved_state is not None:
|
| 318 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 319 |
+
if 'prev_key' in saved_state:
|
| 320 |
+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
| 321 |
+
if static_kv:
|
| 322 |
+
k = prev_key
|
| 323 |
+
else:
|
| 324 |
+
k = torch.cat((prev_key, k), dim=1)
|
| 325 |
+
if 'prev_value' in saved_state:
|
| 326 |
+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
| 327 |
+
if static_kv:
|
| 328 |
+
v = prev_value
|
| 329 |
+
else:
|
| 330 |
+
v = torch.cat((prev_value, v), dim=1)
|
| 331 |
+
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
|
| 332 |
+
prev_key_padding_mask = saved_state['prev_key_padding_mask']
|
| 333 |
+
if static_kv:
|
| 334 |
+
key_padding_mask = prev_key_padding_mask
|
| 335 |
+
else:
|
| 336 |
+
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
|
| 337 |
+
|
| 338 |
+
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 339 |
+
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 340 |
+
saved_state['prev_key_padding_mask'] = key_padding_mask
|
| 341 |
+
|
| 342 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 343 |
+
|
| 344 |
+
src_len = k.size(1)
|
| 345 |
+
|
| 346 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 347 |
+
# not supporting Optional types.
|
| 348 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
| 349 |
+
key_padding_mask = None
|
| 350 |
+
|
| 351 |
+
if key_padding_mask is not None:
|
| 352 |
+
assert key_padding_mask.size(0) == bsz
|
| 353 |
+
assert key_padding_mask.size(1) == src_len
|
| 354 |
+
|
| 355 |
+
if self.add_zero_attn:
|
| 356 |
+
src_len += 1
|
| 357 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 358 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 359 |
+
if attn_mask is not None:
|
| 360 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
| 361 |
+
if key_padding_mask is not None:
|
| 362 |
+
key_padding_mask = torch.cat(
|
| 363 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
| 364 |
+
|
| 365 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 366 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 367 |
+
|
| 368 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 369 |
+
|
| 370 |
+
if attn_mask is not None:
|
| 371 |
+
if len(attn_mask.shape) == 2:
|
| 372 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 373 |
+
elif len(attn_mask.shape) == 3:
|
| 374 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
| 375 |
+
bsz * self.num_heads, tgt_len, src_len)
|
| 376 |
+
attn_weights = attn_weights + attn_mask
|
| 377 |
+
|
| 378 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
| 379 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 380 |
+
attn_weights = attn_weights.masked_fill(
|
| 381 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
| 382 |
+
-1e8,
|
| 383 |
+
)
|
| 384 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 385 |
+
|
| 386 |
+
if key_padding_mask is not None:
|
| 387 |
+
# don't attend to padding symbols
|
| 388 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 389 |
+
attn_weights = attn_weights.masked_fill(
|
| 390 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
| 391 |
+
-1e8,
|
| 392 |
+
)
|
| 393 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 394 |
+
|
| 395 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 396 |
+
|
| 397 |
+
if before_softmax:
|
| 398 |
+
return attn_weights, v
|
| 399 |
+
|
| 400 |
+
attn_weights_float = softmax(attn_weights, dim=-1)
|
| 401 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 402 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
| 403 |
+
|
| 404 |
+
if reset_attn_weight is not None:
|
| 405 |
+
if reset_attn_weight:
|
| 406 |
+
self.last_attn_probs = attn_probs.detach()
|
| 407 |
+
else:
|
| 408 |
+
assert self.last_attn_probs is not None
|
| 409 |
+
attn_probs = self.last_attn_probs
|
| 410 |
+
attn = torch.bmm(attn_probs, v)
|
| 411 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 412 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 413 |
+
attn = self.out_proj(attn)
|
| 414 |
+
|
| 415 |
+
if need_weights:
|
| 416 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
| 417 |
+
if not need_head_weights:
|
| 418 |
+
# average attention weights over heads
|
| 419 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 420 |
+
else:
|
| 421 |
+
attn_weights = None
|
| 422 |
+
|
| 423 |
+
return attn, (attn_weights, attn_logits)
|
| 424 |
+
|
| 425 |
+
def in_proj_qkv(self, query):
|
| 426 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
| 427 |
+
|
| 428 |
+
def in_proj_q(self, query):
|
| 429 |
+
if self.qkv_same_dim:
|
| 430 |
+
return self._in_proj(query, end=self.embed_dim)
|
| 431 |
+
else:
|
| 432 |
+
bias = self.in_proj_bias
|
| 433 |
+
if bias is not None:
|
| 434 |
+
bias = bias[:self.embed_dim]
|
| 435 |
+
return F.linear(query, self.q_proj_weight, bias)
|
| 436 |
+
|
| 437 |
+
def in_proj_k(self, key):
|
| 438 |
+
if self.qkv_same_dim:
|
| 439 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
| 440 |
+
else:
|
| 441 |
+
weight = self.k_proj_weight
|
| 442 |
+
bias = self.in_proj_bias
|
| 443 |
+
if bias is not None:
|
| 444 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
| 445 |
+
return F.linear(key, weight, bias)
|
| 446 |
+
|
| 447 |
+
def in_proj_v(self, value):
|
| 448 |
+
if self.qkv_same_dim:
|
| 449 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
| 450 |
+
else:
|
| 451 |
+
weight = self.v_proj_weight
|
| 452 |
+
bias = self.in_proj_bias
|
| 453 |
+
if bias is not None:
|
| 454 |
+
bias = bias[2 * self.embed_dim:]
|
| 455 |
+
return F.linear(value, weight, bias)
|
| 456 |
+
|
| 457 |
+
def _in_proj(self, input, start=0, end=None):
|
| 458 |
+
weight = self.in_proj_weight
|
| 459 |
+
bias = self.in_proj_bias
|
| 460 |
+
weight = weight[start:end, :]
|
| 461 |
+
if bias is not None:
|
| 462 |
+
bias = bias[start:end]
|
| 463 |
+
return F.linear(input, weight, bias)
|
| 464 |
+
|
| 465 |
+
def _get_input_buffer(self, incremental_state):
|
| 466 |
+
return get_incremental_state(
|
| 467 |
+
self,
|
| 468 |
+
incremental_state,
|
| 469 |
+
'attn_state',
|
| 470 |
+
) or {}
|
| 471 |
+
|
| 472 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
| 473 |
+
set_incremental_state(
|
| 474 |
+
self,
|
| 475 |
+
incremental_state,
|
| 476 |
+
'attn_state',
|
| 477 |
+
buffer,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
| 481 |
+
return attn_weights
|
| 482 |
+
|
| 483 |
+
def clear_buffer(self, incremental_state=None):
|
| 484 |
+
if incremental_state is not None:
|
| 485 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 486 |
+
if 'prev_key' in saved_state:
|
| 487 |
+
del saved_state['prev_key']
|
| 488 |
+
if 'prev_value' in saved_state:
|
| 489 |
+
del saved_state['prev_value']
|
| 490 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class EncSALayer(nn.Module):
|
| 494 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
| 495 |
+
relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu'):
|
| 496 |
+
super().__init__()
|
| 497 |
+
self.c = c
|
| 498 |
+
self.dropout = dropout
|
| 499 |
+
self.num_heads = num_heads
|
| 500 |
+
if num_heads > 0:
|
| 501 |
+
self.layer_norm1 = LayerNorm(c)
|
| 502 |
+
self.self_attn = MultiheadAttention(
|
| 503 |
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
|
| 504 |
+
self.layer_norm2 = LayerNorm(c)
|
| 505 |
+
self.ffn = TransformerFFNLayer(
|
| 506 |
+
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
| 507 |
+
|
| 508 |
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
| 509 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
| 510 |
+
if layer_norm_training is not None:
|
| 511 |
+
self.layer_norm1.training = layer_norm_training
|
| 512 |
+
self.layer_norm2.training = layer_norm_training
|
| 513 |
+
if self.num_heads > 0:
|
| 514 |
+
residual = x
|
| 515 |
+
x = self.layer_norm1(x)
|
| 516 |
+
x, _, = self.self_attn(
|
| 517 |
+
query=x,
|
| 518 |
+
key=x,
|
| 519 |
+
value=x,
|
| 520 |
+
key_padding_mask=encoder_padding_mask
|
| 521 |
+
)
|
| 522 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 523 |
+
x = residual + x
|
| 524 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
| 525 |
+
|
| 526 |
+
residual = x
|
| 527 |
+
x = self.layer_norm2(x)
|
| 528 |
+
x = self.ffn(x)
|
| 529 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 530 |
+
x = residual + x
|
| 531 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
| 532 |
+
return x
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class DecSALayer(nn.Module):
|
| 536 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
|
| 537 |
+
kernel_size=9, act='gelu'):
|
| 538 |
+
super().__init__()
|
| 539 |
+
self.c = c
|
| 540 |
+
self.dropout = dropout
|
| 541 |
+
self.layer_norm1 = LayerNorm(c)
|
| 542 |
+
self.self_attn = MultiheadAttention(
|
| 543 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
| 544 |
+
)
|
| 545 |
+
self.layer_norm2 = LayerNorm(c)
|
| 546 |
+
self.encoder_attn = MultiheadAttention(
|
| 547 |
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
| 548 |
+
)
|
| 549 |
+
self.layer_norm3 = LayerNorm(c)
|
| 550 |
+
self.ffn = TransformerFFNLayer(
|
| 551 |
+
c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
| 552 |
+
|
| 553 |
+
def forward(
|
| 554 |
+
self,
|
| 555 |
+
x,
|
| 556 |
+
encoder_out=None,
|
| 557 |
+
encoder_padding_mask=None,
|
| 558 |
+
incremental_state=None,
|
| 559 |
+
self_attn_mask=None,
|
| 560 |
+
self_attn_padding_mask=None,
|
| 561 |
+
attn_out=None,
|
| 562 |
+
reset_attn_weight=None,
|
| 563 |
+
**kwargs,
|
| 564 |
+
):
|
| 565 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
| 566 |
+
if layer_norm_training is not None:
|
| 567 |
+
self.layer_norm1.training = layer_norm_training
|
| 568 |
+
self.layer_norm2.training = layer_norm_training
|
| 569 |
+
self.layer_norm3.training = layer_norm_training
|
| 570 |
+
residual = x
|
| 571 |
+
x = self.layer_norm1(x)
|
| 572 |
+
x, _ = self.self_attn(
|
| 573 |
+
query=x,
|
| 574 |
+
key=x,
|
| 575 |
+
value=x,
|
| 576 |
+
key_padding_mask=self_attn_padding_mask,
|
| 577 |
+
incremental_state=incremental_state,
|
| 578 |
+
attn_mask=self_attn_mask
|
| 579 |
+
)
|
| 580 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 581 |
+
x = residual + x
|
| 582 |
+
|
| 583 |
+
attn_logits = None
|
| 584 |
+
if encoder_out is not None or attn_out is not None:
|
| 585 |
+
residual = x
|
| 586 |
+
x = self.layer_norm2(x)
|
| 587 |
+
if encoder_out is not None:
|
| 588 |
+
x, attn = self.encoder_attn(
|
| 589 |
+
query=x,
|
| 590 |
+
key=encoder_out,
|
| 591 |
+
value=encoder_out,
|
| 592 |
+
key_padding_mask=encoder_padding_mask,
|
| 593 |
+
incremental_state=incremental_state,
|
| 594 |
+
static_kv=True,
|
| 595 |
+
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
|
| 596 |
+
'enc_dec_attn_constraint_mask'),
|
| 597 |
+
reset_attn_weight=reset_attn_weight
|
| 598 |
+
)
|
| 599 |
+
attn_logits = attn[1]
|
| 600 |
+
elif attn_out is not None:
|
| 601 |
+
x = self.encoder_attn.in_proj_v(attn_out)
|
| 602 |
+
if encoder_out is not None or attn_out is not None:
|
| 603 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 604 |
+
x = residual + x
|
| 605 |
+
|
| 606 |
+
residual = x
|
| 607 |
+
x = self.layer_norm3(x)
|
| 608 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
| 609 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 610 |
+
x = residual + x
|
| 611 |
+
return x, attn_logits
|
| 612 |
+
|
| 613 |
+
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
|
| 614 |
+
self.encoder_attn.clear_buffer(incremental_state)
|
| 615 |
+
self.ffn.clear_buffer(incremental_state)
|
| 616 |
+
|
| 617 |
+
def set_buffer(self, name, tensor, incremental_state):
|
| 618 |
+
return set_incremental_state(self, incremental_state, name, tensor)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class TransformerEncoderLayer(nn.Module):
|
| 622 |
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.hidden_size = hidden_size
|
| 625 |
+
self.dropout = dropout
|
| 626 |
+
self.num_heads = num_heads
|
| 627 |
+
self.op = EncSALayer(
|
| 628 |
+
hidden_size, num_heads, dropout=dropout,
|
| 629 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
| 630 |
+
kernel_size=kernel_size)
|
| 631 |
+
|
| 632 |
+
def forward(self, x, **kwargs):
|
| 633 |
+
return self.op(x, **kwargs)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class TransformerDecoderLayer(nn.Module):
|
| 637 |
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
|
| 638 |
+
super().__init__()
|
| 639 |
+
self.hidden_size = hidden_size
|
| 640 |
+
self.dropout = dropout
|
| 641 |
+
self.num_heads = num_heads
|
| 642 |
+
self.op = DecSALayer(
|
| 643 |
+
hidden_size, num_heads, dropout=dropout,
|
| 644 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
| 645 |
+
kernel_size=kernel_size)
|
| 646 |
+
|
| 647 |
+
def forward(self, x, **kwargs):
|
| 648 |
+
return self.op(x, **kwargs)
|
| 649 |
+
|
| 650 |
+
def clear_buffer(self, *args):
|
| 651 |
+
return self.op.clear_buffer(*args)
|
| 652 |
+
|
| 653 |
+
def set_buffer(self, *args):
|
| 654 |
+
return self.op.set_buffer(*args)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class FFTBlocks(nn.Module):
|
| 658 |
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
|
| 659 |
+
num_heads=2, use_pos_embed=True, use_last_norm=True,
|
| 660 |
+
use_pos_embed_alpha=True):
|
| 661 |
+
super().__init__()
|
| 662 |
+
self.num_layers = num_layers
|
| 663 |
+
embed_dim = self.hidden_size = hidden_size
|
| 664 |
+
self.dropout = dropout
|
| 665 |
+
self.use_pos_embed = use_pos_embed
|
| 666 |
+
self.use_last_norm = use_last_norm
|
| 667 |
+
if use_pos_embed:
|
| 668 |
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
| 669 |
+
self.padding_idx = 0
|
| 670 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
| 671 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
| 672 |
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
self.layers = nn.ModuleList([])
|
| 676 |
+
self.layers.extend([
|
| 677 |
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
| 678 |
+
kernel_size=ffn_kernel_size, num_heads=num_heads)
|
| 679 |
+
for _ in range(self.num_layers)
|
| 680 |
+
])
|
| 681 |
+
if self.use_last_norm:
|
| 682 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
| 683 |
+
else:
|
| 684 |
+
self.layer_norm = None
|
| 685 |
+
|
| 686 |
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
| 687 |
+
"""
|
| 688 |
+
:param x: [B, T, C]
|
| 689 |
+
:param padding_mask: [B, T]
|
| 690 |
+
:return: [B, T, C] or [L, B, T, C]
|
| 691 |
+
"""
|
| 692 |
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
| 693 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
| 694 |
+
if self.use_pos_embed:
|
| 695 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
| 696 |
+
x = x + positions
|
| 697 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 698 |
+
# B x T x C -> T x B x C
|
| 699 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
| 700 |
+
hiddens = []
|
| 701 |
+
for layer in self.layers:
|
| 702 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
| 703 |
+
hiddens.append(x)
|
| 704 |
+
if self.use_last_norm:
|
| 705 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
| 706 |
+
if return_hiddens:
|
| 707 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
| 708 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
| 709 |
+
else:
|
| 710 |
+
x = x.transpose(0, 1) # [B, T, C]
|
| 711 |
+
return x
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class FastSpeechEncoder(FFTBlocks):
|
| 715 |
+
def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2,
|
| 716 |
+
dropout=0.0):
|
| 717 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
|
| 718 |
+
use_pos_embed=False, dropout=dropout) # use_pos_embed_alpha for compatibility
|
| 719 |
+
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
|
| 720 |
+
self.embed_scale = math.sqrt(hidden_size)
|
| 721 |
+
self.padding_idx = 0
|
| 722 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
| 723 |
+
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
def forward(self, txt_tokens, attn_mask=None):
|
| 727 |
+
"""
|
| 728 |
+
|
| 729 |
+
:param txt_tokens: [B, T]
|
| 730 |
+
:return: {
|
| 731 |
+
'encoder_out': [B x T x C]
|
| 732 |
+
}
|
| 733 |
+
"""
|
| 734 |
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
| 735 |
+
x = self.forward_embedding(txt_tokens) # [B, T, H]
|
| 736 |
+
if self.num_layers > 0:
|
| 737 |
+
x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
|
| 738 |
+
return x
|
| 739 |
+
|
| 740 |
+
def forward_embedding(self, txt_tokens):
|
| 741 |
+
# embed tokens and positions
|
| 742 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
| 743 |
+
positions = self.embed_positions(txt_tokens)
|
| 744 |
+
x = x + positions
|
| 745 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 746 |
+
return x
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class FastSpeechDecoder(FFTBlocks):
|
| 750 |
+
def __init__(self, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2):
|
| 751 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
|
preprocess/tools/note_transcription/modules/commons/wavenet.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from packaging import version
|
| 4 |
+
|
| 5 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 6 |
+
n_channels_int = n_channels[0]
|
| 7 |
+
in_act = input_a + input_b
|
| 8 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 9 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 10 |
+
acts = t_act * s_act
|
| 11 |
+
return acts
|
| 12 |
+
|
| 13 |
+
jit_fused_add_tanh_sigmoid_multiply = fused_add_tanh_sigmoid_multiply
|
| 14 |
+
|
| 15 |
+
def script_function():
|
| 16 |
+
if version.parse(torch.__version__) >= version.parse('2.0'):
|
| 17 |
+
global jit_fused_add_tanh_sigmoid_multiply
|
| 18 |
+
jit_fused_add_tanh_sigmoid_multiply = torch.jit.script(fused_add_tanh_sigmoid_multiply)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class WN(torch.nn.Module):
|
| 22 |
+
def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0,
|
| 23 |
+
p_dropout=0, share_cond_layers=False, is_BTC=False):
|
| 24 |
+
super(WN, self).__init__()
|
| 25 |
+
assert (kernel_size % 2 == 1)
|
| 26 |
+
assert (hidden_size % 2 == 0)
|
| 27 |
+
self.is_BTC = is_BTC
|
| 28 |
+
self.hidden_size = hidden_size
|
| 29 |
+
self.kernel_size = kernel_size
|
| 30 |
+
self.dilation_rate = dilation_rate
|
| 31 |
+
self.n_layers = n_layers
|
| 32 |
+
self.gin_channels = c_cond
|
| 33 |
+
self.p_dropout = p_dropout
|
| 34 |
+
self.share_cond_layers = share_cond_layers
|
| 35 |
+
|
| 36 |
+
self.in_layers = torch.nn.ModuleList()
|
| 37 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
| 38 |
+
self.drop = nn.Dropout(p_dropout)
|
| 39 |
+
|
| 40 |
+
if c_cond != 0 and not share_cond_layers:
|
| 41 |
+
cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1)
|
| 42 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
| 43 |
+
|
| 44 |
+
for i in range(n_layers):
|
| 45 |
+
dilation = dilation_rate ** i
|
| 46 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
| 47 |
+
in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size,
|
| 48 |
+
dilation=dilation, padding=padding)
|
| 49 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
| 50 |
+
self.in_layers.append(in_layer)
|
| 51 |
+
|
| 52 |
+
# last one is not necessary
|
| 53 |
+
if i < n_layers - 1:
|
| 54 |
+
res_skip_channels = 2 * hidden_size
|
| 55 |
+
else:
|
| 56 |
+
res_skip_channels = hidden_size
|
| 57 |
+
|
| 58 |
+
res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1)
|
| 59 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
| 60 |
+
self.res_skip_layers.append(res_skip_layer)
|
| 61 |
+
|
| 62 |
+
script_function()
|
| 63 |
+
|
| 64 |
+
def forward(self, x, nonpadding=None, cond=None):
|
| 65 |
+
if self.is_BTC:
|
| 66 |
+
x = x.transpose(1, 2)
|
| 67 |
+
cond = cond.transpose(1, 2) if cond is not None else None
|
| 68 |
+
nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None
|
| 69 |
+
if nonpadding is None:
|
| 70 |
+
nonpadding = 1
|
| 71 |
+
output = torch.zeros_like(x)
|
| 72 |
+
n_channels_tensor = torch.IntTensor([self.hidden_size])
|
| 73 |
+
|
| 74 |
+
if cond is not None and not self.share_cond_layers:
|
| 75 |
+
cond = self.cond_layer(cond)
|
| 76 |
+
|
| 77 |
+
for i in range(self.n_layers):
|
| 78 |
+
x_in = self.in_layers[i](x)
|
| 79 |
+
x_in = self.drop(x_in)
|
| 80 |
+
if cond is not None:
|
| 81 |
+
cond_offset = i * 2 * self.hidden_size
|
| 82 |
+
cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :]
|
| 83 |
+
else:
|
| 84 |
+
cond_l = torch.zeros_like(x_in)
|
| 85 |
+
|
| 86 |
+
if version.parse(torch.__version__) >= version.parse('2.0'):
|
| 87 |
+
acts = jit_fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
|
| 88 |
+
else:
|
| 89 |
+
acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
|
| 90 |
+
|
| 91 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
| 92 |
+
if i < self.n_layers - 1:
|
| 93 |
+
x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding
|
| 94 |
+
output = output + res_skip_acts[:, self.hidden_size:, :]
|
| 95 |
+
else:
|
| 96 |
+
output = output + res_skip_acts
|
| 97 |
+
output = output * nonpadding
|
| 98 |
+
if self.is_BTC:
|
| 99 |
+
output = output.transpose(1, 2)
|
| 100 |
+
return output
|
| 101 |
+
|
| 102 |
+
def remove_weight_norm(self):
|
| 103 |
+
def remove_weight_norm(m):
|
| 104 |
+
try:
|
| 105 |
+
nn.utils.remove_weight_norm(m)
|
| 106 |
+
except ValueError: # this module didn't have weight norm
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
self.apply(remove_weight_norm)
|
preprocess/tools/note_transcription/modules/pe/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Pitch extractor modules for ROSVOT."""
|
preprocess/tools/note_transcription/modules/pe/rmvpe/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .constants import *
|
| 2 |
+
from .model import E2E0
|
| 3 |
+
from .utils import to_local_average_f0, to_viterbi_f0
|
| 4 |
+
from .inference import RMVPE
|
| 5 |
+
from .spec import MelSpectrogram
|
| 6 |
+
from .extractor import extract
|