JeffreyZhou798 commited on
Commit
7ee408c
·
verified ·
1 Parent(s): 65e3901

Upload 111 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. preprocess/README.md +155 -0
  2. preprocess/pipeline.py +161 -0
  3. preprocess/requirements.txt +34 -0
  4. preprocess/tools/__init__.py +53 -0
  5. preprocess/tools/f0_extraction.py +527 -0
  6. preprocess/tools/g2p.py +72 -0
  7. preprocess/tools/lyric_transcription.py +283 -0
  8. preprocess/tools/midi_editor/README.md +170 -0
  9. preprocess/tools/midi_editor/README_CN.md +170 -0
  10. preprocess/tools/midi_editor/eslint.config.js +23 -0
  11. preprocess/tools/midi_editor/index.html +13 -0
  12. preprocess/tools/midi_editor/package-lock.json +0 -0
  13. preprocess/tools/midi_editor/package.json +39 -0
  14. preprocess/tools/midi_editor/postcss.config.js +6 -0
  15. preprocess/tools/midi_editor/public/vite.svg +1 -0
  16. preprocess/tools/midi_editor/src/App.css +834 -0
  17. preprocess/tools/midi_editor/src/App.tsx +675 -0
  18. preprocess/tools/midi_editor/src/components/AudioTrack.tsx +182 -0
  19. preprocess/tools/midi_editor/src/components/LyricTable.tsx +301 -0
  20. preprocess/tools/midi_editor/src/components/PianoRoll.tsx +704 -0
  21. preprocess/tools/midi_editor/src/constants.ts +8 -0
  22. preprocess/tools/midi_editor/src/i18n.ts +196 -0
  23. preprocess/tools/midi_editor/src/index.css +37 -0
  24. preprocess/tools/midi_editor/src/lib/midi.ts +224 -0
  25. preprocess/tools/midi_editor/src/main.tsx +10 -0
  26. preprocess/tools/midi_editor/src/store/useMidiStore.ts +78 -0
  27. preprocess/tools/midi_editor/src/types.ts +17 -0
  28. preprocess/tools/midi_editor/tailwind.config.js +33 -0
  29. preprocess/tools/midi_editor/tsconfig.app.json +28 -0
  30. preprocess/tools/midi_editor/tsconfig.json +7 -0
  31. preprocess/tools/midi_editor/tsconfig.node.json +26 -0
  32. preprocess/tools/midi_editor/vite.config.ts +7 -0
  33. preprocess/tools/midi_parser.py +598 -0
  34. preprocess/tools/note_transcription/__init__.py +0 -0
  35. preprocess/tools/note_transcription/model.py +531 -0
  36. preprocess/tools/note_transcription/modules/__init__.py +1 -0
  37. preprocess/tools/note_transcription/modules/commons/__init__.py +1 -0
  38. preprocess/tools/note_transcription/modules/commons/conformer/__init__.py +1 -0
  39. preprocess/tools/note_transcription/modules/commons/conformer/conformer.py +96 -0
  40. preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py +113 -0
  41. preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py +198 -0
  42. preprocess/tools/note_transcription/modules/commons/conformer/layers.py +260 -0
  43. preprocess/tools/note_transcription/modules/commons/conv.py +175 -0
  44. preprocess/tools/note_transcription/modules/commons/layers.py +85 -0
  45. preprocess/tools/note_transcription/modules/commons/rel_transformer.py +378 -0
  46. preprocess/tools/note_transcription/modules/commons/rnn.py +261 -0
  47. preprocess/tools/note_transcription/modules/commons/transformer.py +751 -0
  48. preprocess/tools/note_transcription/modules/commons/wavenet.py +109 -0
  49. preprocess/tools/note_transcription/modules/pe/__init__.py +1 -0
  50. preprocess/tools/note_transcription/modules/pe/rmvpe/__init__.py +6 -0
preprocess/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎵 SoulX-Singer-Preprocess
2
+
3
+ This part offers a comprehensive **singing transcription and editing toolkit** for real-world music audio. It provides the pipeline from vocal extraction to high-level annotation optimized for SVS dataset construction. By integrating state-of-the-art models, it transforms raw audio into structured singing data and supports the **customizable creation and editing of lyric-aligned MIDI scores**.
4
+
5
+
6
+ ## ✨ Features
7
+
8
+ The toolkit includes the following core modules:
9
+
10
+ - 🎤 **Clean Dry Vocal Extraction**
11
+ Extracts the lead vocal track from polyphonic music audio and dereverberation.
12
+
13
+ - 📝 **Lyrics Transcription**
14
+ Automatically transcribes lyrics from clean vocal.
15
+
16
+ - 🎶 **Note Transcription**
17
+ Converts singing voice into note-level representations for SVS.
18
+
19
+ - 🎼 **MIDI Editor**
20
+ Supports customizable creation and editing of MIDI scores integrated with lyrics.
21
+
22
+
23
+ ## 🔧 Python Environment
24
+
25
+ Before running the pipeline, set up the Python environment as follows:
26
+
27
+ 1. **Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
28
+
29
+ 2. **Activate or create a conda environment** (recommended Python 3.10):
30
+
31
+ - If you already have the `soulxsinger` environment:
32
+
33
+ ```bash
34
+ conda activate soulxsinger
35
+ ```
36
+
37
+ - Otherwise, create it first:
38
+
39
+ ```bash
40
+ conda create -n soulxsinger -y python=3.10
41
+ conda activate soulxsinger
42
+ ```
43
+
44
+ 3. **Install dependencies** from the `preprocess` directory:
45
+
46
+ ```bash
47
+ cd preprocess
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ ## 📁 Data Preparation
52
+
53
+ Before running the pipeline, prepare the following inputs:
54
+
55
+ - **Prompt audio**
56
+ Reference audio that provides timbre and style
57
+
58
+ - **Target audio**
59
+ Original vocal or music audio to be processed and transcribed.
60
+
61
+ Configure the corresponding parameters in:
62
+
63
+ ```
64
+ example/preprocess.sh
65
+ ```
66
+
67
+ Typical configuration includes:
68
+ - Input / output paths
69
+ - Module enable switches
70
+
71
+ ## 🚀 Usage
72
+
73
+ After configuring `preprocess.sh`, run the transcription pipeline with:
74
+
75
+ ```bash
76
+ bash example/preprocess.sh
77
+ ```
78
+
79
+ The script will automatically execute the following steps:
80
+
81
+ 1. **Vocal separation and dereverberation**
82
+ 2. **F0 extraction and voice activity detection (VAD)**
83
+ 3. **Lyrics transcription**
84
+ 4. **Note transcription**
85
+
86
+ ---
87
+
88
+ After the pipeline completes, you will obtain **SoulX-Singer–style metadata** that can be directly used for Singing Voice Synthesis (SVS).
89
+
90
+ **Output paths:**
91
+ - The final metadata (**JSON file**) is written **in the same directory as your input audio**, with the **same filename** (e.g. `audio.mp3` → `audio.json`)
92
+ - All **intermediate results** (separated vocal and accompaniment, F0, VAD outputs, etc.) are also saved under the configured **`save_dir`**.
93
+
94
+ ⚠️ **Important Note**
95
+
96
+ Transcription errors—especially in **lyrics** and **note annotations**—can significantly affect the final SVS quality. We **strongly recommend manually reviewing and correcting** the generated metadata before inference.
97
+
98
+ To support this, we provide a **MIDI Editor** for editing lyrics, phoneme alignment, note pitches, and durations. The workflow is:
99
+
100
+ **Export metadata to MIDI** → edit in the MIDI Editor → **Import edited MIDI back to metadata** for SVS.
101
+
102
+ ---
103
+
104
+ #### Step 1: Metadata → MIDI (for editing)
105
+
106
+ Convert SoulX-Singer metadata to a MIDI file so you can open it in the MIDI Editor:
107
+
108
+ ```bash
109
+ preprocess_root=example/transcriptions/music
110
+
111
+ python -m preprocess.tools.midi_parser \
112
+ --meta2midi \
113
+ --meta "${preprocess_root}/metadata.json" \
114
+ --midi "${preprocess_root}/vocal.mid"
115
+ ```
116
+
117
+ #### Step 2: Edit in the MIDI Editor
118
+
119
+ Open the MIDI Editor (see [MIDI Editor Tutorial](tools/midi_editor/README.md)), load `vocal.mid`, and correct lyrics, pitches, or durations as needed. Save the result as e.g. `vocal_edited.mid`.
120
+
121
+ #### Step 3: MIDI → Metadata (for SoulX-Singer inference)
122
+
123
+ Convert the edited MIDI back into SoulX-Singer-style metadata (and cut wavs) for SVS:
124
+
125
+ ```bash
126
+ python -m preprocess.tools.midi_parser \
127
+ --midi2meta \
128
+ --midi "${preprocess_root}/vocal_edited.mid" \
129
+ --meta "${preprocess_root}/edit_metadata.json" \
130
+ --vocal "${preprocess_root}/vocal.wav" \
131
+ ```
132
+
133
+ Use `edit_metadata.json` (and the wavs under `edit_cut_wavs`) as the target metadata in your inference pipeline.
134
+
135
+
136
+ ## 🔗 References & Dependencies
137
+
138
+ This project builds upon the following excellent open-source works:
139
+
140
+ ### 🎧 Vocal Separation & Dereverberation
141
+ - [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
142
+ - [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
143
+ - [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
144
+
145
+ ### 🎼 F0 Extraction
146
+ - [RMVPE](https://github.com/Dream-High/RMVPE)
147
+
148
+ ### 📝 Lyrics Transcription (ASR)
149
+ - [Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
150
+ - [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
151
+
152
+ ### 🎶 Note Transcription
153
+ - [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
154
+
155
+ We sincerely thank the authors of these repositories for their exceptional open-source contributions, which have been fundamental to the development of this toolkit.
preprocess/pipeline.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shutil
3
+ import soundfile as sf
4
+ from pathlib import Path
5
+ import librosa
6
+
7
+ from preprocess.utils import convert_metadata, merge_short_segments
8
+
9
+ from preprocess.tools import (
10
+ F0Extractor,
11
+ VocalDetector,
12
+ VocalSeparator,
13
+ NoteTranscriber,
14
+ LyricTranscriber,
15
+ )
16
+
17
+
18
+ class PreprocessPipeline:
19
+ def __init__(self, device: str, language: str, save_dir: str, vocal_sep: bool = True, max_merge_duration: int = 60000, midi_transcribe: bool = True):
20
+ self.device = device
21
+ self.language = language
22
+ self.save_dir = save_dir
23
+ self.vocal_sep = vocal_sep
24
+ self.max_merge_duration = max_merge_duration
25
+ self.midi_transcribe = midi_transcribe
26
+
27
+ if vocal_sep:
28
+ self.vocal_separator = VocalSeparator(
29
+ sep_model_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt",
30
+ sep_config_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/config_karaoke_becruily.yaml",
31
+ der_model_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
32
+ der_config_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml",
33
+ device=device
34
+ )
35
+ else:
36
+ self.vocal_separator = None
37
+ self.f0_extractor = F0Extractor(
38
+ model_path="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
39
+ device=device,
40
+ )
41
+ if self.midi_transcribe:
42
+ self.vocal_detector = VocalDetector(
43
+ cut_wavs_output_dir= f"{save_dir}/cut_wavs",
44
+ )
45
+ self.lyric_transcriber = LyricTranscriber(
46
+ zh_model_path="pretrained_models/SoulX-Singer-Preprocess/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
47
+ en_model_path="pretrained_models/SoulX-Singer-Preprocess/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
48
+ device=device
49
+ )
50
+ self.note_transcriber = NoteTranscriber(
51
+ rosvot_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rosvot/model.pt",
52
+ rwbd_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rwbd/model.pt",
53
+ device=device
54
+ )
55
+ else:
56
+ self.vocal_detector = None
57
+ self.lyric_transcriber = None
58
+ self.note_transcriber = None
59
+
60
+ def run(
61
+ self,
62
+ audio_path: str,
63
+ vocal_sep: bool = None,
64
+ max_merge_duration: int = None,
65
+ language: str = None,
66
+ ) -> None:
67
+ vocal_sep = self.vocal_sep if vocal_sep is None else vocal_sep
68
+ max_merge_duration = self.max_merge_duration if max_merge_duration is None else max_merge_duration
69
+ language = self.language if language is None else language
70
+ output_dir = Path(self.save_dir)
71
+ output_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ if vocal_sep:
74
+ # Perform vocal/accompaniment separation
75
+ sep = self.vocal_separator.process(audio_path)
76
+ vocal = sep.vocals_dereverbed.T
77
+ acc = sep.accompaniment.T
78
+ sample_rate = sep.sample_rate
79
+
80
+ vocal_path = output_dir / "vocal.wav"
81
+ acc_path = output_dir / "acc.wav"
82
+ sf.write(vocal_path, vocal, sample_rate)
83
+ sf.write(acc_path, acc, sample_rate)
84
+ else:
85
+ # Use the original audio as vocal source (no separation)
86
+ vocal, sample_rate = librosa.load(audio_path, sr=None, mono=True)
87
+ vocal_path = output_dir / "vocal.wav"
88
+ sf.write(vocal_path, vocal, sample_rate)
89
+
90
+ vocal_f0 = self.f0_extractor.process(str(vocal_path), f0_path=str(vocal_path).replace(".wav", "_f0.npy"))
91
+
92
+ if not self.midi_transcribe or self.vocal_detector is None or self.lyric_transcriber is None or self.note_transcriber is None:
93
+ return
94
+
95
+ segments = self.vocal_detector.process(str(vocal_path), f0=vocal_f0)
96
+
97
+ metadata = []
98
+ for seg in segments:
99
+ self.f0_extractor.process(seg["wav_fn"], f0_path=seg["wav_fn"].replace(".wav", "_f0.npy"))
100
+ words, durs = self.lyric_transcriber.process(
101
+ seg["wav_fn"], language
102
+ )
103
+ seg["words"] = words
104
+ seg["word_durs"] = durs
105
+ seg["language"] = language
106
+ metadata.append(
107
+ self.note_transcriber.process(seg, segment_info=seg)
108
+ )
109
+
110
+ merged = merge_short_segments(
111
+ vocal,
112
+ sample_rate,
113
+ metadata,
114
+ output_dir / "long_cut_wavs",
115
+ max_duration_ms=max_merge_duration,
116
+ )
117
+
118
+ final_metadata = []
119
+
120
+ for item in merged:
121
+ self.f0_extractor.process(item.wav_fn, f0_path=item.wav_fn.replace(".wav", "_f0.npy"))
122
+ final_metadata.append(convert_metadata(item))
123
+
124
+ with open(output_dir / "metadata.json", "w", encoding="utf-8") as f:
125
+ json.dump(final_metadata, f, ensure_ascii=False, indent=2)
126
+
127
+ shutil.copy(output_dir / "metadata.json", audio_path.replace(".wav", ".json").replace(".mp3", ".json").replace(".flac", ".json"))
128
+
129
+
130
+ def main(args):
131
+ pipeline = PreprocessPipeline(
132
+ device=args.device,
133
+ language=args.language,
134
+ save_dir=args.save_dir,
135
+ vocal_sep=args.vocal_sep,
136
+ max_merge_duration=args.max_merge_duration,
137
+ midi_transcribe=args.midi_transcribe,
138
+ )
139
+ pipeline.run(
140
+ audio_path=args.audio_path,
141
+ language=args.language,
142
+ )
143
+
144
+
145
+ if __name__ == "__main__":
146
+ import argparse
147
+
148
+ parser = argparse.ArgumentParser()
149
+ parser.add_argument("--audio_path", type=str, required=True, help="Path to the input audio file")
150
+ parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the output files")
151
+ parser.add_argument("--language", type=str, default="Mandarin", help="Language of the audio")
152
+ parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the models on")
153
+ parser.add_argument("--vocal_sep", type=str, default="True", help="Whether to perform vocal separation")
154
+ parser.add_argument("--max_merge_duration", type=int, default=60000, help="Maximum merged segment duration in milliseconds")
155
+ parser.add_argument("--midi_transcribe", type=str, default="True", help="Whether to do MIDI transcription")
156
+ args = parser.parse_args()
157
+
158
+ args.vocal_sep = args.vocal_sep.lower() == "true"
159
+ args.midi_transcribe = args.midi_transcribe.lower() == "true"
160
+
161
+ main(args)
preprocess/requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ beartype==0.22.9
2
+ einops==0.8.2
3
+ funasr==1.3.0
4
+ g2p_en==2.1.0
5
+ g2pM==0.1.2.5
6
+ librosa==0.11.0
7
+ loralib==0.1.2
8
+ matplotlib==3.10.8
9
+ mido==1.3.3
10
+ ml_collections==1.1.0
11
+ nemo_toolkit==2.6.1
12
+ nltk==3.9.2
13
+ numba==0.63.1
14
+ numpy==2.2.6
15
+ omegaconf==2.3.0
16
+ packaging==24.2
17
+ praat-parselmouth==0.4.7
18
+ pretty_midi==0.2.11
19
+ pyloudnorm==0.2.0
20
+ pyworld==0.3.5
21
+ rotary_embedding_torch==0.8.9
22
+ sageattention==1.0.6
23
+ scikit_learn==1.7.2
24
+ scipy==1.15.3
25
+ six==1.17.0
26
+ setuptools==81.0.0
27
+ scikit_image==0.25.2
28
+ soundfile==0.13.1
29
+ ToJyutping==3.2.0
30
+ torch==2.10.0
31
+ torchaudio==2.10.0
32
+ tqdm==4.67.1
33
+ wandb==0.24.2
34
+ webrtcvad==2.0.10
preprocess/tools/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocess tools.
2
+
3
+ This package provides a thin, stable import surface for common preprocess components.
4
+
5
+ Examples:
6
+ from preprocess.tools import (
7
+ F0Extractor,
8
+ PitchExtractor,
9
+ VocalDetectionModel,
10
+ VocalSeparationModel,
11
+ VocalExtractionModel,
12
+ NoteTranscriptionModel,
13
+ LyricTranscriptionModel,
14
+ )
15
+
16
+ Note:
17
+ Keep these imports lightweight. If a tool pulls heavy dependencies at import time,
18
+ consider switching to lazy imports.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ # Core tools
24
+ from .f0_extraction import F0Extractor
25
+ from .vocal_detection import VocalDetector
26
+
27
+ # Some tools may live outside this package in different layouts across branches.
28
+ # Keep the public surface stable while avoiding hard import failures.
29
+ try:
30
+ from .vocal_separation.model import VocalSeparator # type: ignore
31
+ except Exception: # pragma: no cover
32
+ VocalSeparator = None # type: ignore
33
+
34
+ try:
35
+ from .note_transcription.model import NoteTranscriber # type: ignore
36
+ except Exception: # pragma: no cover
37
+ NoteTranscriber = None # type: ignore
38
+ try:
39
+ from .lyric_transcription import LyricTranscriber
40
+ except Exception: # pragma: no cover
41
+ LyricTranscriber = None # type: ignore
42
+
43
+ __all__ = [
44
+ "F0Extractor",
45
+ "VocalDetector",
46
+ ]
47
+
48
+ if VocalSeparator is not None:
49
+ __all__.append("VocalSeparator")
50
+ if LyricTranscriber is not None:
51
+ __all__.append("LyricTranscriber")
52
+ if NoteTranscriber is not None:
53
+ __all__.append("NoteTranscriber")
preprocess/tools/f0_extraction.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dream-High/RMVPE
2
+ import math
3
+ import time
4
+ import librosa
5
+ import numpy as np
6
+ from librosa.filters import mel
7
+ from scipy.interpolate import interp1d
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class BiGRU(nn.Module):
17
+ def __init__(self, input_features, hidden_features, num_layers):
18
+ super(BiGRU, self).__init__()
19
+ self.gru = nn.GRU(
20
+ input_features,
21
+ hidden_features,
22
+ num_layers=num_layers,
23
+ batch_first=True,
24
+ bidirectional=True,
25
+ )
26
+
27
+ def forward(self, x):
28
+ return self.gru(x)[0]
29
+
30
+
31
+ class ConvBlockRes(nn.Module):
32
+ def __init__(self, in_channels, out_channels, momentum=0.01):
33
+ super(ConvBlockRes, self).__init__()
34
+ self.conv = nn.Sequential(
35
+ nn.Conv2d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=(3, 3),
39
+ stride=(1, 1),
40
+ padding=(1, 1),
41
+ bias=False,
42
+ ),
43
+ nn.BatchNorm2d(out_channels, momentum=momentum),
44
+ nn.ReLU(),
45
+ nn.Conv2d(
46
+ in_channels=out_channels,
47
+ out_channels=out_channels,
48
+ kernel_size=(3, 3),
49
+ stride=(1, 1),
50
+ padding=(1, 1),
51
+ bias=False,
52
+ ),
53
+ nn.BatchNorm2d(out_channels, momentum=momentum),
54
+ nn.ReLU(),
55
+ )
56
+ if in_channels != out_channels:
57
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
58
+
59
+ def forward(self, x):
60
+ if not hasattr(self, "shortcut"):
61
+ return self.conv(x) + x
62
+ else:
63
+ return self.conv(x) + self.shortcut(x)
64
+
65
+
66
+ class ResEncoderBlock(nn.Module):
67
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
68
+ super(ResEncoderBlock, self).__init__()
69
+ self.n_blocks = n_blocks
70
+ self.conv = nn.ModuleList()
71
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
72
+ for i in range(n_blocks - 1):
73
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
74
+ self.kernel_size = kernel_size
75
+ if self.kernel_size is not None:
76
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
77
+
78
+ def forward(self, x):
79
+ for conv in self.conv:
80
+ x = conv(x)
81
+ if self.kernel_size is not None:
82
+ return x, self.pool(x)
83
+ else:
84
+ return x
85
+
86
+
87
+ class Encoder(nn.Module):
88
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
89
+ super(Encoder, self).__init__()
90
+ self.n_encoders = n_encoders
91
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
92
+ self.layers = nn.ModuleList()
93
+ self.latent_channels = []
94
+ for i in range(self.n_encoders):
95
+ self.layers.append(
96
+ ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)
97
+ )
98
+ self.latent_channels.append([out_channels, in_size])
99
+ in_channels = out_channels
100
+ out_channels *= 2
101
+ in_size //= 2
102
+ self.out_size = in_size
103
+ self.out_channel = out_channels
104
+
105
+ def forward(self, x):
106
+ concat_tensors = []
107
+ x = self.bn(x)
108
+ for layer in self.layers:
109
+ t, x = layer(x)
110
+ concat_tensors.append(t)
111
+ return x, concat_tensors
112
+
113
+
114
+ class Intermediate(nn.Module):
115
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
116
+ super(Intermediate, self).__init__()
117
+ self.n_inters = n_inters
118
+ self.layers = nn.ModuleList()
119
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
120
+ for i in range(self.n_inters - 1):
121
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
122
+
123
+ def forward(self, x):
124
+ for layer in self.layers:
125
+ x = layer(x)
126
+ return x
127
+
128
+
129
+ class ResDecoderBlock(nn.Module):
130
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
131
+ super(ResDecoderBlock, self).__init__()
132
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
133
+ self.n_blocks = n_blocks
134
+ self.conv1 = nn.Sequential(
135
+ nn.ConvTranspose2d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ kernel_size=(3, 3),
139
+ stride=stride,
140
+ padding=(1, 1),
141
+ output_padding=out_padding,
142
+ bias=False,
143
+ ),
144
+ nn.BatchNorm2d(out_channels, momentum=momentum),
145
+ nn.ReLU(),
146
+ )
147
+ self.conv2 = nn.ModuleList()
148
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
149
+ for i in range(n_blocks - 1):
150
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
151
+
152
+ def forward(self, x, concat_tensor):
153
+ x = self.conv1(x)
154
+ x = torch.cat((x, concat_tensor), dim=1)
155
+ for conv2 in self.conv2:
156
+ x = conv2(x)
157
+ return x
158
+
159
+
160
+ class Decoder(nn.Module):
161
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
162
+ super(Decoder, self).__init__()
163
+ self.layers = nn.ModuleList()
164
+ self.n_decoders = n_decoders
165
+ for i in range(self.n_decoders):
166
+ out_channels = in_channels // 2
167
+ self.layers.append(
168
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
169
+ )
170
+ in_channels = out_channels
171
+
172
+ def forward(self, x, concat_tensors):
173
+ for i, layer in enumerate(self.layers):
174
+ x = layer(x, concat_tensors[-1 - i])
175
+ return x
176
+
177
+
178
+ class DeepUnet(nn.Module):
179
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
180
+ super(DeepUnet, self).__init__()
181
+ self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
182
+ self.intermediate = Intermediate(
183
+ self.encoder.out_channel // 2,
184
+ self.encoder.out_channel,
185
+ inter_layers,
186
+ n_blocks,
187
+ )
188
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
189
+
190
+ def forward(self, x):
191
+ x, concat_tensors = self.encoder(x)
192
+ x = self.intermediate(x)
193
+ x = self.decoder(x, concat_tensors)
194
+ return x
195
+
196
+
197
+ class E2E(nn.Module):
198
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
199
+ super(E2E, self).__init__()
200
+ self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
201
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
202
+ if n_gru:
203
+ self.fc = nn.Sequential(
204
+ BiGRU(3 * 128, 256, n_gru),
205
+ nn.Linear(512, 360),
206
+ nn.Dropout(0.25),
207
+ nn.Sigmoid(),
208
+ )
209
+ else:
210
+ self.fc = nn.Sequential(
211
+ nn.Linear(3 * 128, 360),
212
+ nn.Dropout(0.25),
213
+ nn.Sigmoid()
214
+ )
215
+
216
+ def forward(self, mel):
217
+ mel = mel.transpose(-1, -2).unsqueeze(1)
218
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
219
+ x = self.fc(x)
220
+ return x
221
+
222
+
223
+
224
+ class MelSpectrogram(torch.nn.Module):
225
+ def __init__(self, is_half, n_mel_channels, sampling_rate, win_length, hop_length,
226
+ n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
227
+ super().__init__()
228
+ n_fft = win_length if n_fft is None else n_fft
229
+ self.hann_window = {}
230
+ mel_basis = mel(
231
+ sr=sampling_rate,
232
+ n_fft=n_fft,
233
+ n_mels=n_mel_channels,
234
+ fmin=mel_fmin,
235
+ fmax=mel_fmax,
236
+ htk=True,
237
+ )
238
+ mel_basis = torch.from_numpy(mel_basis).float()
239
+ self.register_buffer("mel_basis", mel_basis)
240
+ self.n_fft = win_length if n_fft is None else n_fft
241
+ self.hop_length = hop_length
242
+ self.win_length = win_length
243
+ self.sampling_rate = sampling_rate
244
+ self.n_mel_channels = n_mel_channels
245
+ self.clamp = clamp
246
+ self.is_half = is_half
247
+
248
+ def forward(self, audio, keyshift=0, speed=1, center=True):
249
+ factor = 2 ** (keyshift / 12)
250
+ n_fft_new = int(np.round(self.n_fft * factor))
251
+ win_length_new = int(np.round(self.win_length * factor))
252
+ hop_length_new = int(np.round(self.hop_length * speed))
253
+
254
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
255
+ if keyshift_key not in self.hann_window:
256
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
257
+
258
+ fft = torch.stft(
259
+ audio,
260
+ n_fft=n_fft_new,
261
+ hop_length=hop_length_new,
262
+ win_length=win_length_new,
263
+ window=self.hann_window[keyshift_key],
264
+ center=center,
265
+ return_complex=True,
266
+ )
267
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
268
+
269
+ if keyshift != 0:
270
+ size = self.n_fft // 2 + 1
271
+ resize = magnitude.size(1)
272
+ if resize < size:
273
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
274
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
275
+
276
+ mel_output = torch.matmul(self.mel_basis, magnitude)
277
+ if self.is_half:
278
+ mel_output = mel_output.half()
279
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
280
+ return log_mel_spec
281
+
282
+
283
+
284
+ class RMVPE:
285
+ def __init__(self, model_path: str, is_half, device=None):
286
+ self.is_half = is_half
287
+ if device is None:
288
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
289
+ self.device = torch.device(device) if isinstance(device, str) else device
290
+
291
+ self.mel_extractor = MelSpectrogram(
292
+ is_half=is_half,
293
+ n_mel_channels=128,
294
+ sampling_rate=16000,
295
+ win_length=1024,
296
+ hop_length=160,
297
+ n_fft=None,
298
+ mel_fmin=30,
299
+ mel_fmax=8000
300
+ ).to(self.device)
301
+
302
+ model = E2E(n_blocks=4, n_gru=1, kernel_size=(2, 2))
303
+ ckpt = torch.load(model_path, map_location=self.device)
304
+ model.load_state_dict(ckpt)
305
+ model.eval()
306
+
307
+ if is_half:
308
+ model = model.half()
309
+ else:
310
+ model = model.float()
311
+
312
+ self.model = model.to(self.device)
313
+
314
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
315
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
316
+
317
+ def mel2hidden(self, mel):
318
+ with torch.no_grad():
319
+ n_frames = mel.shape[-1]
320
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
321
+ if n_pad > 0:
322
+ mel = F.pad(mel, (0, n_pad), mode="constant")
323
+ mel = mel.half() if self.is_half else mel.float()
324
+ hidden = self.model(mel)
325
+ return hidden[:, :n_frames]
326
+
327
+ def decode(self, hidden, thred=0.03):
328
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
329
+ f0 = 10 * (2 ** (cents_pred / 1200))
330
+ f0[f0 == 10] = 0
331
+ return f0
332
+
333
+ def infer_from_audio(self, audio, thred=0.03):
334
+ if not torch.is_tensor(audio):
335
+ audio = torch.from_numpy(audio)
336
+
337
+ mel = self.mel_extractor(audio.float().to(self.device).unsqueeze(0), center=True)
338
+ hidden = self.mel2hidden(mel)
339
+ hidden = hidden.squeeze(0).cpu().numpy()
340
+
341
+ if self.is_half:
342
+ hidden = hidden.astype("float32")
343
+
344
+ f0 = self.decode(hidden, thred=thred)
345
+ return f0
346
+
347
+ def to_local_average_cents(self, salience, thred=0.05):
348
+ center = np.argmax(salience, axis=1)
349
+ salience = np.pad(salience, ((0, 0), (4, 4)))
350
+ center += 4
351
+
352
+ todo_salience = []
353
+ todo_cents_mapping = []
354
+ starts = center - 4
355
+ ends = center + 5
356
+
357
+ for idx in range(salience.shape[0]):
358
+ todo_salience.append(salience[:, starts[idx]:ends[idx]][idx])
359
+ todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
360
+
361
+ todo_salience = np.array(todo_salience)
362
+ todo_cents_mapping = np.array(todo_cents_mapping)
363
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
364
+ weight_sum = np.sum(todo_salience, 1)
365
+ devided = product_sum / weight_sum
366
+
367
+ maxx = np.max(salience, axis=1)
368
+ devided[maxx <= thred] = 0
369
+
370
+ return devided
371
+
372
+ class F0Extractor:
373
+ """Extract frame-level f0 from singing voice.
374
+
375
+ Wrapper around an RMVPE network that:
376
+ 1) loads the checkpoint once in ``__init__``
377
+ 2) exposes a simple :py:meth:`process` API and optionally saves ``*_f0.npy``.
378
+ """
379
+ def __init__(
380
+ self,
381
+ model_path: str,
382
+ device: str = "cpu",
383
+ *,
384
+ is_half: bool = False,
385
+ input_sr: int = 16000,
386
+ target_sr: int = 24000,
387
+ hop_size: int = 480,
388
+ max_duration: float = 300,
389
+ thred: float = 0.03,
390
+ verbose: bool = True,
391
+ ):
392
+ """Initialize the f0 extractor.
393
+
394
+ Args:
395
+ model_path: Path to RMVPE checkpoint.
396
+ device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
397
+ is_half: Whether to run the model in fp16.
398
+ input_sr: Input resample rate used by RMVPE frontend.
399
+ target_sr: Target sample rate for the output f0 grid.
400
+ hop_size: Target hop size for the output f0 grid.
401
+ max_duration: Max duration (seconds) for interpolation grid.
402
+ thred: Voicing threshold used when decoding salience.
403
+ verbose: Whether to print verbose logs.
404
+ """
405
+ self.model_path = model_path
406
+ self.input_sr = input_sr
407
+ self.target_sr = target_sr
408
+ self.hop_size = hop_size
409
+ self.max_duration = max_duration
410
+ self.thred = thred
411
+
412
+ self.verbose = verbose
413
+
414
+ self.model = RMVPE(model_path, is_half=is_half, device=device)
415
+
416
+ if self.verbose:
417
+ print(
418
+ "[f0 extraction] init success:",
419
+ f"device={device}",
420
+ f"model_path={model_path}",
421
+ f"is_half={is_half}",
422
+ f"input_sr={input_sr}",
423
+ f"target_sr={target_sr}",
424
+ f"hop_size={hop_size}",
425
+ f"thred={thred}",
426
+ )
427
+
428
+ @staticmethod
429
+ def interpolate_f0(
430
+ f0_16k: np.ndarray,
431
+ original_length: int,
432
+ original_sr: int,
433
+ *,
434
+ target_sr: int = 48000,
435
+ hop_size: int = 256,
436
+ max_duration: float = 20.0,
437
+ ) -> np.ndarray:
438
+ """Interpolate f0 from RMVPE's 16k hop grid to target mel hop grid."""
439
+ mel_target_sr = target_sr
440
+ mel_hop_size = hop_size
441
+ mel_max_duration = max_duration
442
+
443
+ batch_max_length = int(mel_max_duration * mel_target_sr / mel_hop_size)
444
+ duration_in_seconds = original_length / original_sr
445
+ effective_target_length = int(duration_in_seconds * mel_target_sr)
446
+ original_frames = math.ceil(effective_target_length / mel_hop_size)
447
+ target_frames = min(original_frames, batch_max_length)
448
+
449
+ rmvpe_hop = 160
450
+ t_16k = np.arange(len(f0_16k)) * (rmvpe_hop / 16000.0)
451
+ t_target = np.arange(target_frames) * (mel_hop_size / float(mel_target_sr))
452
+
453
+ if len(f0_16k) > 0:
454
+ f_interp = interp1d(
455
+ t_16k,
456
+ f0_16k,
457
+ kind="linear",
458
+ bounds_error=False,
459
+ fill_value=0.0,
460
+ assume_sorted=True,
461
+ )
462
+ f0 = f_interp(t_target)
463
+ else:
464
+ f0 = np.zeros(target_frames)
465
+
466
+ if len(f0) != target_frames:
467
+ f0 = (
468
+ f0[:target_frames]
469
+ if len(f0) > target_frames
470
+ else np.pad(f0, (0, target_frames - len(f0)), "constant")
471
+ )
472
+
473
+ return f0
474
+
475
+ def process(self, audio_path: str, *, f0_path: str | None = None, verbose: Optional[bool] = None) -> np.ndarray:
476
+ """Run f0 extraction for a single wav.
477
+
478
+ Args:
479
+ audio_path: Path to the input wav file.
480
+ f0_path: if is not None, save the f0 data to this path.
481
+ verbose: Override instance-level verbose flag for this call.
482
+
483
+ Returns:
484
+ np.ndarray: shape ``[T]``, f0 in Hz (0 for unvoiced).
485
+ """
486
+ verbose = self.verbose if verbose is None else verbose
487
+ if verbose:
488
+ print(f"[f0 extraction] process: start: {audio_path}")
489
+ t0 = time.time()
490
+
491
+ audio, _ = librosa.load(audio_path, sr=self.input_sr)
492
+ f0_16k = self.model.infer_from_audio(audio, thred=self.thred)
493
+ f0 = self.interpolate_f0(
494
+ f0_16k,
495
+ original_length=audio.shape[-1],
496
+ original_sr=self.input_sr,
497
+ target_sr=self.target_sr,
498
+ hop_size=self.hop_size,
499
+ max_duration=self.max_duration,
500
+ )
501
+
502
+ if verbose:
503
+ dt = time.time() - t0
504
+ voiced_ratio = float(np.mean(f0 > 0)) if len(f0) else 0.0
505
+ print(
506
+ "[f0 extraction] process: done:",
507
+ f"frames={len(f0)}",
508
+ f"voiced_ratio={voiced_ratio:.3f}",
509
+ f"time={dt:.3f}s",
510
+ )
511
+ if f0_path is not None:
512
+ np.save(f0_path, f0)
513
+
514
+ return f0
515
+
516
+
517
+ if __name__ == "__main__":
518
+ model_path = (
519
+ "pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt"
520
+ )
521
+ audio_path = "example/audio/zh_prompt.mp3"
522
+
523
+ pe = F0Extractor(
524
+ model_path,
525
+ device="cuda",
526
+ )
527
+ f0 = pe.process(audio_path, f0_path="example/audio/zh_prompt_f0.npy")
preprocess/tools/g2p.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import ToJyutping
4
+ from g2pM import G2pM
5
+ from g2p_en import G2p as G2pE
6
+
7
+ _EN_WORD_RE = re.compile(r"^[A-Za-z]+(?:'[A-Za-z]+)*$")
8
+ _ZH_WORD_RE = re.compile(r"[\u4e00-\u9fff]")
9
+
10
+ EN_FLAG = "en_"
11
+ YUE_FLAG = "yue_"
12
+ ZH_FLAG = "zh_"
13
+
14
+ g2p_zh = G2pM()
15
+ g2p_en = G2pE()
16
+
17
+
18
+ def is_chinese_char(word: str) -> bool:
19
+ if len(word) != 1:
20
+ return False
21
+ return bool(_ZH_WORD_RE.fullmatch(word))
22
+
23
+ def is_english_word(word: str) -> bool:
24
+ if not word:
25
+ return False
26
+ return bool(_EN_WORD_RE.fullmatch(word))
27
+
28
+ def g2p_cantonese(sent):
29
+ return ToJyutping.get_jyutping_list(sent) # with tone
30
+
31
+ def g2p_mandarin(sent):
32
+ return g2p_zh(sent, tone=True, char_split=False)
33
+
34
+ def g2p_english(word):
35
+ return g2p_en(word)
36
+
37
+ def g2p_transform(words, lang):
38
+
39
+ zh_words = []
40
+ transformed_words = [0] * len(words)
41
+
42
+ for idx, w in enumerate(words):
43
+ if w == "<SP>":
44
+ transformed_words[idx] = w
45
+ continue
46
+
47
+ w = w.replace("?", "").replace(".", "").replace("!", "").replace(",", "")
48
+
49
+ if is_chinese_char(w):
50
+ zh_words.append([idx, w])
51
+ else:
52
+ if is_english_word(w):
53
+ w = EN_FLAG + "-".join(g2p_english(w.lower()))
54
+ else:
55
+ w = "<SP>"
56
+ transformed_words[idx] = w
57
+
58
+ sent = "".join([k[1] for k in zh_words])
59
+
60
+ # zh (zh and yue) transformer to g2p
61
+ if len(sent) > 0:
62
+ if lang == "Cantonese":
63
+ g2pm_rst = g2p_cantonese(sent) # with tone
64
+ g2pm_rst = [YUE_FLAG + k[1] for k in g2pm_rst]
65
+ else:
66
+ g2pm_rst = g2p_mandarin(sent)
67
+ g2pm_rst = [ZH_FLAG + k for k in g2pm_rst]
68
+ for p, w in zip([k[0] for k in zh_words], g2pm_rst):
69
+ transformed_words[p] = w
70
+
71
+ return transformed_words
72
+
preprocess/tools/lyric_transcription.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary
2
+ # https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2
3
+ import os
4
+ import re
5
+ import time
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+ import librosa
9
+ import numpy as np
10
+ from funasr import AutoModel
11
+
12
+
13
+ def _build_words_with_gaps(raw_words, raw_timestamps, wav_fn: str):
14
+ words, word_durs = [], []
15
+ prev = 0.0
16
+ for w, t in zip(raw_words, raw_timestamps):
17
+ s, e = float(t[0]), float(t[1])
18
+ if s > prev:
19
+ words.append("<SP>")
20
+ word_durs.append(s - prev)
21
+ words.append(w)
22
+ word_durs.append(e - s)
23
+ prev = e
24
+
25
+ wav_len = librosa.get_duration(filename=wav_fn)
26
+ if wav_len > prev:
27
+ if len(words) == 0:
28
+ words.append("<SP>")
29
+ word_durs.append(wav_len)
30
+ return words, word_durs
31
+ if words[-1] != "<SP>":
32
+ words.append("<SP>")
33
+ word_durs.append(wav_len - prev)
34
+ else:
35
+ word_durs[-1] += wav_len - prev
36
+
37
+ return words, word_durs
38
+
39
+ def _word_dur_post_process(words, word_durs, f0):
40
+ """Post-process word durations using f0 to better place silences.
41
+ """
42
+ # f0 time grid parameters
43
+ sr = 24000 # f0 sample rate
44
+ hop_length = 480 # f0 hop length
45
+
46
+ # Convert word durations (seconds) to frame boundaries on the f0 grid.
47
+ boundaries = np.cumsum([
48
+ 0,
49
+ *[
50
+ int(dur * sr / hop_length)
51
+ for dur in word_durs
52
+ ],
53
+ ]).tolist()
54
+
55
+ sil_tolerance = 5 # tolerance frames for silence detection
56
+ ext_tolerance = 5 # tolerance frames for vocal extension
57
+
58
+ new_words: list[str] = []
59
+ new_word_durs: list[float] = []
60
+ if words:
61
+ new_words.append(words[0])
62
+ new_word_durs.append(word_durs[0])
63
+
64
+ for i in range(1, len(words)):
65
+ word = words[i]
66
+ if word == "<SP>":
67
+ start_frame = boundaries[i]
68
+ end_frame = boundaries[i + 1]
69
+
70
+ num_frames = end_frame - start_frame
71
+ frame_idx = start_frame
72
+
73
+ # Find first region with at least 5 consecutive "unvoiced" frames.
74
+ unvoiced_count = 0
75
+ while frame_idx < end_frame:
76
+ if f0[frame_idx] <= 1: # unvoiced
77
+ unvoiced_count += 1
78
+ if unvoiced_count >= sil_tolerance:
79
+ frame_idx -= sil_tolerance - 1 # back to the last voiced frame
80
+ break
81
+ else:
82
+ unvoiced_count = 0
83
+ frame_idx += 1
84
+
85
+ voice_frames = frame_idx - start_frame
86
+
87
+ if voice_frames >= int(num_frames * 0.9): # over 90% voiced
88
+ # Treat the whole "<SP>" as silence and merge into previous word.
89
+ new_word_durs[-1] += word_durs[i]
90
+ elif voice_frames >= ext_tolerance: # over 5 frames voiced
91
+ # Split the "<SP>" into two parts: leading silence and tail kept as "<SP>".
92
+ dur = voice_frames * hop_length / sr
93
+ new_word_durs[-1] += dur
94
+ new_words.append("<SP>")
95
+ new_word_durs.append(word_durs[i] - dur)
96
+ else:
97
+ # Too short to adjust, keep as-is.
98
+ new_words.append(word)
99
+ new_word_durs.append(word_durs[i])
100
+ else:
101
+ new_words.append(word)
102
+ new_word_durs.append(word_durs[i])
103
+
104
+ return new_words, new_word_durs
105
+
106
+
107
+ class _ASRZhModel:
108
+ """Mandarin/Cantonese ASR wrapper."""
109
+
110
+ def __init__(self, model_path: str, device: str):
111
+ self.model = AutoModel(
112
+ model=model_path,
113
+ disable_update=True,
114
+ device=device,
115
+ )
116
+
117
+ def process(self, wav_fn):
118
+ out = self.model.generate(wav_fn, output_timestamp=True)[0]
119
+ raw_words = out["text"].replace("@", "").split(" ")
120
+ raw_timestamps = [[t[0] / 1000, t[1] / 1000] for t in out["timestamp"]]
121
+ words, word_durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
122
+
123
+ f0_path = os.path.splitext(wav_fn)[0] + "_f0.npy"
124
+ if os.path.exists(f0_path):
125
+ words, word_durs = _word_dur_post_process(
126
+ words, word_durs, np.load(f0_path)
127
+ )
128
+
129
+ return words, word_durs
130
+
131
+
132
+ class _ASREnModel:
133
+ """English ASR wrapper for NeMo Parakeet-TDT."""
134
+
135
+ def __init__(self, model_path: str, device: str):
136
+ try:
137
+ import nemo.collections.asr as nemo_asr # type: ignore
138
+ except Exception as e: # pragma: no cover
139
+ raise ImportError(
140
+ "NeMo (nemo_toolkit) is required for ASR English but is not available in this Python env. "
141
+ "Install it in the active environment, then retry."
142
+ ) from e
143
+
144
+ self.model = nemo_asr.models.ASRModel.restore_from(
145
+ restore_path=model_path,
146
+ map_location=device,
147
+ )
148
+ self.model.eval()
149
+
150
+ @staticmethod
151
+ def _clean_word(word: str) -> str:
152
+ return re.sub(r"[\?\.,:]", "", word).strip()
153
+
154
+ @staticmethod
155
+ def _extract_word_segments(output: Any) -> List[Dict[str, Any]]:
156
+ ts = getattr(output, "timestamp", None)
157
+ if not ts or not isinstance(ts, dict):
158
+ return []
159
+ word_ts = ts.get("word")
160
+ return word_ts if isinstance(word_ts, list) else []
161
+
162
+ def process(self, wav_fn: str) -> Tuple[List[str], List[float]]:
163
+ outputs = self.model.transcribe(
164
+ [wav_fn],
165
+ timestamps=True,
166
+ batch_size=1,
167
+ num_workers=0,
168
+ )
169
+ output = outputs[0] if outputs else None
170
+
171
+ raw_words: List[str] = []
172
+ raw_timestamps: List[List[float]] = []
173
+ if output is not None:
174
+ for w in self._extract_word_segments(output):
175
+ s, e = float(w.get("start", 0.0)), float(w.get("end", 0.0))
176
+ word = self._clean_word(str(w.get("word", "")))
177
+ if word:
178
+ raw_words.append(word)
179
+ raw_timestamps.append([s, e])
180
+
181
+ words, durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
182
+
183
+ f0_path = os.path.splitext(wav_fn)[0] + "_f0.npy"
184
+ if os.path.exists(f0_path):
185
+ words, durs = _word_dur_post_process(
186
+ words, durs, np.load(f0_path)
187
+ )
188
+
189
+ return words, durs
190
+
191
+
192
+ class LyricTranscriber:
193
+ """Transcribe lyrics from singing voice segment
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ zh_model_path: str,
199
+ en_model_path: str,
200
+ device: str = "cuda",
201
+ *,
202
+ verbose: bool = True,
203
+ ):
204
+ """Initialize lyric transcriber.
205
+
206
+ Args:
207
+ zh_model_path (str): Path to the Chinese model file.
208
+ en_model_path (str): Path to the English model file.
209
+ device (str): Device to use for tensor operations.
210
+ verbose (bool): Whether to print verbose logs.
211
+ """
212
+ self.verbose = verbose
213
+ self.device = device
214
+ self.zh_model_path = zh_model_path
215
+ self.en_model_path = en_model_path
216
+
217
+ if self.verbose:
218
+ print(
219
+ "[lyric transcription] init: start:",
220
+ f"device={device}",
221
+ f"model_path={zh_model_path}",
222
+ )
223
+
224
+ # Always initialize Chinese ASR.
225
+ self.zh_model = _ASRZhModel(device=device, model_path=zh_model_path)
226
+
227
+ # English ASR will be lazily initialized on first English request to avoid long waiting cost when importing NeMo
228
+ self.en_model = None
229
+
230
+ if self.verbose:
231
+ print("[lyric transcription] init: success")
232
+
233
+ def process(self, wav_fn, language: str | None = "Mandarin", *, verbose: bool | None = None):
234
+ """ Lyric transcriber process
235
+
236
+ Args:
237
+ wav_fn (str): Path to the audio file.
238
+ language (str | None): Language of the audio. Defaults to "Mandarin". Supports "Mandarin", "Cantonese" and "English".
239
+ verbose (bool | None): Whether to print verbose logs. Defaults to None.
240
+ """
241
+ v = self.verbose if verbose is None else verbose
242
+ if language not in {"Mandarin", "Cantonese", "English"}:
243
+ raise ValueError(f"Unsupported language: {language}, should be one of ['Mandarin', 'Cantonese', 'English']")
244
+ if v:
245
+ print(f"[lyric transcription] process: start: wav_fn={wav_fn} language={language}")
246
+ t0 = time.time()
247
+
248
+ lang = (language or "auto").lower()
249
+ if lang in {"english"}:
250
+ if self.en_model is None:
251
+ # Lazy-load NeMo model only when English is actually used.
252
+ if v:
253
+ print("[lyric transcription] init English ASR start, please make sure NeMo is installed and wait for a while")
254
+ self.en_model = _ASREnModel(model_path=self.en_model_path, device=self.device)
255
+ if v:
256
+ print("[lyric transcription] init English ASR success")
257
+ out = self.en_model.process(wav_fn)
258
+ else:
259
+ out = self.zh_model.process(wav_fn)
260
+
261
+ if v:
262
+ words, durs = out
263
+ n_words = len(words) if isinstance(words, list) else 0
264
+ dur_sum = float(sum(durs)) if isinstance(durs, list) else 0.0
265
+ dt = time.time() - t0
266
+ print(
267
+ "[lyric transcription] process: done:",
268
+ f"n_words={n_words}",
269
+ f"dur_sum={dur_sum:.3f}s",
270
+ f"time={dt:.3f}s",
271
+ )
272
+
273
+ return out
274
+
275
+
276
+ if __name__ == "__main__":
277
+ m = LyricTranscriber(
278
+ zh_model_path="pretrained_models/SoulX-Singer-Preprocess/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
279
+ en_model_path="pretrained_models/SoulX-Singer-Preprocess/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
280
+ device="cuda"
281
+ )
282
+ print(m.process("example/audio/zh_prompt.mp3", language="Mandarin"))
283
+ print(m.process("example/audio/en_prompt.mp3", language="English"))
preprocess/tools/midi_editor/README.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎹 MIDI Editor - Web-based Singing MIDI Editor
2
+
3
+ [English](README.md) | [简体中文](README_CN.md)
4
+
5
+ A full-featured web MIDI editor for singing voice preprocess. It supports real-time drag editing of MIDI notes, lyric editing, audio waveform alignment, and importing/exporting MIDI files with lyrics.
6
+
7
+ ![MIDI Editor](https://img.shields.io/badge/React-19.2-blue) ![TypeScript](https://img.shields.io/badge/TypeScript-5.9-blue) ![Vite](https://img.shields.io/badge/Vite-7.2-purple)
8
+
9
+ ## ✨ Features
10
+
11
+ ### 🎼 Piano Roll Editing
12
+
13
+ - **Visual note editing**: Full range from C1 to C8 with intuitive piano key layout
14
+ - **Drag operations**:
15
+ - Move notes: drag note blocks to adjust position and pitch
16
+ - Resize start: drag the left edge to adjust start time
17
+ - Resize end: drag the right edge to adjust end time
18
+ - **Quick pitch adjust**:
19
+ - Command/Ctrl + Up/Down to nudge selected note pitch
20
+ - Use the Transpose control in the toolbar to shift all notes at once
21
+ - **Double-click to add**: Add new notes quickly in empty areas
22
+ - **Piano key preview**: Click a key on the left to audition the pitch
23
+
24
+ ### 🔍 Zoom & Navigation
25
+
26
+ - **Horizontal zoom**
27
+ - **Vertical zoom**
28
+ - **Dynamic snapping**: finer snap granularity at higher zoom (min 0.01s)
29
+ - **Auto scroll**: keep the playhead visible during playback
30
+
31
+ ### 📝 Lyric Editing
32
+
33
+ - **Inline editing**: edit lyrics for each note in the side list
34
+ - **Batch fill**: enter lyrics and auto-fill notes in order
35
+ - **Fill from selection**: start batch fill from the currently selected note
36
+ - **Precise fields**: edit PITCH, START, and END directly
37
+ - **Confirm edits**: press Enter or click ✓ to confirm, avoiding accidental changes
38
+
39
+ ### 🎵 Audio Alignment
40
+
41
+ - **Waveform display**: import audio to display waveform, synced with the MIDI timeline
42
+ - **Formats**: MP3, WAV, OGG, FLAC, M4A, AAC
43
+ - **Sync playback**: play audio and MIDI together with independent volume control
44
+ - **Click to seek**: click waveform or timeline to seek
45
+
46
+ ### ⚠️ Overlap Detection
47
+
48
+ - **Visual highlight**: overlapping notes blink in red
49
+ - **One-click fix**: remove all overlaps automatically
50
+
51
+ ### 📥 Import & Export
52
+
53
+ - **MIDI import**: parse standard MIDI files with automatic lyric metadata extraction
54
+ - **MIDI export**: export MIDI files with lyric information
55
+
56
+ ### 🎨 UI & UX
57
+
58
+ - **Theme toggle**: light and dark modes
59
+ - **Responsive layout**: adapts to window size
60
+ - **SVG grid**: cross-browser grid rendering
61
+ - **Status feedback**: real-time state and error tips
62
+
63
+ ## 🚀 Quick Start
64
+
65
+ ### Requirements
66
+
67
+ - Node.js 18+
68
+ - npm or yarn
69
+
70
+ ### Install
71
+
72
+ ```bash
73
+ # Install dependencies
74
+ npm install
75
+
76
+ # Start dev server
77
+ npm run dev
78
+
79
+ # Expose to LAN
80
+ npm run dev -- --host 0.0.0.0
81
+ ```
82
+
83
+ ### Build
84
+
85
+ ```bash
86
+ # Build for production
87
+ npm run build
88
+
89
+ # Preview build
90
+ npm run preview
91
+ ```
92
+
93
+ ## 📖 Usage
94
+
95
+ ### Basic Workflow
96
+
97
+ 1. **Import MIDI**: click Import MIDI and select a .mid file
98
+ 2. **Edit notes**: drag notes in the piano roll to adjust time and pitch
99
+ 3. **Add lyrics**: edit lyrics in the right-side list, or use batch fill
100
+ 4. **Align audio** (optional): import reference audio for side-by-side editing
101
+ 5. **Export**: click Export MIDI to save
102
+
103
+ ### Shortcuts
104
+
105
+ | Action | Description |
106
+ |------|------|
107
+ | Double-click piano roll | Add a new note |
108
+ | Double-click note | Edit lyric |
109
+ | Drag note | Move note and pitch |
110
+ | Drag note edges | Resize note |
111
+ | Backspace / Delete | Delete selected note |
112
+ | Enter | Confirm value edits |
113
+ | Escape | Cancel value edits |
114
+ | Ctrl(Command) + Wheel | Horizontal zoom |
115
+ | Ctrl(Command) + Shift(Option) + Wheel | Vertical zoom |
116
+
117
+ ### Playback Controls
118
+
119
+ | Button | Description |
120
+ |------|------|
121
+ | ⏮ | Go to start |
122
+ | ⏪ | Back 2 seconds |
123
+ | ▶ / ⏸ | Play / Pause |
124
+ | ⏩ | Forward 2 seconds |
125
+ | ⏭ | Go to end |
126
+ | Selection | Play selected region |
127
+
128
+ ## 🛠 Tech Stack
129
+
130
+ - **Frontend**: React 19 + TypeScript
131
+ - **Build**: Vite 7
132
+ - **State**: Zustand
133
+ - **Audio**: Tone.js
134
+ - **Waveform**: WaveSurfer.js
135
+ - **MIDI**: @tonejs/midi
136
+ - **Styles**: CSS with custom variables
137
+
138
+ ## 📁 Project Structure
139
+
140
+ ```
141
+ .
142
+ ├── eslint.config.js
143
+ ├── index.html
144
+ ├── package.json
145
+ ├── postcss.config.js
146
+ ├── README.md
147
+ ├── README_CN.md
148
+ ├── tailwind.config.js
149
+ ├── tsconfig.app.json
150
+ ├── tsconfig.json
151
+ ├── tsconfig.node.json
152
+ ├── vite.config.ts
153
+ ├── public/
154
+ └── src/
155
+ ├── App.css # Main styles (theme variables, layout, components)
156
+ ├── App.tsx # Main app component (transport, import/export, transpose)
157
+ ├── constants.ts # Constants (grid width, row height, pitch range)
158
+ ├── i18n.ts # Internationalization (zh/en translations, smart lyric tokenizer)
159
+ ├── index.css # Global styles (Tailwind, root font, theme gradients)
160
+ ├── main.tsx # React entry point
161
+ ├── types.ts # Type definitions (NoteEvent, TimeSignature, etc.)
162
+ ├── components/
163
+ │ ├── AudioTrack.tsx # Audio waveform display component
164
+ │ ├── LyricTable.tsx # Lyric editing table component
165
+ │ └── PianoRoll.tsx # Piano roll editor component
166
+ ├── lib/
167
+ │ └── midi.ts # MIDI import/export utilities (UTF-8 lyric encoding)
168
+ └── store/
169
+ └── useMidiStore.ts # Zustand state management
170
+ ```
preprocess/tools/midi_editor/README_CN.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎹 MIDI Editor - 网页端歌声 MIDI 编辑器
2
+
3
+ [English](README.md) | [简体中文](README_CN.md)
4
+
5
+ 一个功能完整的网页端歌声 MIDI 文件编辑器。支持实时拖拽调整 MIDI 音符、歌词编辑、音频波形对齐,以及导入导出含歌词的 MIDI 文件。
6
+
7
+ ![MIDI Editor](https://img.shields.io/badge/React-19.2-blue) ![TypeScript](https://img.shields.io/badge/TypeScript-5.9-blue) ![Vite](https://img.shields.io/badge/Vite-7.2-purple)
8
+
9
+ ## ✨ 功能特性
10
+
11
+ ### 🎼 钢琴卷帘编辑
12
+
13
+ - **可视化音符编辑**:支持 C1-C8 全音域显示,直观的钢琴键布局
14
+ - **拖拽操作**:
15
+ - 移动音符:拖拽音符块调整位置和音高
16
+ - 调整音头:拖拽音符左边缘调整开始时间
17
+ - 调整音尾:拖拽音符右边缘调整结束时间
18
+ - **快捷音高调整**:
19
+ - Command/Ctrl + 上/下键调整选中音符的音高
20
+ - 通过功能区的移调功能来整体移动音高
21
+ - **双击添加**:在钢琴卷帘空白处双击快速添加新音符
22
+ - **钢琴键试听**:点击左侧钢琴键可试听对应音高
23
+
24
+ ### 🔍 缩放与导航
25
+
26
+ - **水平缩放**
27
+ - **垂直缩放**
28
+ - **动态精度**:缩放越大,音符调整的 snap 粒度越精细(最小 0.01 秒)
29
+ - **自动滚动**:播放时播放头自动保持可见
30
+
31
+ ### 📝 歌词编辑
32
+
33
+ - **实时编辑**:右侧列表直接编辑每个音符的歌词
34
+ - **批量填充**:输入一段歌词,按字顺序自动填充到音符
35
+ - **从选中开始**:批量填充可从当前选中的音符开始
36
+ - **精确调整**:可直接编辑 PITCH(音高)、START(开始时间)、END(结束时间)
37
+ - **确认机制**:修改数值后按 Enter 或点击 ✓ 确认,避免误操作
38
+
39
+ ### 🎵 音频对齐
40
+
41
+ - **波形显示**:导入音频后显示波形,与 MIDI 同步滚动
42
+ - **格式支持**:MP3、WAV、OGG、FLAC、M4A、AAC
43
+ - **同步播放**:音频与 MIDI 同步播放,可分别调整音量大小
44
+ - **点击定位**:点击波形或时间尺可快速定位播放位置
45
+
46
+ ### ⚠️ 重叠检测
47
+
48
+ - **可视化标注**:时间重叠的音符显示为红色并闪烁
49
+ - **一键修复**:点击消除重叠按钮自动修复所有重叠
50
+
51
+ ### 📥 导入导出
52
+
53
+ - **MIDI 导入**:支持标准 MIDI 文件,自动解析歌词元数据
54
+ - **MIDI 导出**:导出包含歌词信息的 MIDI 文件
55
+
56
+ ### 🎨 界面特性
57
+
58
+ - **主题切换**:支持浅色/深色主题
59
+ - **响应式布局**:自适应窗口大小
60
+ - **SVG 网格**:跨浏览器兼容的网格渲染
61
+ - **状态提示**:实时显示操作状态和错误信息
62
+
63
+ ## 🚀 快速开始
64
+
65
+ ### 环境要求
66
+
67
+ - Node.js 18+
68
+ - npm 或 yarn
69
+
70
+ ### 安装
71
+
72
+ ```bash
73
+ # 安装依赖
74
+ npm install
75
+
76
+ # 启动开发服务器
77
+ npm run dev
78
+
79
+ # 在局域网启动
80
+ npm run dev -- --host 0.0.0.0
81
+ ```
82
+
83
+ ### 构建
84
+
85
+ ```bash
86
+ # 构建生产版本
87
+ npm run build
88
+
89
+ # 预览构建结果
90
+ npm run preview
91
+ ```
92
+
93
+ ## 📖 使用指南
94
+
95
+ ### 基本工作流
96
+
97
+ 1. **导入 MIDI**:点击导入 MIDI 按钮选择 .mid 文件
98
+ 2. **编辑音符**:在钢琴卷帘中拖拽调整音符位置和时长
99
+ 3. **添加歌词**:在右侧列表中输入句级别的歌词或单字编辑
100
+ 4. **对齐音频**(可选):导入参考音频进行对照编辑
101
+ 5. **导出文件**:点击导出含歌词 MIDI 保存文件
102
+
103
+ ### 快捷操作
104
+
105
+ | 操作 | 说明 |
106
+ |------|------|
107
+ | 双击钢琴卷帘 | 添加新音符 |
108
+ | 双击音符 | 修改歌词 |
109
+ | 拖拽音符 | 移动音符位置/音高 |
110
+ | 拖拽音符边缘 | 调整音符时长 |
111
+ | Backspace / Delete | 删除选中音符 |
112
+ | Enter | 确认数值修改 |
113
+ | Escape | 取消数值修改 |
114
+ | Ctrl(Command) + 滚轮 | 水平缩放 |
115
+ | Ctrl(Command) + Shift(Option) + 滚轮 | 垂直缩放 |
116
+
117
+ ### 播放控制
118
+
119
+ | 按钮 | 功能 |
120
+ |------|------|
121
+ | ⏮ | 回到开头 |
122
+ | ⏪ | 后退 2 秒 |
123
+ | ▶ / ⏸ | 播放 / 暂停 |
124
+ | ⏩ | 前进 2 秒 |
125
+ | ⏭ | 跳到结尾 |
126
+ | 选定区域 | 播放选定区域 |
127
+
128
+ ## 🛠 技术栈
129
+
130
+ - **前端框架**:React 19 + TypeScript
131
+ - **构建工具**:Vite 7
132
+ - **状态管理**:Zustand
133
+ - **音频引擎**:Tone.js
134
+ - **波形显示**:WaveSurfer.js
135
+ - **MIDI 解析**:@tonejs/midi
136
+ - **样式**:CSS(自定义变量主题)
137
+
138
+ ## 📁 项目结构
139
+
140
+ ```
141
+ .
142
+ ├── eslint.config.js
143
+ ├── index.html
144
+ ├── package.json
145
+ ├── postcss.config.js
146
+ ├── README.md
147
+ ├── README_CN.md
148
+ ├── tailwind.config.js
149
+ ├── tsconfig.app.json
150
+ ├── tsconfig.json
151
+ ├── tsconfig.node.json
152
+ ├── vite.config.ts
153
+ ├── public/
154
+ └── src/
155
+ ├── App.css # 主样式(含主题变量、布局、组件样式)
156
+ ├── App.tsx # 主应用组件(走带、导入导出、移调等)
157
+ ├── constants.ts # 常量定义(网格宽度、行高、音域范围)
158
+ ├── i18n.ts # 国际化(中英文翻译、歌词智能分词器)
159
+ ├── index.css # ���局样式(Tailwind、根字体、主题渐变)
160
+ ├── main.tsx # React 入口
161
+ ├── types.ts # 类型定义(NoteEvent、TimeSignature 等)
162
+ ├── components/
163
+ │ ├── AudioTrack.tsx # 音频波形显示组件
164
+ │ ├── LyricTable.tsx # 歌词编辑表格组件
165
+ │ └── PianoRoll.tsx # 钢琴卷帘编辑器组件
166
+ ├── lib/
167
+ │ └── midi.ts # MIDI 导入导出工具(含 UTF-8 歌词编解码)
168
+ └── store/
169
+ └── useMidiStore.ts # Zustand 状态管理
170
+ ```
preprocess/tools/midi_editor/eslint.config.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import tseslint from 'typescript-eslint'
6
+ import { defineConfig, globalIgnores } from 'eslint/config'
7
+
8
+ export default defineConfig([
9
+ globalIgnores(['dist']),
10
+ {
11
+ files: ['**/*.{ts,tsx}'],
12
+ extends: [
13
+ js.configs.recommended,
14
+ tseslint.configs.recommended,
15
+ reactHooks.configs.flat.recommended,
16
+ reactRefresh.configs.vite,
17
+ ],
18
+ languageOptions: {
19
+ ecmaVersion: 2020,
20
+ globals: globals.browser,
21
+ },
22
+ },
23
+ ])
preprocess/tools/midi_editor/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>SoulX-Singer MIDI Editor</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.tsx"></script>
12
+ </body>
13
+ </html>
preprocess/tools/midi_editor/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
preprocess/tools/midi_editor/package.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "midi-editor",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "tsc -b && vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "@tonejs/midi": "^2.0.28",
14
+ "class-variance-authority": "^0.7.1",
15
+ "nanoid": "^5.1.6",
16
+ "react": "^19.2.0",
17
+ "react-dom": "^19.2.0",
18
+ "tone": "^15.1.22",
19
+ "wavesurfer.js": "^7.12.1",
20
+ "zustand": "^5.0.10"
21
+ },
22
+ "devDependencies": {
23
+ "@eslint/js": "^9.39.1",
24
+ "@types/node": "^24.10.1",
25
+ "@types/react": "^19.2.5",
26
+ "@types/react-dom": "^19.2.3",
27
+ "@vitejs/plugin-react": "^5.1.1",
28
+ "autoprefixer": "^10.4.20",
29
+ "eslint": "^9.39.1",
30
+ "eslint-plugin-react-hooks": "^7.0.1",
31
+ "eslint-plugin-react-refresh": "^0.4.24",
32
+ "globals": "^16.5.0",
33
+ "postcss": "^8.4.47",
34
+ "tailwindcss": "^3.4.15",
35
+ "typescript": "~5.9.3",
36
+ "typescript-eslint": "^8.46.4",
37
+ "vite": "^7.2.4"
38
+ }
39
+ }
preprocess/tools/midi_editor/postcss.config.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export default {
2
+ plugins: {
3
+ tailwindcss: {},
4
+ autoprefixer: {},
5
+ },
6
+ }
preprocess/tools/midi_editor/public/vite.svg ADDED
preprocess/tools/midi_editor/src/App.css ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .app-shell {
2
+ padding: 24px;
3
+ color: var(--text-primary);
4
+ width: 100%;
5
+ max-width: 100%;
6
+ margin: 0;
7
+ height: 100vh;
8
+ max-height: 100vh;
9
+ display: flex;
10
+ flex-direction: column;
11
+ overflow: hidden;
12
+ box-sizing: border-box;
13
+ }
14
+
15
+ .topbar {
16
+ display: flex;
17
+ align-items: center;
18
+ justify-content: space-between;
19
+ gap: 24px;
20
+ background: var(--panel-strong);
21
+ border: 1px solid var(--border-subtle);
22
+ border-radius: 16px;
23
+ padding: 20px 24px;
24
+ box-shadow: var(--shadow-panel);
25
+ }
26
+
27
+ .topbar h1 {
28
+ margin: 4px 0 0 0;
29
+ font-size: 26px;
30
+ letter-spacing: -0.5px;
31
+ }
32
+
33
+ .eyebrow {
34
+ margin: 0;
35
+ text-transform: uppercase;
36
+ font-size: 12px;
37
+ letter-spacing: 2px;
38
+ color: var(--text-muted);
39
+ }
40
+
41
+ .muted {
42
+ margin: 6px 0 0 0;
43
+ color: var(--text-muted);
44
+ }
45
+
46
+ .actions {
47
+ display: flex;
48
+ gap: 10px;
49
+ align-items: center;
50
+ }
51
+
52
+ .transpose-group {
53
+ display: flex;
54
+ align-items: center;
55
+ }
56
+
57
+ .transpose-select {
58
+ padding: 10px 10px;
59
+ border-radius: 12px;
60
+ border: 1px solid var(--border-soft);
61
+ background: var(--button-soft-bg);
62
+ color: var(--button-soft-text);
63
+ font-weight: 600;
64
+ font-size: 14px;
65
+ cursor: pointer;
66
+ outline: none;
67
+ appearance: none;
68
+ -webkit-appearance: none;
69
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='10' height='6'%3E%3Cpath d='M0 0l5 6 5-6z' fill='%23888'/%3E%3C/svg%3E");
70
+ background-repeat: no-repeat;
71
+ background-position: right 10px center;
72
+ padding-right: 26px;
73
+ transition: transform 140ms ease, box-shadow 140ms ease, background 140ms ease;
74
+ }
75
+
76
+ .transpose-select:hover {
77
+ transform: translateY(-1px);
78
+ }
79
+
80
+ .transpose-select:focus {
81
+ border-color: var(--accent);
82
+ }
83
+
84
+ .icon-toggle {
85
+ width: 40px;
86
+ height: 40px;
87
+ border-radius: 999px;
88
+ border: 1px solid var(--border-soft);
89
+ background: var(--button-ghost-bg);
90
+ color: var(--button-ghost-text);
91
+ display: inline-flex;
92
+ align-items: center;
93
+ justify-content: center;
94
+ font-size: 18px;
95
+ cursor: pointer;
96
+ }
97
+
98
+ .icon-toggle:hover {
99
+ transform: translateY(-1px);
100
+ }
101
+
102
+ .lang-label {
103
+ font-size: 14px;
104
+ font-weight: 700;
105
+ line-height: 1;
106
+ }
107
+
108
+ .audio-bar {
109
+ margin-top: 14px;
110
+ padding: 12px 16px;
111
+ border-radius: 14px;
112
+ background: var(--panel-strong);
113
+ border: 1px solid var(--border-subtle);
114
+ display: flex;
115
+ align-items: center;
116
+ justify-content: space-between;
117
+ gap: 16px;
118
+ }
119
+
120
+ .audio-left {
121
+ display: flex;
122
+ align-items: center;
123
+ gap: 12px;
124
+ }
125
+
126
+ .audio-hint {
127
+ color: var(--text-muted);
128
+ font-size: 12px;
129
+ }
130
+
131
+ .audio-right {
132
+ display: flex;
133
+ align-items: center;
134
+ gap: 20px;
135
+ }
136
+
137
+ .volume-control {
138
+ display: flex;
139
+ align-items: center;
140
+ gap: 8px;
141
+ }
142
+
143
+ .volume-label {
144
+ font-size: 12px;
145
+ color: var(--text-muted);
146
+ min-width: 32px;
147
+ }
148
+
149
+ .volume-slider {
150
+ width: 80px;
151
+ height: 4px;
152
+ cursor: pointer;
153
+ accent-color: var(--accent);
154
+ }
155
+
156
+ .volume-value {
157
+ font-size: 11px;
158
+ color: var(--text-muted);
159
+ min-width: 36px;
160
+ text-align: right;
161
+ }
162
+
163
+ .toggle {
164
+ display: inline-flex;
165
+ align-items: center;
166
+ gap: 8px;
167
+ font-size: 13px;
168
+ color: var(--text-primary);
169
+ }
170
+
171
+ .panel {
172
+ margin-top: 18px;
173
+ background: var(--panel);
174
+ border: 1px solid var(--border-subtle);
175
+ border-radius: 16px;
176
+ padding: 18px;
177
+ box-shadow: var(--shadow-panel);
178
+ display: flex;
179
+ flex-direction: column;
180
+ flex: 1;
181
+ min-height: 0;
182
+ overflow: hidden;
183
+ }
184
+
185
+ .panel-split {
186
+ display: grid;
187
+ grid-template-columns: minmax(0, 1fr) 360px;
188
+ gap: 16px;
189
+ align-items: stretch;
190
+ flex: 1;
191
+ min-height: 0;
192
+ max-height: 100%;
193
+ overflow: hidden;
194
+ }
195
+
196
+ .panel-main {
197
+ min-width: 0;
198
+ display: flex;
199
+ flex-direction: column;
200
+ min-height: 0;
201
+ max-height: 100%;
202
+ overflow: hidden;
203
+ }
204
+
205
+ .panel-side {
206
+ display: flex;
207
+ flex-direction: column;
208
+ gap: 16px;
209
+ width: 360px;
210
+ max-width: 360px;
211
+ /* Use absolute positioning to enforce height */
212
+ position: relative;
213
+ overflow: hidden;
214
+ }
215
+
216
+ .controls {
217
+ display: grid;
218
+ grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
219
+ gap: 14px;
220
+ align-items: center;
221
+ background: var(--panel-soft);
222
+ padding: 12px 14px;
223
+ border-radius: 12px;
224
+ border: 1px solid var(--border-soft);
225
+ flex-shrink: 0;
226
+ }
227
+
228
+ .controls label {
229
+ display: block;
230
+ font-size: 12px;
231
+ text-transform: uppercase;
232
+ letter-spacing: 1px;
233
+ color: var(--text-muted);
234
+ margin-bottom: 4px;
235
+ }
236
+
237
+ .controls input[type='number'] {
238
+ width: 100%;
239
+ padding: 10px 12px;
240
+ border-radius: 10px;
241
+ border: 1px solid var(--border-soft);
242
+ background: var(--input-bg);
243
+ color: var(--text-primary);
244
+ }
245
+
246
+ .timesig {
247
+ display: flex;
248
+ align-items: center;
249
+ gap: 6px;
250
+ }
251
+
252
+ .timesig span {
253
+ font-weight: 700;
254
+ color: var(--text-muted);
255
+ }
256
+
257
+ .transport {
258
+ display: flex;
259
+ gap: 6px;
260
+ align-items: center;
261
+ grid-column: 1 / -1;
262
+ }
263
+
264
+ .transport button {
265
+ padding: 6px 10px !important;
266
+ font-size: 13px !important;
267
+ min-width: 0;
268
+ }
269
+
270
+ .status {
271
+ grid-column: 1 / -1;
272
+ color: var(--text-muted);
273
+ font-size: 13px;
274
+ }
275
+
276
+ .transport-divider {
277
+ width: 1px;
278
+ height: 20px;
279
+ background: var(--border-soft);
280
+ margin: 0 2px;
281
+ flex-shrink: 0;
282
+ }
283
+
284
+ .selection-btn {
285
+ font-size: 12px !important;
286
+ padding: 6px 10px !important;
287
+ white-space: nowrap;
288
+ }
289
+
290
+ .selection-btn.active {
291
+ background: var(--accent) !important;
292
+ color: white !important;
293
+ }
294
+
295
+ .button,
296
+ .actions button,
297
+ .transport button,
298
+ .ghost,
299
+ .primary,
300
+ .json-btn,
301
+ .soft {
302
+ cursor: pointer;
303
+ border-radius: 12px;
304
+ border: 1px solid transparent;
305
+ padding: 10px 14px;
306
+ font-weight: 600;
307
+ transition: transform 140ms ease, box-shadow 140ms ease, background 140ms ease, border 140ms ease;
308
+ color: #0f1528;
309
+ }
310
+
311
+ .ghost {
312
+ background: var(--button-ghost-bg);
313
+ color: var(--button-ghost-text);
314
+ border-color: var(--border-soft);
315
+ }
316
+
317
+ .primary {
318
+ background: linear-gradient(135deg, var(--accent), var(--accent-strong));
319
+ color: var(--button-primary-text);
320
+ box-shadow: 0 8px 26px rgba(72, 228, 194, 0.2);
321
+ }
322
+
323
+ .json-btn {
324
+ background: linear-gradient(135deg, #f59e0b, #d97706);
325
+ color: #fff;
326
+ box-shadow: 0 8px 26px rgba(245, 158, 11, 0.2);
327
+ }
328
+
329
+ .soft {
330
+ background: var(--button-soft-bg);
331
+ color: var(--button-soft-text);
332
+ border: 1px solid var(--border-soft);
333
+ }
334
+
335
+ .ghost:disabled,
336
+ .primary:disabled,
337
+ .json-btn:disabled,
338
+ .soft:disabled {
339
+ opacity: 0.6;
340
+ cursor: not-allowed;
341
+ }
342
+
343
+ .ghost:hover,
344
+ .primary:hover,
345
+ .json-btn:hover,
346
+ .soft:hover {
347
+ transform: translateY(-1px);
348
+ }
349
+
350
+ .piano-shell {
351
+ border-radius: 12px;
352
+ background: var(--panel-strong);
353
+ border: 1px solid var(--border-subtle);
354
+ overflow: hidden;
355
+ flex: 1;
356
+ min-height: 0;
357
+ max-height: 100%;
358
+ display: flex;
359
+ flex-direction: column;
360
+ }
361
+
362
+ .ruler {
363
+ position: relative;
364
+ height: 32px;
365
+ background: var(--panel-soft);
366
+ border-bottom: 1px solid var(--border-soft);
367
+ min-width: 100%;
368
+ }
369
+
370
+ .ruler-shell {
371
+ display: flex;
372
+ }
373
+
374
+ .ruler-spacer {
375
+ background: var(--panel-soft);
376
+ border-bottom: 1px solid var(--border-soft);
377
+ height: 32px;
378
+ }
379
+
380
+ .ruler-scroll {
381
+ overflow: hidden;
382
+ flex: 1;
383
+ height: 32px;
384
+ cursor: pointer;
385
+ }
386
+
387
+ .measure-mark {
388
+ position: absolute;
389
+ top: 0;
390
+ height: 100%;
391
+ display: flex;
392
+ flex-direction: column;
393
+ align-items: flex-start;
394
+ font-size: 10px;
395
+ color: var(--text-muted);
396
+ padding-left: 4px;
397
+ border-left: 1px solid var(--border-soft);
398
+ }
399
+
400
+ .measure-mark span {
401
+ margin-top: 2px;
402
+ }
403
+
404
+ .ruler-playhead {
405
+ position: absolute;
406
+ top: 0;
407
+ width: 2px;
408
+ height: 100%;
409
+ background: #ff7043;
410
+ pointer-events: none;
411
+ z-index: 10;
412
+ }
413
+
414
+ .ruler-scroll.selecting {
415
+ cursor: crosshair;
416
+ }
417
+
418
+ .selection-range {
419
+ position: absolute;
420
+ top: 0;
421
+ height: 100%;
422
+ background: rgba(66, 165, 245, 0.35);
423
+ border-left: 2px solid #42a5f5;
424
+ border-right: 2px solid #42a5f5;
425
+ pointer-events: none;
426
+ z-index: 5;
427
+ }
428
+
429
+ .grid-selection-range {
430
+ position: absolute;
431
+ top: 0;
432
+ background: rgba(66, 165, 245, 0.15);
433
+ border-left: 2px dashed #42a5f5;
434
+ border-right: 2px dashed #42a5f5;
435
+ pointer-events: none;
436
+ z-index: 1;
437
+ }
438
+
439
+ .roll-body {
440
+ display: flex;
441
+ flex: 1;
442
+ min-height: 0;
443
+ overflow: hidden;
444
+ }
445
+
446
+ .pitch-rail {
447
+ background: var(--panel-strong);
448
+ border-right: 1px solid var(--border-subtle);
449
+ color: var(--text-primary);
450
+ font-size: 12px;
451
+ text-align: right;
452
+ overflow: hidden;
453
+ flex-shrink: 0;
454
+ height: 100%;
455
+ }
456
+
457
+ .pitch-cell {
458
+ border-bottom: 1px solid var(--border-soft);
459
+ display: flex;
460
+ align-items: center;
461
+ justify-content: flex-end;
462
+ padding: 0 4px;
463
+ font-variant-numeric: tabular-nums;
464
+ box-sizing: border-box;
465
+ }
466
+
467
+ .pitch-white {
468
+ background: rgba(255, 255, 255, 0.06);
469
+ color: var(--text-primary);
470
+ }
471
+
472
+ .pitch-black {
473
+ background: rgba(0, 0, 0, 0.35);
474
+ color: rgba(233, 238, 247, 0.9);
475
+ }
476
+
477
+ .pitch-c {
478
+ background: rgba(100, 150, 255, 0.15);
479
+ font-weight: 600;
480
+ }
481
+
482
+ .pitch-label {
483
+ font-size: 10px;
484
+ }
485
+
486
+ .roll-grid {
487
+ position: relative;
488
+ overflow: auto;
489
+ flex: 1;
490
+ min-height: 0;
491
+ background-color: var(--grid-bg);
492
+ }
493
+
494
+ .grid-content {
495
+ background-color: var(--grid-bg);
496
+ }
497
+
498
+ .grid-svg {
499
+ shape-rendering: crispEdges;
500
+ }
501
+
502
+ .grid-overlay {
503
+ position: relative;
504
+ }
505
+
506
+ .note-chip {
507
+ position: absolute;
508
+ background: linear-gradient(135deg, var(--accent), var(--accent-strong));
509
+ border-radius: 6px;
510
+ border: 1px solid rgba(255, 255, 255, 0.16);
511
+ box-shadow: 0 10px 22px rgba(0, 0, 0, 0.25);
512
+ display: flex;
513
+ align-items: center;
514
+ justify-content: center;
515
+ color: var(--note-text);
516
+ font-weight: 700;
517
+ user-select: none;
518
+ box-sizing: border-box;
519
+ }
520
+
521
+ .note-active {
522
+ outline: 2px solid #ff7043;
523
+ z-index: 2;
524
+ }
525
+
526
+ .note-overlap {
527
+ background: linear-gradient(135deg, #ef5350 0%, #ff7043 100%) !important;
528
+ animation: pulse-overlap 1s ease-in-out infinite;
529
+ }
530
+
531
+ /* Selected overlapping note - more visible outline */
532
+ .note-overlap.note-active {
533
+ outline: 3px solid #1e40af;
534
+ outline-offset: 1px;
535
+ box-shadow: 0 0 12px rgba(30, 64, 175, 0.8);
536
+ animation: none;
537
+ }
538
+
539
+ @keyframes pulse-overlap {
540
+ 0%, 100% { opacity: 1; }
541
+ 50% { opacity: 0.7; }
542
+ }
543
+
544
+ .playhead {
545
+ position: absolute;
546
+ top: 0;
547
+ width: 2px;
548
+ background: #ff7043;
549
+ box-shadow: 0 0 12px rgba(255, 112, 67, 0.6);
550
+ pointer-events: none;
551
+ z-index: 20;
552
+ }
553
+
554
+ .pitch-rail-inner {
555
+ will-change: transform;
556
+ }
557
+
558
+ .note-label {
559
+ width: 100%;
560
+ text-align: center;
561
+ font-size: 12px;
562
+ padding: 0 12px;
563
+ overflow: hidden;
564
+ text-overflow: ellipsis;
565
+ white-space: nowrap;
566
+ }
567
+
568
+ .note-handle {
569
+ position: absolute;
570
+ top: 0;
571
+ width: 8px;
572
+ height: 100%;
573
+ background: rgba(255, 255, 255, 0.25);
574
+ cursor: ew-resize;
575
+ }
576
+
577
+ .note-handle.start {
578
+ left: 0;
579
+ border-radius: 6px 0 0 6px;
580
+ }
581
+
582
+ .note-handle.end {
583
+ right: 0;
584
+ border-radius: 0 6px 6px 0;
585
+ }
586
+
587
+ .lyric-container {
588
+ flex: 1;
589
+ min-height: 0;
590
+ position: relative;
591
+ }
592
+
593
+ .lyric-card {
594
+ border: 1px solid rgba(255, 255, 255, 0.06);
595
+ border-radius: 12px;
596
+ background: var(--panel-soft);
597
+ overflow: hidden;
598
+ display: flex;
599
+ flex-direction: column;
600
+ /* Force fixed height with absolute positioning */
601
+ position: absolute;
602
+ top: 0;
603
+ left: 0;
604
+ right: 0;
605
+ bottom: 0;
606
+ }
607
+
608
+ .lyric-bulk {
609
+ display: flex;
610
+ gap: 8px;
611
+ padding: 10px 12px;
612
+ border-bottom: 1px solid rgba(255, 255, 255, 0.06);
613
+ align-items: center;
614
+ }
615
+
616
+ .lyric-bulk-input {
617
+ flex: 1;
618
+ padding: 8px 10px;
619
+ border-radius: 10px;
620
+ border: 1px solid var(--border-soft);
621
+ background: var(--input-bg);
622
+ color: var(--text-primary);
623
+ resize: vertical;
624
+ }
625
+
626
+ .lyric-header,
627
+ .lyric-row {
628
+ display: grid;
629
+ grid-template-columns: 1.4fr 0.5fr 0.5fr 0.5fr;
630
+ gap: 8px;
631
+ padding: 10px 12px;
632
+ align-items: center;
633
+ }
634
+
635
+ .lyric-header {
636
+ font-size: 12px;
637
+ text-transform: uppercase;
638
+ letter-spacing: 1px;
639
+ color: var(--text-muted);
640
+ border-bottom: 1px solid rgba(255, 255, 255, 0.06);
641
+ }
642
+
643
+ .lyric-list {
644
+ overflow-y: auto;
645
+ overflow-x: hidden;
646
+ flex: 1;
647
+ min-height: 0;
648
+ }
649
+
650
+ .lyric-row {
651
+ border-bottom: 1px solid rgba(255, 255, 255, 0.04);
652
+ }
653
+
654
+ .lyric-row:hover {
655
+ background: rgba(255, 255, 255, 0.03);
656
+ }
657
+
658
+ .lyric-row-active {
659
+ background: rgba(72, 228, 194, 0.08);
660
+ border-left: 3px solid #48e4c2;
661
+ }
662
+
663
+ .lyric-input {
664
+ width: 100%;
665
+ padding: 8px 10px;
666
+ border-radius: 10px;
667
+ border: 1px solid var(--border-soft);
668
+ background: var(--input-bg);
669
+ color: var(--text-primary);
670
+ }
671
+
672
+ .lyric-meta {
673
+ color: var(--text-muted);
674
+ font-variant-numeric: tabular-nums;
675
+ }
676
+
677
+ .editable-cell {
678
+ position: relative;
679
+ display: flex;
680
+ align-items: center;
681
+ gap: 2px;
682
+ }
683
+
684
+ .lyric-meta-input {
685
+ width: 100%;
686
+ padding: 2px 4px;
687
+ border: 1px solid transparent;
688
+ border-radius: 4px;
689
+ background: transparent;
690
+ color: var(--text-muted);
691
+ font-size: 12px;
692
+ font-variant-numeric: tabular-nums;
693
+ text-align: center;
694
+ outline: none;
695
+ transition: border-color 0.15s, background-color 0.15s;
696
+ }
697
+
698
+ .lyric-meta-input:hover {
699
+ background: var(--surface-elevated);
700
+ }
701
+
702
+ .lyric-meta-input:focus {
703
+ border-color: var(--accent);
704
+ background: var(--surface-elevated);
705
+ color: var(--text-primary);
706
+ }
707
+
708
+ .lyric-meta-dirty {
709
+ border-color: #f59e0b !important;
710
+ background: rgba(245, 158, 11, 0.1) !important;
711
+ }
712
+
713
+ .confirm-btn {
714
+ flex-shrink: 0;
715
+ width: 18px;
716
+ height: 18px;
717
+ padding: 0;
718
+ border: none;
719
+ border-radius: 4px;
720
+ background: #22c55e;
721
+ color: white;
722
+ font-size: 12px;
723
+ font-weight: bold;
724
+ cursor: pointer;
725
+ display: flex;
726
+ align-items: center;
727
+ justify-content: center;
728
+ transition: background 0.15s;
729
+ }
730
+
731
+ .confirm-btn:hover {
732
+ background: #16a34a;
733
+ }
734
+
735
+ /* Hide number input spinners */
736
+ .lyric-meta-input::-webkit-outer-spin-button,
737
+ .lyric-meta-input::-webkit-inner-spin-button {
738
+ -webkit-appearance: none;
739
+ margin: 0;
740
+ }
741
+
742
+ .lyric-meta-input[type=number] {
743
+ -moz-appearance: textfield;
744
+ }
745
+
746
+ .lyric-empty {
747
+ padding: 16px;
748
+ color: var(--text-muted);
749
+ text-align: center;
750
+ }
751
+
752
+ .audio-track {
753
+ display: grid;
754
+ grid-template-columns: 80px 1fr;
755
+ gap: 12px;
756
+ align-items: center;
757
+ padding: 12px 14px;
758
+ border-radius: 12px;
759
+ border: 1px solid var(--border-soft);
760
+ background: var(--panel-soft);
761
+ margin-bottom: 12px;
762
+ flex-shrink: 0;
763
+ }
764
+
765
+ .audio-track-label {
766
+ font-size: 12px;
767
+ text-transform: uppercase;
768
+ letter-spacing: 1px;
769
+ color: var(--text-muted);
770
+ }
771
+
772
+ .audio-wave {
773
+ width: 100%;
774
+ height: 80px;
775
+ min-height: 80px;
776
+ }
777
+
778
+ :root {
779
+ --text-primary: #e9eef7;
780
+ --text-muted: rgba(233, 238, 247, 0.7);
781
+ --panel: rgba(13, 16, 28, 0.8);
782
+ --panel-strong: rgba(16, 21, 35, 0.95);
783
+ --panel-soft: rgba(255, 255, 255, 0.03);
784
+ --border-subtle: rgba(255, 255, 255, 0.08);
785
+ --border-soft: rgba(255, 255, 255, 0.12);
786
+ --input-bg: rgba(255, 255, 255, 0.06);
787
+ --grid-bg: rgba(14, 18, 30, 0.9);
788
+ --grid-line-minor: rgba(233, 238, 247, 0.08);
789
+ --grid-line-major: rgba(233, 238, 247, 0.16);
790
+ --accent: #48e4c2;
791
+ --accent-strong: #4b64bc;
792
+ --note-text: #0b1122;
793
+ --button-ghost-bg: rgba(233, 238, 247, 0.18);
794
+ --button-ghost-text: #ffffff;
795
+ --button-soft-bg: rgba(255, 255, 255, 0.14);
796
+ --button-soft-text: #ffffff;
797
+ --button-primary-text: #0b1122;
798
+ --shadow-panel: 0 18px 40px rgba(0, 0, 0, 0.32);
799
+ }
800
+
801
+ :root[data-theme='light'] {
802
+ --text-primary: #1b2238;
803
+ --text-muted: rgba(27, 34, 56, 0.7);
804
+ --panel: rgba(255, 255, 255, 0.9);
805
+ --panel-strong: rgba(250, 252, 255, 0.98);
806
+ --panel-soft: rgba(15, 23, 42, 0.04);
807
+ --border-subtle: rgba(15, 23, 42, 0.12);
808
+ --border-soft: rgba(15, 23, 42, 0.16);
809
+ --input-bg: rgba(15, 23, 42, 0.06);
810
+ --grid-bg: rgba(248, 250, 255, 0.95);
811
+ --grid-line-minor: rgba(15, 23, 42, 0.12);
812
+ --grid-line-major: rgba(15, 23, 42, 0.24);
813
+ --accent: #3f8cff;
814
+ --accent-strong: #4b64bc;
815
+ --note-text: #ffffff;
816
+ --button-ghost-bg: rgba(15, 23, 42, 0.06);
817
+ --button-ghost-text: #1b2238;
818
+ --button-soft-bg: rgba(15, 23, 42, 0.06);
819
+ --button-soft-text: #1b2238;
820
+ --button-primary-text: #0b1122;
821
+ --shadow-panel: 0 18px 40px rgba(15, 23, 42, 0.15);
822
+ }
823
+
824
+ .sr-only {
825
+ position: absolute;
826
+ width: 1px;
827
+ height: 1px;
828
+ padding: 0;
829
+ margin: -1px;
830
+ overflow: hidden;
831
+ clip: rect(0, 0, 0, 0);
832
+ white-space: nowrap;
833
+ border: 0;
834
+ }
preprocess/tools/midi_editor/src/App.tsx ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
2
+ import * as Tone from 'tone'
3
+ import { PianoRoll } from './components/PianoRoll'
4
+ import { LyricTable } from './components/LyricTable'
5
+ import { AudioTrack } from './components/AudioTrack'
6
+ import { useMidiStore } from './store/useMidiStore'
7
+ import { exportMidi, importMidiFile } from './lib/midi'
8
+ import type { TimeSignature } from './types'
9
+ import type { Lang } from './i18n'
10
+ import { getTranslations } from './i18n'
11
+ import { BASE_GRID_SECOND_WIDTH, BASE_ROW_HEIGHT, LOW_NOTE, HIGH_NOTE } from './constants'
12
+ import './App.css'
13
+
14
+ type PlayEvent = {
15
+ time: number
16
+ midi: number
17
+ duration: number
18
+ velocity: number
19
+ }
20
+
21
+ function App() {
22
+ const {
23
+ notes,
24
+ tempo,
25
+ timeSignature,
26
+ selectedId,
27
+ playhead,
28
+ ppq,
29
+ addNote,
30
+ updateNote,
31
+ removeNote,
32
+ setNotes,
33
+ setTempo,
34
+ setTimeSignature,
35
+ setPpq,
36
+ select,
37
+ setPlayhead,
38
+ } = useMidiStore()
39
+
40
+ const [lang, setLang] = useState<Lang>('zh')
41
+ const t = getTranslations(lang)
42
+
43
+ const [status, setStatus] = useState(t.ready)
44
+ const [isPlaying, setIsPlaying] = useState(false)
45
+ const [theme, setTheme] = useState<'dark' | 'light'>('light')
46
+ const [audioUrl, setAudioUrl] = useState<string | null>(null)
47
+ const [audioDuration, setAudioDuration] = useState(0)
48
+ const [midiVolume, setMidiVolume] = useState(80) // 0-100
49
+ const [audioVolume, setAudioVolume] = useState(80) // 0-100
50
+ const [horizontalZoom, setHorizontalZoom] = useState(1)
51
+ const [verticalZoom, setVerticalZoom] = useState(1)
52
+ const [focusLyricId, setFocusLyricId] = useState<string | null>(null)
53
+ // Selection range for loop playback (in seconds)
54
+ const [selectionStart, setSelectionStart] = useState<number | null>(null)
55
+ const [selectionEnd, setSelectionEnd] = useState<number | null>(null)
56
+ const [isSelectingRange, setIsSelectingRange] = useState(false)
57
+ const fileInputRef = useRef<HTMLInputElement | null>(null)
58
+ const audioInputRef = useRef<HTMLInputElement | null>(null)
59
+ const audioRef = useRef<HTMLAudioElement | null>(null)
60
+ const partRef = useRef<Tone.Part<PlayEvent> | null>(null)
61
+ const synthRef = useRef<Tone.PolySynth | null>(null)
62
+ const rafRef = useRef<number | null>(null)
63
+ const audioScrollRef = useRef<HTMLDivElement | null>(null)
64
+
65
+ useEffect(() => {
66
+ return () => {
67
+ stopPlayback()
68
+ synthRef.current?.dispose()
69
+ }
70
+ }, [])
71
+
72
+ useEffect(() => {
73
+ document.documentElement.dataset.theme = theme
74
+ }, [theme])
75
+
76
+ // Update status text when language changes
77
+ useEffect(() => {
78
+ setStatus(t.ready)
79
+ }, [lang])
80
+
81
+ // Sync audio volume - also trigger when audioUrl changes (new audio loaded)
82
+ useEffect(() => {
83
+ if (audioRef.current) {
84
+ audioRef.current.volume = audioVolume / 100
85
+ }
86
+ }, [audioVolume, audioUrl])
87
+
88
+ // Sync MIDI synth volume
89
+ useEffect(() => {
90
+ if (synthRef.current) {
91
+ // Convert 0-100 to dB scale (-60 to 0)
92
+ const dbValue = midiVolume === 0 ? -Infinity : (midiVolume / 100) * 60 - 60
93
+ synthRef.current.volume.value = dbValue
94
+ }
95
+ }, [midiVolume])
96
+
97
+ useEffect(() => {
98
+ if (!audioUrl) return
99
+ return () => {
100
+ URL.revokeObjectURL(audioUrl)
101
+ }
102
+ }, [audioUrl])
103
+
104
+ const ensureSynth = async () => {
105
+ await Tone.start()
106
+ if (!synthRef.current) {
107
+ synthRef.current = new Tone.PolySynth(Tone.Synth).toDestination()
108
+ // Apply current volume
109
+ const dbValue = midiVolume === 0 ? -Infinity : (midiVolume / 100) * 60 - 60
110
+ synthRef.current.volume.value = dbValue
111
+ }
112
+ }
113
+
114
+ const playPreviewNote = useCallback(async (midi: number) => {
115
+ await ensureSynth()
116
+ const frequency = Tone.Frequency(midi, 'midi').toFrequency()
117
+ synthRef.current?.triggerAttackRelease(frequency, '8n', Tone.now(), 0.7)
118
+ }, [midiVolume])
119
+
120
+ useEffect(() => {
121
+ const onKeyDown = (event: KeyboardEvent) => {
122
+ if (!selectedId) return
123
+ const target = event.target as HTMLElement | null
124
+ if (target && ['INPUT', 'TEXTAREA'].includes(target.tagName)) return
125
+
126
+ // Delete note
127
+ if (event.key === 'Backspace' || event.key === 'Delete') {
128
+ event.preventDefault()
129
+ removeNote(selectedId)
130
+ select(null)
131
+ return
132
+ }
133
+
134
+ // Cmd/Ctrl + Up/Down to adjust pitch
135
+ const isCmdOrCtrl = event.metaKey || event.ctrlKey
136
+ if (isCmdOrCtrl && (event.key === 'ArrowUp' || event.key === 'ArrowDown')) {
137
+ event.preventDefault()
138
+ const selectedNote = notes.find(n => n.id === selectedId)
139
+ if (!selectedNote) return
140
+
141
+ const delta = event.key === 'ArrowUp' ? 1 : -1
142
+ const newMidi = Math.max(LOW_NOTE, Math.min(HIGH_NOTE, selectedNote.midi + delta))
143
+
144
+ if (newMidi !== selectedNote.midi) {
145
+ updateNote(selectedId, { midi: newMidi })
146
+ playPreviewNote(newMidi)
147
+ }
148
+ }
149
+ }
150
+ window.addEventListener('keydown', onKeyDown)
151
+ return () => window.removeEventListener('keydown', onKeyDown)
152
+ }, [selectedId, notes, removeNote, select, updateNote, playPreviewNote])
153
+
154
+ const noteEvents = useMemo<PlayEvent[]>(
155
+ () =>
156
+ notes.map((note) => ({
157
+ time: (60 / tempo) * note.start,
158
+ duration: (60 / tempo) * note.duration,
159
+ midi: note.midi,
160
+ velocity: note.velocity,
161
+ })),
162
+ [notes, tempo],
163
+ )
164
+
165
+ const beatToSeconds = (beat: number) => beat * (60 / tempo)
166
+ const secondsToBeat = (seconds: number) => seconds / (60 / tempo)
167
+ const seekBySeconds = (deltaSeconds: number) => {
168
+ const maxNoteEnd = notes.reduce((acc, n) => Math.max(acc, n.start + n.duration), 0)
169
+ const maxBeat = Math.max(secondsToBeat(audioDuration), maxNoteEnd)
170
+ const nextSeconds = Math.max(0, Math.min(beatToSeconds(maxBeat), beatToSeconds(playhead) + deltaSeconds))
171
+ seekToBeat(secondsToBeat(nextSeconds))
172
+ }
173
+
174
+ const gridSecondWidth = BASE_GRID_SECOND_WIDTH * horizontalZoom
175
+ const rowHeight = BASE_ROW_HEIGHT * verticalZoom
176
+
177
+ // Calculate MIDI content width to sync with audio track
178
+ const midiContentWidth = useMemo(() => {
179
+ const noteEndSeconds = notes.reduce((acc, n) => {
180
+ const endBeat = n.start + n.duration
181
+ return Math.max(acc, beatToSeconds(endBeat))
182
+ }, 8)
183
+ const maxSeconds = Math.max(noteEndSeconds + 10, audioDuration + 10, 30)
184
+ return maxSeconds * gridSecondWidth
185
+ }, [notes, audioDuration, gridSecondWidth, beatToSeconds])
186
+
187
+ const seekToBeat = (beat: number) => {
188
+ setPlayhead(beat)
189
+ Tone.Transport.seconds = beatToSeconds(beat)
190
+ if (audioRef.current) {
191
+ audioRef.current.currentTime = beatToSeconds(beat)
192
+ }
193
+ }
194
+
195
+ const schedulePlayback = async () => {
196
+ if (!notes.length && !audioUrl) return
197
+ await ensureSynth()
198
+ partRef.current?.dispose()
199
+ Tone.Transport.cancel()
200
+ Tone.Transport.stop()
201
+ Tone.Transport.bpm.value = tempo
202
+
203
+ // Determine playback range
204
+ const hasSelection = selectionStart !== null && selectionEnd !== null && selectionEnd > selectionStart
205
+ const startSeconds = hasSelection ? selectionStart : beatToSeconds(playhead)
206
+ const endSeconds = hasSelection ? selectionEnd : null
207
+
208
+ Tone.Transport.seconds = startSeconds
209
+
210
+ // Filter notes within selection range if applicable
211
+ const filteredEvents = hasSelection
212
+ ? noteEvents.filter(e => e.time >= startSeconds && e.time < endSeconds!)
213
+ : noteEvents
214
+
215
+ if (filteredEvents.length) {
216
+ partRef.current = new Tone.Part((time, event) => {
217
+ if (midiVolume === 0) return
218
+ const frequency = Tone.Frequency(event.midi, 'midi').toFrequency()
219
+ synthRef.current?.triggerAttackRelease(frequency, event.duration, time, event.velocity)
220
+ }, filteredEvents)
221
+ partRef.current.start(0)
222
+ }
223
+ Tone.Transport.start()
224
+ if (audioRef.current && audioUrl) {
225
+ audioRef.current.currentTime = startSeconds
226
+ if (audioVolume > 0) {
227
+ audioRef.current.play().catch(() => null)
228
+ }
229
+ }
230
+ setIsPlaying(true)
231
+ setStatus(hasSelection ? t.selectionPlayback : t.playing)
232
+
233
+ const tick = () => {
234
+ const seconds =
235
+ audioRef.current && audioUrl && !audioRef.current.paused
236
+ ? audioRef.current.currentTime
237
+ : Tone.Transport.seconds
238
+
239
+ // Stop at selection end
240
+ if (endSeconds !== null && seconds >= endSeconds) {
241
+ pausePlayback()
242
+ seekToBeat(secondsToBeat(selectionStart!))
243
+ setStatus(t.selectionDone)
244
+ return
245
+ }
246
+
247
+ const beat = seconds / (60 / tempo)
248
+ setPlayhead(beat)
249
+ rafRef.current = requestAnimationFrame(tick)
250
+ }
251
+ rafRef.current = requestAnimationFrame(tick)
252
+ }
253
+
254
+ const stopPlayback = () => {
255
+ Tone.Transport.stop()
256
+ Tone.Transport.cancel()
257
+ partRef.current?.dispose()
258
+ partRef.current = null
259
+ setIsPlaying(false)
260
+ setPlayhead(0)
261
+ if (audioRef.current) {
262
+ audioRef.current.pause()
263
+ audioRef.current.currentTime = 0
264
+ }
265
+ if (rafRef.current) {
266
+ cancelAnimationFrame(rafRef.current)
267
+ rafRef.current = null
268
+ }
269
+ }
270
+
271
+ const pausePlayback = () => {
272
+ Tone.Transport.stop()
273
+ partRef.current?.dispose()
274
+ partRef.current = null
275
+ setIsPlaying(false)
276
+ if (audioRef.current) {
277
+ audioRef.current.pause()
278
+ }
279
+ if (rafRef.current) {
280
+ cancelAnimationFrame(rafRef.current)
281
+ rafRef.current = null
282
+ }
283
+ }
284
+
285
+ const handlePlayToggle = async () => {
286
+ if (isPlaying) {
287
+ pausePlayback()
288
+ setStatus(t.paused)
289
+ } else {
290
+ await schedulePlayback()
291
+ }
292
+ }
293
+
294
+ const handleImportClick = () => fileInputRef.current?.click()
295
+ const handleAudioImportClick = () => audioInputRef.current?.click()
296
+
297
+ const handleFileChange = async (event: React.ChangeEvent<HTMLInputElement>) => {
298
+ const file = event.target.files?.[0]
299
+ if (!file) return
300
+
301
+ try {
302
+ const snapshot = await importMidiFile(file)
303
+ setNotes(snapshot.notes)
304
+ setTempo(snapshot.tempo)
305
+ setTimeSignature(snapshot.timeSignature as TimeSignature)
306
+ setPpq(snapshot.ppq) // Preserve original ppq for accurate export
307
+ setStatus(t.imported(file.name))
308
+ } catch (error) {
309
+ console.error(error)
310
+ setStatus(t.importFailed)
311
+ } finally {
312
+ event.target.value = ''
313
+ }
314
+ }
315
+
316
+ const handleAudioChange = (event: React.ChangeEvent<HTMLInputElement>) => {
317
+ const file = event.target.files?.[0]
318
+ if (!file) return
319
+
320
+ // Validate audio file type
321
+ const validAudioTypes = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/flac', 'audio/mp4', 'audio/aac', 'audio/x-m4a']
322
+ const validExtensions = ['.mp3', '.wav', '.ogg', '.flac', '.m4a', '.aac']
323
+ const fileName = file.name.toLowerCase()
324
+ const isValidType = validAudioTypes.includes(file.type) || file.type.startsWith('audio/')
325
+ const isValidExtension = validExtensions.some(ext => fileName.endsWith(ext))
326
+
327
+ if (!isValidType && !isValidExtension) {
328
+ setStatus(t.unsupportedFormat(validExtensions.join(', ')))
329
+ event.target.value = ''
330
+ return
331
+ }
332
+
333
+ const url = URL.createObjectURL(file)
334
+ setAudioUrl(url)
335
+ setStatus(t.audioImported(file.name))
336
+ event.target.value = ''
337
+ }
338
+
339
+ // Fix overlapping notes by trimming the first note to end where the second begins
340
+ // Returns the number of fixed overlaps
341
+ const fixOverlaps = (): number => {
342
+ const sortedNotes = [...notes].sort((a, b) => a.start - b.start)
343
+ let fixCount = 0
344
+
345
+ for (let i = 0; i < sortedNotes.length - 1; i++) {
346
+ const noteA = sortedNotes[i]
347
+ const noteB = sortedNotes[i + 1]
348
+ const noteAEnd = noteA.start + noteA.duration
349
+
350
+ // If noteA overlaps with noteB
351
+ if (noteAEnd > noteB.start) {
352
+ // Trim noteA to end at noteB's start
353
+ const newDuration = Math.max(0.01, noteB.start - noteA.start)
354
+ updateNote(noteA.id, { duration: newDuration })
355
+ fixCount++
356
+ }
357
+ }
358
+
359
+ return fixCount
360
+ }
361
+
362
+ // UI handler for fix overlaps button
363
+ const handleFixOverlaps = () => {
364
+ const fixCount = fixOverlaps()
365
+ if (fixCount > 0) {
366
+ setStatus(t.fixedOverlaps(fixCount))
367
+ } else {
368
+ setStatus(t.noOverlaps)
369
+ }
370
+ }
371
+
372
+ const handleExport = () => {
373
+ // Auto-fix overlaps before export
374
+ fixOverlaps()
375
+
376
+ // Get the latest notes from store (after fix, zustand set is synchronous)
377
+ const latestNotes = useMidiStore.getState().notes
378
+
379
+ const blob = exportMidi({ notes: latestNotes, tempo, timeSignature, ppq })
380
+ const url = URL.createObjectURL(blob)
381
+ const anchor = document.createElement('a')
382
+ anchor.href = url
383
+ anchor.download = 'vocal-midi.mid'
384
+ anchor.click()
385
+ URL.revokeObjectURL(url)
386
+ setStatus(t.exported)
387
+ }
388
+
389
+ const handleTranspose = (semitones: number) => {
390
+ if (semitones === 0 || !notes.length) return
391
+ for (const note of notes) {
392
+ const newMidi = Math.max(0, Math.min(127, note.midi + semitones))
393
+ updateNote(note.id, { midi: newMidi })
394
+ }
395
+ setStatus(t.transposed(semitones))
396
+ }
397
+
398
+ return (
399
+ <div className="app-shell">
400
+ <header className="topbar">
401
+ <div>
402
+ <p className="eyebrow">{t.eyebrow}</p>
403
+ <h1>{t.title}</h1>
404
+ <p className="muted">{t.subtitle}</p>
405
+ </div>
406
+ <div className="actions">
407
+ <button className="primary" onClick={handleImportClick}>
408
+ {t.importMidi}
409
+ </button>
410
+ <button className="primary" onClick={handleExport}>
411
+ {t.exportMidi}
412
+ </button>
413
+ <div className="transpose-group" title={t.transposeTooltip}>
414
+ <select
415
+ className="transpose-select"
416
+ value={0}
417
+ onChange={(e) => {
418
+ const val = Number(e.target.value)
419
+ if (val !== 0) handleTranspose(val)
420
+ e.target.value = '0'
421
+ }}
422
+ >
423
+ <option value={0}>{t.transpose}</option>
424
+ {Array.from({ length: 24 }, (_, i) => i - 12)
425
+ .filter(v => v !== 0)
426
+ .reverse()
427
+ .map(v => (
428
+ <option key={v} value={v}>
429
+ {v > 0 ? `+${v}` : v}
430
+ </option>
431
+ ))}
432
+ </select>
433
+ </div>
434
+ <button className="soft" onClick={handleFixOverlaps} title={t.fixOverlapsTooltip}>
435
+ {t.fixOverlaps}
436
+ </button>
437
+ <button className="icon-toggle" onClick={() => setTheme(theme === 'dark' ? 'light' : 'dark')}>
438
+ {theme === 'dark' ? (
439
+ <span className="icon" aria-label={t.switchToLight}>
440
+ ☀️
441
+ </span>
442
+ ) : (
443
+ <span className="icon" aria-label={t.switchToDark}>
444
+ 🌙
445
+ </span>
446
+ )}
447
+ </button>
448
+ <button
449
+ className="icon-toggle"
450
+ onClick={() => setLang(lang === 'zh' ? 'en' : 'zh')}
451
+ title={lang === 'zh' ? 'Switch to English' : '切换到中文'}
452
+ >
453
+ <span className="lang-label">{lang === 'zh' ? 'EN' : '中'}</span>
454
+ </button>
455
+ <input ref={fileInputRef} type="file" accept=".mid,.midi" className="sr-only" onChange={handleFileChange} />
456
+ </div>
457
+ </header>
458
+
459
+ <section className="audio-bar">
460
+ <div className="audio-left">
461
+ <button className="ghost" onClick={handleAudioImportClick}>
462
+ {t.importAudio}
463
+ </button>
464
+ <input
465
+ ref={audioInputRef}
466
+ type="file"
467
+ accept=".mp3,.wav,.ogg,.flac,.m4a,.aac"
468
+ className="sr-only"
469
+ onChange={handleAudioChange}
470
+ />
471
+ <span className="audio-hint">{t.audioHint}</span>
472
+ </div>
473
+ <div className="audio-right">
474
+ <div className="volume-control">
475
+ <span className="volume-label">{t.midiLabel}</span>
476
+ <input
477
+ type="range"
478
+ min={0}
479
+ max={100}
480
+ value={midiVolume}
481
+ onChange={(e) => setMidiVolume(Number(e.target.value))}
482
+ className="volume-slider"
483
+ />
484
+ <span className="volume-value">{midiVolume}%</span>
485
+ </div>
486
+ <div className="volume-control">
487
+ <span className="volume-label">{t.audioLabel}</span>
488
+ <input
489
+ type="range"
490
+ min={0}
491
+ max={100}
492
+ value={audioVolume}
493
+ onChange={(e) => setAudioVolume(Number(e.target.value))}
494
+ className="volume-slider"
495
+ />
496
+ <span className="volume-value">{audioVolume}%</span>
497
+ </div>
498
+ </div>
499
+ </section>
500
+
501
+ <section className="panel panel-split">
502
+ <div className="panel-main">
503
+ {audioUrl && (
504
+ <AudioTrack
505
+ key={audioUrl}
506
+ ref={audioScrollRef}
507
+ audioUrl={audioUrl}
508
+ muted={audioVolume === 0}
509
+ onSeek={(seconds) => seekToBeat(secondsToBeat(seconds))}
510
+ playheadSeconds={beatToSeconds(playhead)}
511
+ gridSecondWidth={gridSecondWidth}
512
+ minContentWidth={midiContentWidth}
513
+ />
514
+ )}
515
+ <PianoRoll
516
+ notes={notes}
517
+ selectedId={selectedId}
518
+ timeSignature={timeSignature}
519
+ tempo={tempo}
520
+ playhead={playhead}
521
+ selectionStart={selectionStart}
522
+ selectionEnd={selectionEnd}
523
+ onAddNote={addNote}
524
+ onSelect={select}
525
+ onUpdateNote={updateNote}
526
+ onSeek={seekToBeat}
527
+ onScroll={(left) => {
528
+ if (audioScrollRef.current) {
529
+ audioScrollRef.current.scrollLeft = left
530
+ }
531
+ }}
532
+ onZoom={(deltaH, deltaV) => {
533
+ if (deltaH !== 0) {
534
+ setHorizontalZoom(prev => Math.max(0.5, prev + deltaH))
535
+ }
536
+ if (deltaV !== 0) {
537
+ setVerticalZoom(prev => Math.max(0.6, Math.min(2.5, prev + deltaV)))
538
+ }
539
+ }}
540
+ onPlayNote={playPreviewNote}
541
+ onFocusLyric={(noteId) => {
542
+ select(noteId)
543
+ setFocusLyricId(noteId)
544
+ }}
545
+ onSelectionChange={(start, end) => {
546
+ setSelectionStart(start)
547
+ setSelectionEnd(end)
548
+ }}
549
+ isSelectingRange={isSelectingRange}
550
+ audioDuration={audioDuration}
551
+ gridSecondWidth={gridSecondWidth}
552
+ rowHeight={rowHeight}
553
+ />
554
+ </div>
555
+ <aside className="panel-side">
556
+ <div className="controls">
557
+ <div className="toggle" style={{ justifyContent: 'space-between' }}>
558
+ <span>{t.horizontalZoom}</span>
559
+ <input
560
+ type="range"
561
+ min={0.5}
562
+ max={10}
563
+ step={0.1}
564
+ value={Math.min(horizontalZoom, 10)}
565
+ onChange={(e) => setHorizontalZoom(Number(e.target.value))}
566
+ style={{ width: '140px' }}
567
+ />
568
+ <span style={{ width: 42, textAlign: 'right' }}>{horizontalZoom.toFixed(1)}x</span>
569
+ </div>
570
+ <div className="toggle" style={{ justifyContent: 'space-between' }}>
571
+ <span>{t.verticalZoom}</span>
572
+ <input
573
+ type="range"
574
+ min={0.6}
575
+ max={2.5}
576
+ step={0.1}
577
+ value={verticalZoom}
578
+ onChange={(e) => setVerticalZoom(Number(e.target.value))}
579
+ style={{ width: '140px' }}
580
+ />
581
+ <span style={{ width: 42, textAlign: 'right' }}>{verticalZoom.toFixed(1)}x</span>
582
+ </div>
583
+ <div className="transport">
584
+ <button
585
+ className="soft"
586
+ onClick={() => {
587
+ setPlayhead(0)
588
+ seekToBeat(0)
589
+ }}
590
+ title={t.goToStart}
591
+ >
592
+
593
+ </button>
594
+ <button
595
+ className="soft"
596
+ onClick={() => seekBySeconds(-2)}
597
+ title={t.back2s}
598
+ >
599
+
600
+ </button>
601
+ <button
602
+ className="primary"
603
+ onClick={handlePlayToggle}
604
+ disabled={!notes.length && !audioUrl}
605
+ title={isPlaying ? t.pause : (selectionStart !== null && selectionEnd !== null ? t.playSelection : t.play)}
606
+ >
607
+ {isPlaying ? '⏸' : '▶'}
608
+ </button>
609
+ <button
610
+ className="soft"
611
+ onClick={() => seekBySeconds(2)}
612
+ title={t.forward2s}
613
+ >
614
+
615
+ </button>
616
+ <button
617
+ className="soft"
618
+ onClick={() => {
619
+ const maxNoteEnd = notes.reduce((acc, n) => Math.max(acc, n.start + n.duration), 0)
620
+ seekToBeat(Math.max(secondsToBeat(audioDuration), maxNoteEnd))
621
+ }}
622
+ title={t.goToEnd}
623
+ >
624
+
625
+ </button>
626
+ <span className="transport-divider" />
627
+ <button
628
+ className={`soft selection-btn ${isSelectingRange ? 'active' : ''}`}
629
+ onClick={() => {
630
+ if (isSelectingRange) {
631
+ // Exiting selection mode - auto clear selection
632
+ setIsSelectingRange(false)
633
+ setSelectionStart(null)
634
+ setSelectionEnd(null)
635
+ } else {
636
+ setIsSelectingRange(true)
637
+ }
638
+ }}
639
+ title={isSelectingRange ? t.exitSelectMode : t.setRangeTooltip}
640
+ >
641
+ {isSelectingRange ? `📍 ${t.selectingRange}` : `📍 ${t.setRange}`}
642
+ </button>
643
+ </div>
644
+ <div className="status">{status}</div>
645
+ </div>
646
+ <div className="lyric-container">
647
+ <LyricTable
648
+ notes={notes}
649
+ selectedId={selectedId}
650
+ tempo={tempo}
651
+ focusLyricId={focusLyricId}
652
+ lang={lang}
653
+ onSelect={select}
654
+ onUpdate={updateNote}
655
+ onFocusHandled={() => setFocusLyricId(null)}
656
+ />
657
+ </div>
658
+ </aside>
659
+ </section>
660
+ <audio
661
+ ref={audioRef}
662
+ src={audioUrl ?? undefined}
663
+ preload="auto"
664
+ className="sr-only"
665
+ onLoadedMetadata={(e) => {
666
+ setAudioDuration(e.currentTarget.duration)
667
+ // Ensure volume is set when audio loads
668
+ e.currentTarget.volume = audioVolume / 100
669
+ }}
670
+ />
671
+ </div>
672
+ )
673
+ }
674
+
675
+ export default App
preprocess/tools/midi_editor/src/components/AudioTrack.tsx ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useRef, forwardRef, useState } from 'react'
2
+ import WaveSurfer from 'wavesurfer.js'
3
+ import { PITCH_WIDTH } from '../constants'
4
+
5
+ export type AudioTrackProps = {
6
+ audioUrl: string | null
7
+ muted: boolean
8
+ onSeek: (seconds: number) => void
9
+ mediaElement?: HTMLAudioElement | null
10
+ playheadSeconds: number
11
+ gridSecondWidth: number
12
+ minContentWidth?: number // Minimum width to match MIDI editor area
13
+ }
14
+
15
+ export const AudioTrack = forwardRef<HTMLDivElement, AudioTrackProps>(
16
+ ({ audioUrl, muted, onSeek, playheadSeconds, gridSecondWidth, minContentWidth = 0 }, ref) => {
17
+ const containerRef = useRef<HTMLDivElement | null>(null)
18
+ const waveRef = useRef<WaveSurfer | null>(null)
19
+ const [waveWidth, setWaveWidth] = useState(0)
20
+
21
+ useEffect(() => {
22
+ if (!containerRef.current) return
23
+ if (!audioUrl) {
24
+ try {
25
+ waveRef.current?.destroy()
26
+ } catch {
27
+ // ignore teardown errors
28
+ }
29
+ waveRef.current = null
30
+ setWaveWidth(0)
31
+ return
32
+ }
33
+
34
+ let cancelled = false
35
+
36
+ // Clean up existing instance
37
+ if (waveRef.current) {
38
+ try {
39
+ waveRef.current.destroy()
40
+ } catch {
41
+ // ignore teardown errors
42
+ }
43
+ }
44
+
45
+ waveRef.current = WaveSurfer.create({
46
+ container: containerRef.current,
47
+ waveColor: '#4b64bc',
48
+ progressColor: '#4b64bc',
49
+ cursorColor: 'transparent',
50
+ barWidth: 2,
51
+ barGap: 2,
52
+ height: 60,
53
+ normalize: true,
54
+ minPxPerSec: gridSecondWidth,
55
+ interact: false,
56
+ hideScrollbar: true,
57
+ autoScroll: false,
58
+ })
59
+
60
+ waveRef.current.load(audioUrl).catch(() => null)
61
+ waveRef.current.on('error', () => null)
62
+
63
+ waveRef.current.on('ready', () => {
64
+ if (cancelled || !waveRef.current) return
65
+ const duration = waveRef.current.getDuration()
66
+ const requiredWidth = duration * gridSecondWidth
67
+ setWaveWidth(requiredWidth)
68
+ })
69
+
70
+ return () => {
71
+ cancelled = true
72
+ try {
73
+ waveRef.current?.destroy()
74
+ } catch {
75
+ // ignore teardown errors
76
+ }
77
+ waveRef.current = null
78
+ }
79
+ }, [audioUrl, gridSecondWidth])
80
+
81
+ useEffect(() => {
82
+ if (!waveRef.current) return
83
+ waveRef.current.setOptions({
84
+ waveColor: muted ? '#9aa6b2' : '#4b64bc',
85
+ progressColor: muted ? '#c0c9d4' : '#4b64bc',
86
+ })
87
+ }, [muted])
88
+
89
+ if (!audioUrl) return null
90
+
91
+ // Content width should be at least as wide as MIDI editor
92
+ const contentWidth = Math.max(waveWidth, minContentWidth)
93
+
94
+ return (
95
+ <div
96
+ className="audio-track-row"
97
+ style={{
98
+ display: 'flex',
99
+ borderBottom: '1px solid var(--border-soft)',
100
+ height: '70px',
101
+ flexShrink: 0
102
+ }}
103
+ >
104
+ <div
105
+ className="audio-gutter"
106
+ style={{
107
+ width: PITCH_WIDTH,
108
+ flexShrink: 0,
109
+ background: 'var(--panel-strong)',
110
+ borderRight: '1px solid var(--border-subtle)',
111
+ display: 'flex',
112
+ alignItems: 'center',
113
+ justifyContent: 'center',
114
+ fontSize: '11px',
115
+ color: 'var(--text-muted)',
116
+ fontWeight: 600,
117
+ }}
118
+ >
119
+ AUDIO
120
+ </div>
121
+
122
+ {/* Scroll Mask - Controlled by parent via ref */}
123
+ <div
124
+ ref={ref}
125
+ className="audio-scroll-mask"
126
+ style={{
127
+ flex: 1,
128
+ overflow: 'hidden',
129
+ position: 'relative',
130
+ background: 'var(--panel-soft)',
131
+ }}
132
+ onClick={(e) => {
133
+ const rect = e.currentTarget.getBoundingClientRect()
134
+ const scrollMask = e.currentTarget as HTMLDivElement
135
+ const x = e.clientX - rect.left + scrollMask.scrollLeft
136
+ const seconds = x / gridSecondWidth
137
+ onSeek(seconds)
138
+ }}
139
+ >
140
+ {/* Container that matches MIDI editor width */}
141
+ <div
142
+ className="audio-content"
143
+ style={{
144
+ width: contentWidth > 0 ? contentWidth : '100%',
145
+ height: '100%',
146
+ position: 'relative'
147
+ }}
148
+ >
149
+ {/* WaveSurfer container - only as wide as audio */}
150
+ <div
151
+ ref={containerRef}
152
+ className="wave-container"
153
+ style={{
154
+ width: waveWidth > 0 ? waveWidth : '100%',
155
+ height: '100%',
156
+ position: 'absolute',
157
+ left: 0,
158
+ top: 0
159
+ }}
160
+ />
161
+
162
+ {/* Custom Playhead */}
163
+ <div
164
+ className="audio-playhead"
165
+ style={{
166
+ position: 'absolute',
167
+ top: 0,
168
+ bottom: 0,
169
+ width: '2px',
170
+ background: '#ff7043',
171
+ boxShadow: '0 0 12px rgba(255, 112, 67, 0.6)',
172
+ left: playheadSeconds * gridSecondWidth,
173
+ zIndex: 10,
174
+ pointerEvents: 'none',
175
+ }}
176
+ />
177
+ </div>
178
+ </div>
179
+ </div>
180
+ )
181
+ }
182
+ )
preprocess/tools/midi_editor/src/components/LyricTable.tsx ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useMemo, useRef, useState } from 'react'
2
+ import type { NoteEvent } from '../types'
3
+ import type { Lang } from '../i18n'
4
+ import { getTranslations, tokenizeLyrics } from '../i18n'
5
+
6
+ export type LyricTableProps = {
7
+ notes: NoteEvent[]
8
+ selectedId: string | null
9
+ tempo: number
10
+ focusLyricId: string | null
11
+ lang: Lang
12
+ onSelect: (id: string | null) => void
13
+ onUpdate: (id: string, patch: Partial<NoteEvent>) => void
14
+ onScrollToNote?: (noteId: string) => void
15
+ onFocusHandled?: () => void
16
+ }
17
+
18
+ const formatSeconds = (beats: number, tempo: number) => {
19
+ const seconds = beats * (60 / tempo)
20
+ return Number.parseFloat(seconds.toFixed(2))
21
+ }
22
+
23
+ const secondsToBeats = (seconds: number, tempo: number) => {
24
+ return seconds * (tempo / 60)
25
+ }
26
+
27
+ // Editable cell with confirmation
28
+ function EditableCell({
29
+ value,
30
+ noteId,
31
+ field,
32
+ tempo,
33
+ onConfirm,
34
+ confirmTitle,
35
+ type = 'number',
36
+ min,
37
+ step
38
+ }: {
39
+ value: number
40
+ noteId: string
41
+ field: 'midi' | 'start' | 'end'
42
+ tempo: number
43
+ onConfirm: (noteId: string, field: string, value: number) => void
44
+ confirmTitle?: string
45
+ type?: string
46
+ min?: number
47
+ step?: number
48
+ }) {
49
+ const displayValue = field === 'midi' ? value : formatSeconds(value, tempo)
50
+ const [localValue, setLocalValue] = useState(String(displayValue))
51
+ const [isDirty, setIsDirty] = useState(false)
52
+ const inputRef = useRef<HTMLInputElement>(null)
53
+
54
+ // Sync with external value when it changes (and not dirty)
55
+ useEffect(() => {
56
+ if (!isDirty) {
57
+ setLocalValue(String(displayValue))
58
+ }
59
+ }, [displayValue, isDirty])
60
+
61
+ const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
62
+ setLocalValue(e.target.value)
63
+ setIsDirty(true)
64
+ }
65
+
66
+ const handleConfirm = () => {
67
+ const parsed = parseFloat(localValue)
68
+ if (!isNaN(parsed)) {
69
+ if (field === 'midi') {
70
+ if (parsed >= 0 && parsed <= 127) {
71
+ onConfirm(noteId, field, Math.round(parsed))
72
+ }
73
+ } else {
74
+ if (parsed >= 0) {
75
+ onConfirm(noteId, field, secondsToBeats(parsed, tempo))
76
+ }
77
+ }
78
+ }
79
+ setIsDirty(false)
80
+ }
81
+
82
+ const handleKeyDown = (e: React.KeyboardEvent) => {
83
+ if (e.key === 'Enter') {
84
+ e.preventDefault()
85
+ handleConfirm()
86
+ inputRef.current?.blur()
87
+ } else if (e.key === 'Escape') {
88
+ setLocalValue(String(displayValue))
89
+ setIsDirty(false)
90
+ inputRef.current?.blur()
91
+ }
92
+ }
93
+
94
+ const handleBlur = () => {
95
+ if (isDirty) {
96
+ // Reset to original on blur without confirm
97
+ setLocalValue(String(displayValue))
98
+ setIsDirty(false)
99
+ }
100
+ }
101
+
102
+ return (
103
+ <div className="editable-cell">
104
+ <input
105
+ ref={inputRef}
106
+ className={`lyric-meta-input ${isDirty ? 'lyric-meta-dirty' : ''}`}
107
+ type={type}
108
+ min={min}
109
+ step={step}
110
+ value={localValue}
111
+ onChange={handleChange}
112
+ onKeyDown={handleKeyDown}
113
+ onBlur={handleBlur}
114
+ onClick={(e) => e.stopPropagation()}
115
+ />
116
+ {isDirty && (
117
+ <button
118
+ className="confirm-btn"
119
+ onMouseDown={(e) => {
120
+ e.preventDefault() // Prevent input blur
121
+ e.stopPropagation()
122
+ }}
123
+ onClick={(e) => {
124
+ e.stopPropagation()
125
+ handleConfirm()
126
+ }}
127
+ title={confirmTitle}
128
+ >
129
+
130
+ </button>
131
+ )}
132
+ </div>
133
+ )
134
+ }
135
+
136
+ export function LyricTable({ notes, selectedId, tempo, focusLyricId, lang, onSelect, onUpdate, onScrollToNote, onFocusHandled }: LyricTableProps) {
137
+ const t = getTranslations(lang)
138
+ const listRef = useRef<HTMLDivElement | null>(null)
139
+ const inputRefs = useRef<Map<string, HTMLInputElement>>(new Map())
140
+ const sorted = useMemo(() => [...notes].sort((a, b) => a.start - b.start), [notes])
141
+
142
+ // Scroll to selected note (no auto-focus on single click)
143
+ useEffect(() => {
144
+ if (!selectedId || !listRef.current) return
145
+
146
+ const target = listRef.current.querySelector<HTMLDivElement>(`[data-note-id="${selectedId}"]`)
147
+ if (target) {
148
+ target.scrollIntoView({ block: 'nearest', behavior: 'smooth' })
149
+ }
150
+ }, [selectedId])
151
+
152
+ // Focus lyric input when requested (double-click on note or click on list row)
153
+ useEffect(() => {
154
+ if (!focusLyricId) return
155
+
156
+ const input = inputRefs.current.get(focusLyricId)
157
+ if (input) {
158
+ setTimeout(() => {
159
+ input.focus()
160
+ input.select()
161
+ }, 50)
162
+ }
163
+ onFocusHandled?.()
164
+ }, [focusLyricId, onFocusHandled])
165
+
166
+ // Fill lyrics from selected note onwards
167
+ // Uses smart tokenizer: CJK chars -> one per note, English words -> one per note
168
+ const handleBulkFill = (bulkText: string) => {
169
+ if (!sorted.length) return
170
+ const tokens = tokenizeLyrics(bulkText)
171
+ if (!tokens.length) return
172
+
173
+ let startIndex = 0
174
+ if (selectedId) {
175
+ const selectedIndex = sorted.findIndex(n => n.id === selectedId)
176
+ if (selectedIndex >= 0) {
177
+ startIndex = selectedIndex
178
+ }
179
+ }
180
+
181
+ let tokenIndex = 0
182
+ for (let i = startIndex; i < sorted.length && tokenIndex < tokens.length; i++) {
183
+ onUpdate(sorted[i].id, { lyric: tokens[tokenIndex] })
184
+ tokenIndex++
185
+ }
186
+ }
187
+
188
+ const handleRowClick = (noteId: string) => {
189
+ onSelect(noteId)
190
+ onScrollToNote?.(noteId)
191
+ }
192
+
193
+ const handleFieldConfirm = (noteId: string, field: string, value: number) => {
194
+ const note = notes.find(n => n.id === noteId)
195
+ if (!note) return
196
+
197
+ if (field === 'midi') {
198
+ onUpdate(noteId, { midi: value })
199
+ } else if (field === 'start') {
200
+ // Keep END the same, adjust duration accordingly
201
+ const currentEnd = note.start + note.duration
202
+ const newDuration = Math.max(0.01, currentEnd - value)
203
+ onUpdate(noteId, { start: value, duration: newDuration })
204
+ } else if (field === 'end') {
205
+ // End changed, update duration
206
+ const newDuration = Math.max(0.01, value - note.start)
207
+ onUpdate(noteId, { duration: newDuration })
208
+ }
209
+ }
210
+
211
+ return (
212
+ <div className="lyric-card">
213
+ <div className="lyric-bulk">
214
+ <textarea
215
+ className="lyric-bulk-input"
216
+ rows={2}
217
+ placeholder={selectedId ? t.fillPlaceholderSelected : t.fillPlaceholderDefault}
218
+ onKeyDown={(e) => {
219
+ if (e.key === 'Enter' && !e.shiftKey) {
220
+ e.preventDefault()
221
+ handleBulkFill(e.currentTarget.value)
222
+ }
223
+ }}
224
+ />
225
+ <button
226
+ className="soft"
227
+ type="button"
228
+ onClick={(e) => {
229
+ const textarea = e.currentTarget.previousElementSibling as HTMLTextAreaElement
230
+ handleBulkFill(textarea.value)
231
+ }}
232
+ >
233
+ {t.fillButton.split('\n').map((line, i) => (
234
+ <span key={i}>{line}{i === 0 && <br/>}</span>
235
+ ))}
236
+ </button>
237
+ </div>
238
+ <div className="lyric-header" style={{ flexShrink: 0 }}>
239
+ <div>LYRIC</div>
240
+ <div>PITCH</div>
241
+ <div>START</div>
242
+ <div>END</div>
243
+ </div>
244
+ <div className="lyric-list" ref={listRef}>
245
+ {sorted.map((note) => (
246
+ <div
247
+ key={note.id}
248
+ className={`lyric-row ${selectedId === note.id ? 'lyric-row-active' : ''}`}
249
+ data-note-id={note.id}
250
+ onClick={() => handleRowClick(note.id)}
251
+ >
252
+ <input
253
+ ref={(el) => {
254
+ if (el) {
255
+ inputRefs.current.set(note.id, el)
256
+ } else {
257
+ inputRefs.current.delete(note.id)
258
+ }
259
+ }}
260
+ className="lyric-input"
261
+ value={note.lyric}
262
+ placeholder={t.lyricPlaceholder}
263
+ onChange={(event) => onUpdate(note.id, { lyric: event.target.value })}
264
+ onClick={(e) => e.stopPropagation()}
265
+ />
266
+ <EditableCell
267
+ value={note.midi}
268
+ noteId={note.id}
269
+ field="midi"
270
+ tempo={tempo}
271
+ onConfirm={handleFieldConfirm}
272
+ confirmTitle={t.confirmEdit}
273
+ min={0}
274
+ />
275
+ <EditableCell
276
+ value={note.start}
277
+ noteId={note.id}
278
+ field="start"
279
+ tempo={tempo}
280
+ onConfirm={handleFieldConfirm}
281
+ confirmTitle={t.confirmEdit}
282
+ min={0}
283
+ step={0.01}
284
+ />
285
+ <EditableCell
286
+ value={note.start + note.duration}
287
+ noteId={note.id}
288
+ field="end"
289
+ tempo={tempo}
290
+ onConfirm={handleFieldConfirm}
291
+ confirmTitle={t.confirmEdit}
292
+ min={0}
293
+ step={0.01}
294
+ />
295
+ </div>
296
+ ))}
297
+ {sorted.length === 0 && <div className="lyric-empty">{t.emptyHint}</div>}
298
+ </div>
299
+ </div>
300
+ )
301
+ }
preprocess/tools/midi_editor/src/components/PianoRoll.tsx ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useMemo, useRef, useState, useCallback, memo } from 'react'
2
+ import type React from 'react'
3
+ import type { NoteEvent, TimeSignature } from '../types'
4
+ import { PITCH_WIDTH, LOW_NOTE, HIGH_NOTE } from '../constants'
5
+
6
+ const midiToName = (midi: number) => {
7
+ const names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
8
+ const octave = Math.floor(midi / 12) - 1
9
+ return `${names[midi % 12]}${octave}`
10
+ }
11
+
12
+ // Memoized note component to prevent unnecessary re-renders
13
+ const NoteChip = memo(function NoteChip({
14
+ note,
15
+ left,
16
+ top,
17
+ width,
18
+ height,
19
+ fontSize,
20
+ isSelected,
21
+ isOverlapping,
22
+ onPointerDown,
23
+ onDoubleClick,
24
+ }: {
25
+ note: NoteEvent
26
+ left: number
27
+ top: number
28
+ width: number
29
+ height: number
30
+ fontSize: number
31
+ isSelected: boolean
32
+ isOverlapping: boolean
33
+ onPointerDown: (event: React.PointerEvent<HTMLDivElement>, mode: 'move' | 'resize-start' | 'resize-end') => void
34
+ onDoubleClick: (event: React.MouseEvent<HTMLDivElement>) => void
35
+ }) {
36
+ return (
37
+ <div
38
+ className={`note-chip ${isSelected ? 'note-active' : ''} ${isOverlapping ? 'note-overlap' : ''}`}
39
+ style={{
40
+ left,
41
+ top: top + 1,
42
+ width,
43
+ height,
44
+ willChange: 'transform', // GPU acceleration hint
45
+ }}
46
+ onPointerDown={(e) => onPointerDown(e, 'move')}
47
+ onDoubleClick={onDoubleClick}
48
+ >
49
+ <div className="note-label" style={{ fontSize }}>
50
+ <span>{note.lyric || '\u00a0'}</span>
51
+ </div>
52
+ <div className="note-handle start" onPointerDown={(e) => { e.stopPropagation(); onPointerDown(e, 'resize-start') }} />
53
+ <div className="note-handle end" onPointerDown={(e) => { e.stopPropagation(); onPointerDown(e, 'resize-end') }} />
54
+ </div>
55
+ )
56
+ })
57
+
58
+ // Dynamic snap based on zoom level - higher zoom = finer snap
59
+ const getSnapSeconds = (gridSecondWidth: number) => {
60
+ // At base width (80px/s), snap is 0.1s
61
+ // At 2x zoom (160px/s), snap is 0.05s
62
+ // At 4x zoom (320px/s), snap is 0.025s
63
+ // At 8x zoom (640px/s), snap is 0.01s
64
+ const baseSnap = 0.1
65
+ const zoomFactor = gridSecondWidth / 80
66
+ return Math.max(0.01, baseSnap / zoomFactor)
67
+ }
68
+
69
+ const snapSeconds = (value: number, gridSecondWidth: number) => {
70
+ const snap = getSnapSeconds(gridSecondWidth)
71
+ return Math.max(0, Math.round(value / snap) * snap)
72
+ }
73
+
74
+ export type PianoRollProps = {
75
+ notes: NoteEvent[]
76
+ selectedId: string | null
77
+ timeSignature: TimeSignature
78
+ tempo: number
79
+ playhead: number // in beats
80
+ selectionStart: number | null // in seconds
81
+ selectionEnd: number | null // in seconds
82
+ onAddNote: (note: Partial<NoteEvent>) => NoteEvent
83
+ onUpdateNote: (id: string, patch: Partial<NoteEvent>) => void
84
+ onSelect: (id: string | null) => void
85
+ onSeek: (beat: number) => void
86
+ onScroll?: (left: number) => void
87
+ onZoom?: (deltaH: number, deltaV: number) => void
88
+ onPlayNote?: (midi: number) => void
89
+ onFocusLyric?: (noteId: string) => void
90
+ onSelectionChange?: (start: number | null, end: number | null) => void
91
+ isSelectingRange?: boolean
92
+ audioDuration?: number
93
+ gridSecondWidth: number
94
+ rowHeight: number
95
+ }
96
+
97
+ export function PianoRoll({
98
+ notes,
99
+ selectedId,
100
+ timeSignature: _timeSignature,
101
+ tempo,
102
+ playhead,
103
+ selectionStart,
104
+ selectionEnd,
105
+ onAddNote,
106
+ onSelect,
107
+ onUpdateNote,
108
+ onSeek,
109
+ onScroll,
110
+ onZoom,
111
+ onPlayNote,
112
+ onFocusLyric,
113
+ onSelectionChange,
114
+ isSelectingRange = false,
115
+ audioDuration = 0,
116
+ gridSecondWidth,
117
+ rowHeight
118
+ }: PianoRollProps) {
119
+ const scrollContainerRef = useRef<HTMLDivElement | null>(null)
120
+ const rulerScrollRef = useRef<HTMLDivElement | null>(null)
121
+ const [scrollTop, setScrollTop] = useState(0)
122
+ const [scrollLeft, setScrollLeft] = useState(0)
123
+ const [viewportWidth, setViewportWidth] = useState(800)
124
+ const [viewportHeight, setViewportHeight] = useState(400)
125
+ const dragRef = useRef<{
126
+ id: string
127
+ mode: 'move' | 'resize-start' | 'resize-end'
128
+ originX: number
129
+ originY: number
130
+ startSeconds: number
131
+ durationSeconds: number
132
+ midi: number
133
+ lastMidi?: number // Track last midi for pitch change sound
134
+ } | null>(null)
135
+
136
+ // Selection drag state
137
+ const selectionDragRef = useRef<{
138
+ startX: number
139
+ startSeconds: number
140
+ } | null>(null)
141
+
142
+ // Store callbacks in refs to avoid stale closures in event handlers
143
+ const onPlayNoteRef = useRef(onPlayNote)
144
+ const onUpdateNoteRef = useRef(onUpdateNote)
145
+
146
+ useEffect(() => {
147
+ onPlayNoteRef.current = onPlayNote
148
+ onUpdateNoteRef.current = onUpdateNote
149
+ }, [onPlayNote, onUpdateNote])
150
+
151
+ // Conversion helpers
152
+ const beatToSeconds = useCallback((beat: number) => beat * (60 / tempo), [tempo])
153
+ const secondsToBeat = useCallback((seconds: number) => seconds / (60 / tempo), [tempo])
154
+
155
+ // Calculate dimensions
156
+ const totalRows = HIGH_NOTE - LOW_NOTE + 1
157
+ const contentHeight = totalRows * rowHeight
158
+ const [containerWidth, setContainerWidth] = useState(1200)
159
+
160
+ // Track container size
161
+ useEffect(() => {
162
+ const container = scrollContainerRef.current
163
+ if (!container) return
164
+
165
+ const observer = new ResizeObserver((entries) => {
166
+ for (const entry of entries) {
167
+ setContainerWidth(entry.contentRect.width)
168
+ setViewportWidth(entry.contentRect.width)
169
+ setViewportHeight(entry.contentRect.height)
170
+ }
171
+ })
172
+ observer.observe(container)
173
+ return () => observer.disconnect()
174
+ }, [])
175
+
176
+ const maxSeconds = useMemo(() => {
177
+ const noteEndSeconds = notes.reduce((acc, n) => {
178
+ const endBeat = n.start + n.duration
179
+ return Math.max(acc, beatToSeconds(endBeat))
180
+ }, 8)
181
+ // Ensure grid extends at least 2x the visible area for smoother scrolling
182
+ const minSecondsForView = (containerWidth / gridSecondWidth) * 2
183
+ return Math.max(noteEndSeconds + 10, audioDuration + 10, minSecondsForView, 30)
184
+ }, [notes, audioDuration, beatToSeconds, containerWidth, gridSecondWidth])
185
+
186
+ const contentWidth = maxSeconds * gridSecondWidth
187
+
188
+ // Drag handlers - use refs to avoid stale closure issues
189
+ const handlePointerMove = useCallback((event: PointerEvent) => {
190
+ const drag = dragRef.current
191
+ if (!drag) return
192
+
193
+ const dxSeconds = (event.clientX - drag.originX) / gridSecondWidth
194
+ const dy = (event.clientY - drag.originY) / rowHeight
195
+
196
+ if (drag.mode === 'move') {
197
+ const nextSeconds = snapSeconds(drag.startSeconds + dxSeconds, gridSecondWidth)
198
+ const nextMidi = Math.min(HIGH_NOTE, Math.max(LOW_NOTE, Math.round(drag.midi - dy)))
199
+
200
+ // Play sound when pitch changes
201
+ if (nextMidi !== drag.lastMidi && onPlayNoteRef.current) {
202
+ onPlayNoteRef.current(nextMidi)
203
+ drag.lastMidi = nextMidi
204
+ }
205
+
206
+ onUpdateNoteRef.current(drag.id, {
207
+ start: secondsToBeat(nextSeconds),
208
+ midi: nextMidi
209
+ })
210
+ }
211
+
212
+ if (drag.mode === 'resize-start') {
213
+ const nextSeconds = snapSeconds(drag.startSeconds + dxSeconds, gridSecondWidth)
214
+ const delta = drag.startSeconds - nextSeconds
215
+ const nextDurationSeconds = Math.max(0.05, drag.durationSeconds + delta)
216
+ onUpdateNoteRef.current(drag.id, {
217
+ start: secondsToBeat(nextSeconds),
218
+ duration: secondsToBeat(nextDurationSeconds)
219
+ })
220
+ }
221
+
222
+ if (drag.mode === 'resize-end') {
223
+ const nextDurationSeconds = Math.max(0.05, snapSeconds(drag.durationSeconds + dxSeconds, gridSecondWidth))
224
+ onUpdateNoteRef.current(drag.id, { duration: secondsToBeat(nextDurationSeconds) })
225
+ }
226
+ }, [gridSecondWidth, rowHeight, secondsToBeat])
227
+
228
+ const handlePointerUp = useCallback(() => {
229
+ dragRef.current = null
230
+ window.removeEventListener('pointermove', handlePointerMove)
231
+ window.removeEventListener('pointerup', handlePointerUp)
232
+ }, [handlePointerMove])
233
+
234
+ useEffect(() => {
235
+ return () => {
236
+ window.removeEventListener('pointermove', handlePointerMove)
237
+ window.removeEventListener('pointerup', handlePointerUp)
238
+ }
239
+ }, [handlePointerMove, handlePointerUp])
240
+
241
+ // Scroll sync
242
+ useEffect(() => {
243
+ const container = scrollContainerRef.current
244
+ const ruler = rulerScrollRef.current
245
+ if (!container || !ruler) return
246
+
247
+ const handleScroll = () => {
248
+ ruler.scrollLeft = container.scrollLeft
249
+ setScrollTop(container.scrollTop)
250
+ setScrollLeft(container.scrollLeft)
251
+ if (onScroll) onScroll(container.scrollLeft)
252
+ }
253
+
254
+ container.addEventListener('scroll', handleScroll)
255
+ return () => container.removeEventListener('scroll', handleScroll)
256
+ }, [onScroll])
257
+
258
+ // Zoom support via wheel/trackpad
259
+ // Mac: Cmd+滚轮 (水平缩放), Cmd+Shift+滚轮 (垂直缩放), 或双指捏合
260
+ // Windows/Linux: Ctrl+滚轮 (水平缩放), Ctrl+Shift+滚轮 (垂直缩放)
261
+ useEffect(() => {
262
+ const container = scrollContainerRef.current
263
+ if (!container || !onZoom) return
264
+
265
+ const handleWheel = (e: WheelEvent) => {
266
+ // Ctrl (Windows/Linux/捏合) or Cmd (Mac) triggers zoom
267
+ const isZoomTrigger = e.ctrlKey || e.metaKey
268
+
269
+ if (isZoomTrigger) {
270
+ e.preventDefault()
271
+ e.stopPropagation()
272
+
273
+ // Use deltaY for zoom amount, normalize for different input methods
274
+ // Pinch gestures typically have smaller delta values
275
+ let delta = -e.deltaY
276
+ if (Math.abs(delta) > 10) {
277
+ // Likely a mouse wheel, scale down
278
+ delta = delta * 0.01
279
+ } else {
280
+ // Likely a trackpad pinch, scale appropriately
281
+ delta = delta * 0.05
282
+ }
283
+
284
+ // Shift or Alt/Option for vertical zoom, otherwise horizontal
285
+ if (e.shiftKey || e.altKey) {
286
+ onZoom(0, delta)
287
+ } else {
288
+ onZoom(delta, 0)
289
+ }
290
+ }
291
+ }
292
+
293
+ container.addEventListener('wheel', handleWheel, { passive: false })
294
+ return () => container.removeEventListener('wheel', handleWheel)
295
+ }, [onZoom])
296
+
297
+ // Playhead auto-scroll
298
+ useEffect(() => {
299
+ if (!scrollContainerRef.current) return
300
+ const container = scrollContainerRef.current
301
+ const playheadX = beatToSeconds(playhead) * gridSecondWidth
302
+ const viewStart = container.scrollLeft
303
+ const viewEnd = viewStart + container.clientWidth
304
+
305
+ if (playheadX > viewEnd) {
306
+ container.scrollLeft = playheadX
307
+ } else if (playheadX < viewStart) {
308
+ container.scrollLeft = playheadX
309
+ }
310
+ }, [playhead, gridSecondWidth, beatToSeconds])
311
+
312
+ // Selection auto-scroll
313
+ useEffect(() => {
314
+ if (!scrollContainerRef.current || !selectedId) return
315
+ const note = notes.find((n) => n.id === selectedId)
316
+ if (!note) return
317
+ const container = scrollContainerRef.current
318
+ const noteX = beatToSeconds(note.start) * gridSecondWidth
319
+ const noteY = (HIGH_NOTE - note.midi) * rowHeight
320
+
321
+ const viewStart = container.scrollLeft
322
+ const viewEnd = viewStart + container.clientWidth
323
+ if (noteX < viewStart + 50 || noteX > viewEnd - 50) {
324
+ container.scrollLeft = Math.max(0, noteX - container.clientWidth * 0.35)
325
+ }
326
+
327
+ const viewTop = container.scrollTop
328
+ const viewBottom = viewTop + container.clientHeight
329
+ if (noteY < viewTop || noteY > viewBottom - rowHeight) {
330
+ container.scrollTop = Math.max(0, noteY - container.clientHeight * 0.4)
331
+ }
332
+ }, [selectedId, notes, gridSecondWidth, rowHeight, beatToSeconds])
333
+
334
+ const handleGridDoubleClick = (event: React.MouseEvent<HTMLDivElement>) => {
335
+ // Only add note if clicking on empty space (not on a note)
336
+ const target = event.target as HTMLElement
337
+ if (target.closest('.note-chip')) return
338
+
339
+ if (!scrollContainerRef.current) return
340
+ const container = scrollContainerRef.current
341
+ const rect = container.getBoundingClientRect()
342
+ const x = event.clientX - rect.left + container.scrollLeft
343
+ const y = event.clientY - rect.top + container.scrollTop
344
+
345
+ const seconds = snapSeconds(x / gridSecondWidth, gridSecondWidth)
346
+ const pitch = Math.min(HIGH_NOTE, Math.max(LOW_NOTE, HIGH_NOTE - Math.floor(y / rowHeight)))
347
+
348
+ const created = onAddNote({
349
+ start: secondsToBeat(seconds),
350
+ midi: pitch,
351
+ duration: secondsToBeat(0.5),
352
+ lyric: ''
353
+ })
354
+ onSelect(created.id)
355
+ }
356
+
357
+ const startDrag = (
358
+ event: React.PointerEvent<HTMLDivElement>,
359
+ note: NoteEvent,
360
+ mode: 'move' | 'resize-start' | 'resize-end',
361
+ ) => {
362
+ event.preventDefault()
363
+ event.stopPropagation()
364
+ dragRef.current = {
365
+ id: note.id,
366
+ mode,
367
+ originX: event.clientX,
368
+ originY: event.clientY,
369
+ startSeconds: beatToSeconds(note.start),
370
+ durationSeconds: beatToSeconds(note.duration),
371
+ midi: note.midi,
372
+ lastMidi: note.midi, // Initialize last midi
373
+ }
374
+ window.addEventListener('pointermove', handlePointerMove)
375
+ window.addEventListener('pointerup', handlePointerUp)
376
+ onSelect(note.id)
377
+
378
+ // Play sound when clicking/selecting note
379
+ if (onPlayNote) {
380
+ onPlayNote(note.midi)
381
+ }
382
+ }
383
+
384
+ // Second-based ruler labels
385
+ const secondLabels = useMemo(() => {
386
+ const labels = [] as Array<{ left: number; label: string }>
387
+ const totalSeconds = Math.ceil(maxSeconds)
388
+ for (let s = 0; s <= totalSeconds; s += 1) {
389
+ labels.push({ left: s * gridSecondWidth, label: `${s}s` })
390
+ }
391
+ return labels
392
+ }, [maxSeconds, gridSecondWidth])
393
+
394
+ // Piano keys
395
+ const pitchRows = useMemo(() => {
396
+ const rows = [] as Array<{ midi: number; isBlack: boolean; label: string; isC: boolean }>
397
+ const black = new Set([1, 3, 6, 8, 10])
398
+ for (let p = HIGH_NOTE; p >= LOW_NOTE; p -= 1) {
399
+ const name = midiToName(p)
400
+ const isC = p % 12 === 0
401
+ rows.push({ midi: p, isBlack: black.has(p % 12), label: name, isC })
402
+ }
403
+ return rows
404
+ }, [])
405
+
406
+ // Detect overlapping notes using optimized sweep line algorithm
407
+ const overlappingNoteIds = useMemo(() => {
408
+ if (notes.length < 2) return new Set<string>()
409
+
410
+ const overlapping = new Set<string>()
411
+ const sortedNotes = [...notes].sort((a, b) => a.start - b.start)
412
+ const EPSILON = 0.05 // Tolerance for floating point comparison
413
+
414
+ // Use a sliding window approach - more efficient for typical music data
415
+ // Active notes: notes that haven't ended yet
416
+ const activeNotes: typeof sortedNotes = []
417
+
418
+ for (const note of sortedNotes) {
419
+ // Remove notes that have ended before current note starts
420
+ while (activeNotes.length > 0) {
421
+ const firstActive = activeNotes[0]
422
+ const firstActiveEnd = firstActive.start + firstActive.duration
423
+ if (firstActiveEnd <= note.start + EPSILON) {
424
+ activeNotes.shift()
425
+ } else {
426
+ break
427
+ }
428
+ }
429
+
430
+ // Check overlap with remaining active notes
431
+ for (const activeNote of activeNotes) {
432
+ const activeEnd = activeNote.start + activeNote.duration
433
+ if (note.start < activeEnd - EPSILON) {
434
+ overlapping.add(activeNote.id)
435
+ overlapping.add(note.id)
436
+ }
437
+ }
438
+
439
+ // Add current note to active set (maintain sorted order by end time)
440
+ const noteEnd = note.start + note.duration
441
+ let insertIndex = activeNotes.length
442
+ for (let i = 0; i < activeNotes.length; i++) {
443
+ const aEnd = activeNotes[i].start + activeNotes[i].duration
444
+ if (noteEnd < aEnd) {
445
+ insertIndex = i
446
+ break
447
+ }
448
+ }
449
+ activeNotes.splice(insertIndex, 0, note)
450
+ }
451
+ return overlapping
452
+ }, [notes])
453
+
454
+ // Calculate visible area with buffer for smooth scrolling
455
+ const BUFFER_PX = 200 // Render notes slightly outside viewport for smooth scrolling
456
+ const visibleArea = useMemo(() => {
457
+ return {
458
+ left: Math.max(0, scrollLeft - BUFFER_PX),
459
+ right: scrollLeft + viewportWidth + BUFFER_PX,
460
+ top: Math.max(0, scrollTop - BUFFER_PX),
461
+ bottom: scrollTop + viewportHeight + BUFFER_PX,
462
+ }
463
+ }, [scrollLeft, scrollTop, viewportWidth, viewportHeight])
464
+
465
+ // Filter notes to only render visible ones (virtualization)
466
+ const visibleNotes = useMemo(() => {
467
+ return notes.filter(note => {
468
+ const noteSeconds = beatToSeconds(note.start)
469
+ const noteDurationSeconds = beatToSeconds(note.duration)
470
+ const noteLeft = noteSeconds * gridSecondWidth
471
+ const noteRight = noteLeft + noteDurationSeconds * gridSecondWidth
472
+ const noteTop = (HIGH_NOTE - note.midi) * rowHeight
473
+ const noteBottom = noteTop + rowHeight
474
+
475
+ // Check if note intersects with visible area
476
+ const horizontallyVisible = noteRight >= visibleArea.left && noteLeft <= visibleArea.right
477
+ const verticallyVisible = noteBottom >= visibleArea.top && noteTop <= visibleArea.bottom
478
+
479
+ return horizontallyVisible && verticallyVisible
480
+ })
481
+ }, [notes, visibleArea, gridSecondWidth, rowHeight, beatToSeconds])
482
+
483
+ // Calculate visible grid lines (virtualization)
484
+ const visibleGridLines = useMemo(() => {
485
+ const startSecond = Math.max(0, Math.floor(visibleArea.left / gridSecondWidth) - 1)
486
+ const endSecond = Math.ceil(visibleArea.right / gridSecondWidth) + 1
487
+ const startRow = Math.max(0, Math.floor(visibleArea.top / rowHeight) - 1)
488
+ const endRow = Math.min(totalRows, Math.ceil(visibleArea.bottom / rowHeight) + 1)
489
+
490
+ return {
491
+ horizontalLines: Array.from({ length: endRow - startRow + 1 }, (_, i) => startRow + i),
492
+ verticalLines: Array.from({ length: endSecond - startSecond + 1 }, (_, i) => startSecond + i),
493
+ }
494
+ }, [visibleArea, gridSecondWidth, rowHeight, totalRows])
495
+
496
+ const playheadSeconds = beatToSeconds(playhead)
497
+
498
+ // Selection drag handlers
499
+ const handleRulerPointerDown = (event: React.PointerEvent<HTMLDivElement>) => {
500
+ if (!isSelectingRange) {
501
+ // Normal click to seek
502
+ const rect = event.currentTarget.getBoundingClientRect()
503
+ const x = event.clientX - rect.left + (rulerScrollRef.current?.scrollLeft ?? 0)
504
+ const seconds = x / gridSecondWidth
505
+ onSeek(secondsToBeat(seconds))
506
+ return
507
+ }
508
+
509
+ // Start selection drag
510
+ const rect = event.currentTarget.getBoundingClientRect()
511
+ const x = event.clientX - rect.left + (rulerScrollRef.current?.scrollLeft ?? 0)
512
+ const seconds = Math.max(0, x / gridSecondWidth)
513
+
514
+ selectionDragRef.current = {
515
+ startX: event.clientX,
516
+ startSeconds: seconds,
517
+ }
518
+
519
+ onSelectionChange?.(seconds, seconds)
520
+
521
+ const handleSelectionMove = (e: PointerEvent) => {
522
+ if (!selectionDragRef.current) return
523
+ const currentX = e.clientX - rect.left + (rulerScrollRef.current?.scrollLeft ?? 0)
524
+ const currentSeconds = Math.max(0, currentX / gridSecondWidth)
525
+ const start = Math.min(selectionDragRef.current.startSeconds, currentSeconds)
526
+ const end = Math.max(selectionDragRef.current.startSeconds, currentSeconds)
527
+ onSelectionChange?.(start, end)
528
+ }
529
+
530
+ const handleSelectionUp = () => {
531
+ selectionDragRef.current = null
532
+ window.removeEventListener('pointermove', handleSelectionMove)
533
+ window.removeEventListener('pointerup', handleSelectionUp)
534
+ }
535
+
536
+ window.addEventListener('pointermove', handleSelectionMove)
537
+ window.addEventListener('pointerup', handleSelectionUp)
538
+ }
539
+
540
+ return (
541
+ <div className="piano-shell">
542
+ {/* Ruler */}
543
+ <div className="ruler-shell">
544
+ <div className="ruler-spacer" style={{ width: PITCH_WIDTH, flexShrink: 0 }} />
545
+ <div
546
+ ref={rulerScrollRef}
547
+ className={`ruler-scroll ${isSelectingRange ? 'selecting' : ''}`}
548
+ onPointerDown={handleRulerPointerDown}
549
+ >
550
+ <div className="ruler" style={{ width: contentWidth }}>
551
+ {secondLabels.map((mark) => (
552
+ <div key={mark.left} className="measure-mark" style={{ left: mark.left }}>
553
+ <span>{mark.label}</span>
554
+ </div>
555
+ ))}
556
+ {/* Selection range indicator */}
557
+ {selectionStart !== null && selectionEnd !== null && selectionEnd > selectionStart && (
558
+ <div
559
+ className="selection-range"
560
+ style={{
561
+ left: selectionStart * gridSecondWidth,
562
+ width: (selectionEnd - selectionStart) * gridSecondWidth
563
+ }}
564
+ />
565
+ )}
566
+ {/* Ruler playhead indicator */}
567
+ <div
568
+ className="ruler-playhead"
569
+ style={{ left: playheadSeconds * gridSecondWidth }}
570
+ />
571
+ </div>
572
+ </div>
573
+ </div>
574
+
575
+ {/* Main content area */}
576
+ <div className="roll-body">
577
+ {/* Piano keys - synced with vertical scroll */}
578
+ <div className="pitch-rail" style={{ width: PITCH_WIDTH }}>
579
+ <div
580
+ className="pitch-rail-inner"
581
+ style={{
582
+ transform: `translateY(${-scrollTop}px)`,
583
+ height: contentHeight
584
+ }}
585
+ >
586
+ {pitchRows.map((pitch) => (
587
+ <div
588
+ key={pitch.midi}
589
+ className={`pitch-cell ${pitch.isBlack ? 'pitch-black' : 'pitch-white'} ${pitch.isC ? 'pitch-c' : ''}`}
590
+ style={{ height: rowHeight, cursor: 'pointer' }}
591
+ onClick={() => onPlayNote?.(pitch.midi)}
592
+ onMouseDown={(e) => e.preventDefault()}
593
+ >
594
+ <span className="pitch-label">{pitch.label}</span>
595
+ </div>
596
+ ))}
597
+ </div>
598
+ </div>
599
+
600
+ {/* Scrollable grid area */}
601
+ <div
602
+ ref={scrollContainerRef}
603
+ className="roll-grid"
604
+ onDoubleClick={handleGridDoubleClick}
605
+ >
606
+ <div
607
+ className="grid-content"
608
+ style={{
609
+ width: contentWidth,
610
+ height: contentHeight,
611
+ position: 'relative'
612
+ }}
613
+ >
614
+ {/* SVG Grid - virtualized for performance */}
615
+ <svg
616
+ className="grid-svg"
617
+ width={contentWidth}
618
+ height={contentHeight}
619
+ style={{ position: 'absolute', top: 0, left: 0, pointerEvents: 'none' }}
620
+ >
621
+ {/* Horizontal lines (pitch rows) - only visible ones */}
622
+ {visibleGridLines.horizontalLines.map(i => (
623
+ <line
624
+ key={`h-${i}`}
625
+ x1={visibleArea.left}
626
+ y1={i * rowHeight}
627
+ x2={visibleArea.right}
628
+ y2={i * rowHeight}
629
+ stroke="var(--grid-line-minor)"
630
+ strokeWidth={1}
631
+ />
632
+ ))}
633
+ {/* Vertical lines (seconds) - only visible ones */}
634
+ {visibleGridLines.verticalLines.map(i => (
635
+ <line
636
+ key={`v-${i}`}
637
+ x1={i * gridSecondWidth}
638
+ y1={visibleArea.top}
639
+ x2={i * gridSecondWidth}
640
+ y2={visibleArea.bottom}
641
+ stroke="var(--grid-line-minor)"
642
+ strokeWidth={1}
643
+ />
644
+ ))}
645
+ </svg>
646
+
647
+ {/* Selection range in grid */}
648
+ {selectionStart !== null && selectionEnd !== null && selectionEnd > selectionStart && (
649
+ <div
650
+ className="grid-selection-range"
651
+ style={{
652
+ left: selectionStart * gridSecondWidth,
653
+ width: (selectionEnd - selectionStart) * gridSecondWidth,
654
+ height: contentHeight
655
+ }}
656
+ />
657
+ )}
658
+
659
+ {/* Playhead */}
660
+ <div
661
+ className="playhead"
662
+ style={{
663
+ left: playheadSeconds * gridSecondWidth,
664
+ height: contentHeight
665
+ }}
666
+ />
667
+
668
+ {/* Notes - virtualized: only render visible notes */}
669
+ {visibleNotes.map((note) => {
670
+ const noteSeconds = beatToSeconds(note.start)
671
+ const noteDurationSeconds = beatToSeconds(note.duration)
672
+ const left = noteSeconds * gridSecondWidth
673
+ const top = (HIGH_NOTE - note.midi) * rowHeight
674
+ const noteWidthPx = Math.max(noteDurationSeconds * gridSecondWidth, 4)
675
+ const noteHeight = rowHeight - 2
676
+ const isOverlapping = overlappingNoteIds.has(note.id)
677
+ // Dynamic font size based on row height (base: 12px at 20px row height)
678
+ const fontSize = Math.max(10, Math.min(24, rowHeight * 0.6))
679
+
680
+ return (
681
+ <NoteChip
682
+ key={note.id}
683
+ note={note}
684
+ left={left}
685
+ top={top}
686
+ width={noteWidthPx}
687
+ height={noteHeight}
688
+ fontSize={fontSize}
689
+ isSelected={selectedId === note.id}
690
+ isOverlapping={isOverlapping}
691
+ onPointerDown={(event, mode) => startDrag(event, note, mode)}
692
+ onDoubleClick={(event) => {
693
+ event.stopPropagation()
694
+ onFocusLyric?.(note.id)
695
+ }}
696
+ />
697
+ )
698
+ })}
699
+ </div>
700
+ </div>
701
+ </div>
702
+ </div>
703
+ )
704
+ }
preprocess/tools/midi_editor/src/constants.ts ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ // Base values used for scaling; actual runtime values are derived in components
2
+ export const BASE_GRID_SECOND_WIDTH = 80
3
+ export const BASE_ROW_HEIGHT = 20
4
+ export const PITCH_WIDTH = 60
5
+ // C-1 to C8 range (MIDI note numbers)
6
+ // LOW_NOTE = 0 to support SP markers (pitch=0) in some MIDI files
7
+ export const LOW_NOTE = 0 // C-1 (also supports pitch=0 for SP markers)
8
+ export const HIGH_NOTE = 108 // C8
preprocess/tools/midi_editor/src/i18n.ts ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export type Lang = 'zh' | 'en'
2
+
3
+ const zh = {
4
+ // Header
5
+ eyebrow: '歌声 MIDI 编辑器',
6
+ title: 'SoulX-Singer MIDI Editor',
7
+ subtitle: '导入、拖拽、实时修改歌词并导出标准 MIDI。',
8
+ switchToLight: '切换到亮色',
9
+ switchToDark: '切换到暗色',
10
+ importJson: '导入 JSON',
11
+ exportJson: '导出 JSON',
12
+ importMidi: '导入 MIDI',
13
+ exportMidi: '导出 MIDI',
14
+ transpose: '移调',
15
+ transposeTooltip: '整体升降调:所有音符的音高同步改变',
16
+ transposed: (n: number) => `已移调 ${n > 0 ? '+' : ''}${n} 半音`,
17
+ fixOverlaps: '消除重叠',
18
+ fixOverlapsTooltip: '自动消除重叠:将重叠音符的音尾提前到下一个音的音头',
19
+ jsonImported: (name: string) => `已从 JSON 载入 ${name}`,
20
+ jsonImportFailed: 'JSON 导入失败,请确认文件格式正确',
21
+ jsonExported: '已导出 META JSON 文件',
22
+
23
+ // Audio bar
24
+ importAudio: '对齐音频导入',
25
+ audioHint: '导入后显示音频波形并与 MIDI 同步走带',
26
+ midiLabel: 'MIDI',
27
+ audioLabel: '音频',
28
+
29
+ // Controls
30
+ horizontalZoom: '水平缩放',
31
+ verticalZoom: '垂直缩放',
32
+ goToStart: '回到开头',
33
+ back2s: '后退 2 秒',
34
+ pause: '暂停',
35
+ playSelection: '播放选区',
36
+ play: '播放',
37
+ forward2s: '前进 2 秒',
38
+ goToEnd: '回到结尾',
39
+ selectingRange: '选区中',
40
+ setRange: '设选区',
41
+ exitSelectMode: '退出选区模式(并清除选区)',
42
+ setRangeTooltip: '设置选区:在时间轴上拖拽选择播放范围',
43
+
44
+ // Status
45
+ ready: '准备就绪',
46
+ selectionPlayback: '选区回放中...',
47
+ playing: '正在回放...',
48
+ selectionDone: '选区播放完毕',
49
+ paused: '已暂停',
50
+ imported: (name: string) => `已载入 ${name}`,
51
+ importFailed: '导入失败,请确认文件合法',
52
+ audioImported: (name: string) => `已载入音频 ${name}`,
53
+ unsupportedFormat: (exts: string) => `不支持的文件格式,请选择音频文件(${exts})`,
54
+ fixedOverlaps: (count: number) => `已修复 ${count} 个重叠音符`,
55
+ noOverlaps: '没有检测到重叠音符',
56
+ exported: '已导出包含歌词的 MIDI 文件',
57
+
58
+ // Lyric table
59
+ fillPlaceholderSelected: '从选中音符开始按词/字填充',
60
+ fillPlaceholderDefault: '输入歌词,点击按词/字填充',
61
+ fillButton: '按词\n填充',
62
+ lyricPlaceholder: '输入歌词',
63
+ emptyHint: '导入或双击钢琴卷帘以添加音符',
64
+ confirmEdit: '确认修改 (Enter)',
65
+ }
66
+
67
+ const en: typeof zh = {
68
+ // Header
69
+ eyebrow: 'Vocal MIDI Editor',
70
+ title: 'SoulX-Singer MIDI Editor',
71
+ subtitle: 'Import, drag, edit lyrics in real-time, and export standard MIDI.',
72
+ switchToLight: 'Switch to light',
73
+ switchToDark: 'Switch to dark',
74
+ importJson: 'Import JSON',
75
+ exportJson: 'Export JSON',
76
+ importMidi: 'Import MIDI',
77
+ exportMidi: 'Export MIDI',
78
+ transpose: 'Transpose',
79
+ transposeTooltip: 'Transpose all notes up or down by semitones',
80
+ transposed: (n: number) => `Transposed ${n > 0 ? '+' : ''}${n} semitone(s)`,
81
+ fixOverlaps: 'Fix Overlaps',
82
+ fixOverlapsTooltip: 'Auto fix overlaps: trim note end to the start of the next note',
83
+ jsonImported: (name: string) => `Loaded from JSON ${name}`,
84
+ jsonImportFailed: 'JSON import failed, please check the file format',
85
+ jsonExported: 'Exported META JSON file',
86
+
87
+ // Audio bar
88
+ importAudio: 'Import Audio',
89
+ audioHint: 'Display audio waveform synced with MIDI transport',
90
+ midiLabel: 'MIDI',
91
+ audioLabel: 'Audio',
92
+
93
+ // Controls
94
+ horizontalZoom: 'H-Zoom',
95
+ verticalZoom: 'V-Zoom',
96
+ goToStart: 'Go to start',
97
+ back2s: 'Back 2s',
98
+ pause: 'Pause',
99
+ playSelection: 'Play selection',
100
+ play: 'Play',
101
+ forward2s: 'Forward 2s',
102
+ goToEnd: 'Go to end',
103
+ selectingRange: 'Selecting',
104
+ setRange: 'Select',
105
+ exitSelectMode: 'Exit selection mode (and clear selection)',
106
+ setRangeTooltip: 'Set selection: drag on the timeline to select playback range',
107
+
108
+ // Status
109
+ ready: 'Ready',
110
+ selectionPlayback: 'Playing selection...',
111
+ playing: 'Playing...',
112
+ selectionDone: 'Selection playback done',
113
+ paused: 'Paused',
114
+ imported: (name: string) => `Loaded ${name}`,
115
+ importFailed: 'Import failed, please check the file',
116
+ audioImported: (name: string) => `Loaded audio ${name}`,
117
+ unsupportedFormat: (exts: string) => `Unsupported format, please select an audio file (${exts})`,
118
+ fixedOverlaps: (count: number) => `Fixed ${count} overlapping note(s)`,
119
+ noOverlaps: 'No overlapping notes detected',
120
+ exported: 'Exported MIDI file with lyrics',
121
+
122
+ // Lyric table
123
+ fillPlaceholderSelected: 'Fill words from selected note',
124
+ fillPlaceholderDefault: 'Enter lyrics, click fill button',
125
+ fillButton: 'Fill\nWords',
126
+ lyricPlaceholder: 'Type lyric',
127
+ emptyHint: 'Import or double-click piano roll to add notes',
128
+ confirmEdit: 'Confirm (Enter)',
129
+ }
130
+
131
+ const translations: Record<Lang, typeof zh> = { zh, en }
132
+
133
+ export type Translations = typeof zh
134
+
135
+ export function getTranslations(lang: Lang): Translations {
136
+ return translations[lang]
137
+ }
138
+
139
+ // Smart tokenizer for lyrics: CJK characters are individual tokens, Latin words are grouped
140
+ function isCJK(char: string): boolean {
141
+ const code = char.codePointAt(0) || 0
142
+ return (
143
+ (code >= 0x4E00 && code <= 0x9FFF) || // CJK Unified Ideographs
144
+ (code >= 0x3400 && code <= 0x4DBF) || // CJK Extension A
145
+ (code >= 0x20000 && code <= 0x2A6DF) || // CJK Extension B
146
+ (code >= 0x3040 && code <= 0x309F) || // Hiragana
147
+ (code >= 0x30A0 && code <= 0x30FF) || // Katakana
148
+ (code >= 0xAC00 && code <= 0xD7AF) // Hangul Syllables
149
+ )
150
+ }
151
+
152
+ /**
153
+ * Tokenize lyrics text for note filling.
154
+ * - CJK characters: each character becomes one token (one per note)
155
+ * - Latin/English words: each space-separated word becomes one token (one per note)
156
+ * - Mixed text is handled correctly
157
+ *
158
+ * Examples:
159
+ * "你好世界" -> ["你", "好", "世", "界"]
160
+ * "hello world" -> ["hello", "world"]
161
+ * "I love 你" -> ["I", "love", "你"]
162
+ * "something wrong" -> ["something", "wrong"]
163
+ */
164
+ export function tokenizeLyrics(text: string): string[] {
165
+ const tokens: string[] = []
166
+ const cleaned = text.trim()
167
+ if (!cleaned) return tokens
168
+
169
+ let i = 0
170
+ while (i < cleaned.length) {
171
+ const char = cleaned[i]
172
+
173
+ // Skip whitespace
174
+ if (/\s/.test(char)) {
175
+ i++
176
+ continue
177
+ }
178
+
179
+ // CJK character - each is a separate token
180
+ if (isCJK(char)) {
181
+ tokens.push(char)
182
+ i++
183
+ continue
184
+ }
185
+
186
+ // Latin/number/other - collect until whitespace or CJK
187
+ let word = ''
188
+ while (i < cleaned.length && !/\s/.test(cleaned[i]) && !isCJK(cleaned[i])) {
189
+ word += cleaned[i]
190
+ i++
191
+ }
192
+ if (word) tokens.push(word)
193
+ }
194
+
195
+ return tokens
196
+ }
preprocess/tools/midi_editor/src/index.css ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
4
+
5
+ :root {
6
+ font-family: 'Space Grotesk', 'IBM Plex Sans', system-ui, sans-serif;
7
+ color: var(--text-primary);
8
+ background: radial-gradient(circle at 20% 20%, rgba(72, 228, 194, 0.08), transparent 35%),
9
+ radial-gradient(circle at 80% 0%, rgba(75, 100, 188, 0.24), transparent 40%),
10
+ #0f1528;
11
+ text-rendering: optimizeLegibility;
12
+ -webkit-font-smoothing: antialiased;
13
+ }
14
+
15
+ :root[data-theme='light'] {
16
+ background: radial-gradient(circle at 20% 20%, rgba(63, 140, 255, 0.08), transparent 35%),
17
+ radial-gradient(circle at 80% 0%, rgba(75, 100, 188, 0.14), transparent 40%),
18
+ #f5f7fb;
19
+ }
20
+
21
+ * {
22
+ box-sizing: border-box;
23
+ }
24
+
25
+ body {
26
+ margin: 0;
27
+ min-height: 100vh;
28
+ background: transparent;
29
+ }
30
+
31
+ #root {
32
+ min-height: 100vh;
33
+ }
34
+
35
+ a {
36
+ color: inherit;
37
+ }
preprocess/tools/midi_editor/src/lib/midi.ts ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Midi } from '@tonejs/midi'
2
+ import { writeMidi } from 'midi-file'
3
+ import type { MidiData, MidiEvent } from 'midi-file'
4
+ import type { NoteEvent, ProjectSnapshot, TimeSignature } from '../types'
5
+
6
+ const DEFAULT_SIGNATURE: TimeSignature = [4, 4]
7
+
8
+ // Decode UTF-8 byte string (latin1 encoded) to proper Unicode string
9
+ // This matches: text.encode("latin1").decode("utf-8") in Python
10
+ function decodeUtf8ByteString(byteString: string): string {
11
+ try {
12
+ const bytes = new Uint8Array(byteString.length)
13
+ for (let i = 0; i < byteString.length; i++) {
14
+ bytes[i] = byteString.charCodeAt(i)
15
+ }
16
+ return new TextDecoder('utf-8').decode(bytes)
17
+ } catch {
18
+ return byteString
19
+ }
20
+ }
21
+
22
+ // Encode Unicode string to UTF-8 byte string (latin1 encoding)
23
+ // This matches: text.encode("utf-8").decode("latin1") in Python
24
+ function encodeUtf8ByteString(text: string): string {
25
+ const bytes = new TextEncoder().encode(text)
26
+ let output = ''
27
+ bytes.forEach((b) => {
28
+ output += String.fromCharCode(b)
29
+ })
30
+ return output
31
+ }
32
+
33
+ export async function importMidiFile(file: File): Promise<ProjectSnapshot> {
34
+ const buffer = await file.arrayBuffer()
35
+ return parseMidiBuffer(buffer)
36
+ }
37
+
38
+ export async function parseMidiBuffer(buffer: ArrayBuffer): Promise<ProjectSnapshot> {
39
+ const midi = new Midi(buffer)
40
+ const tempo = midi.header.tempos[0]?.bpm ?? 120
41
+ const timeSignature = (midi.header.timeSignatures[0]?.timeSignature as TimeSignature | undefined) ?? DEFAULT_SIGNATURE
42
+
43
+ // Merge notes from all tracks and sort by ticks then by midi (for stable ordering)
44
+ const allNotes = midi.tracks
45
+ .flatMap(t => t.notes)
46
+ .sort((a, b) => a.ticks - b.ticks || a.midi - b.midi)
47
+
48
+ // Get lyrics from header.meta and sort by ticks
49
+ const lyricEvents = midi.header.meta
50
+ .filter((event) => event.type === 'lyrics')
51
+ .sort((a, b) => a.ticks - b.ticks)
52
+
53
+ // Match lyrics to notes by tick position
54
+ // Each lyric should be consumed by exactly one note at the same tick
55
+ const lyricsByTick = new Map<number, string[]>()
56
+ for (const event of lyricEvents) {
57
+ const existing = lyricsByTick.get(event.ticks) || []
58
+ existing.push(decodeUtf8ByteString(event.text))
59
+ lyricsByTick.set(event.ticks, existing)
60
+ }
61
+
62
+ // Track which lyrics have been used at each tick position
63
+ const usedLyricIndices = new Map<number, number>()
64
+
65
+ const notes: NoteEvent[] = allNotes.map((note, index) => {
66
+ const beat = note.ticks / midi.header.ppq
67
+ const durationBeats = note.durationTicks / midi.header.ppq
68
+
69
+ let lyric = ''
70
+
71
+ // First try exact tick match
72
+ const lyricsAtTick = lyricsByTick.get(note.ticks)
73
+ if (lyricsAtTick && lyricsAtTick.length > 0) {
74
+ const usedIndex = usedLyricIndices.get(note.ticks) || 0
75
+ if (usedIndex < lyricsAtTick.length) {
76
+ lyric = lyricsAtTick[usedIndex]
77
+ usedLyricIndices.set(note.ticks, usedIndex + 1)
78
+ }
79
+ }
80
+
81
+ // If no exact match, try nearby ticks (within small tolerance)
82
+ if (!lyric) {
83
+ const tolerance = midi.header.ppq / 100 // Very small tolerance
84
+ for (const [tick, lyrics] of lyricsByTick.entries()) {
85
+ if (Math.abs(tick - note.ticks) <= tolerance) {
86
+ const usedIndex = usedLyricIndices.get(tick) || 0
87
+ if (usedIndex < lyrics.length) {
88
+ lyric = lyrics[usedIndex]
89
+ usedLyricIndices.set(tick, usedIndex + 1)
90
+ break
91
+ }
92
+ }
93
+ }
94
+ }
95
+
96
+ return {
97
+ id: `${index}-${note.midi}-${Math.round(note.ticks)}`,
98
+ midi: note.midi,
99
+ start: beat,
100
+ duration: Math.max(durationBeats, 0.0625),
101
+ velocity: note.velocity,
102
+ lyric,
103
+ }
104
+ })
105
+
106
+ return { tempo, timeSignature, notes, ppq: midi.header.ppq }
107
+ }
108
+
109
+ // Used to add absoluteTime property for sorting
110
+ type WithAbsoluteTime<T> = T & { absoluteTime: number }
111
+
112
+ export function exportMidi(snapshot: ProjectSnapshot): Blob {
113
+ const ppq = snapshot.ppq ?? 480 // Use original ppq if available, otherwise default to 480
114
+ const microsecondsPerBeat = Math.round(60000000 / snapshot.tempo) // Convert BPM to microseconds per beat
115
+
116
+ // Sort notes by start time, then by midi for stable ordering
117
+ const sortedNotes = [...snapshot.notes].sort((a, b) => a.start - b.start || a.midi - b.midi)
118
+
119
+ // Build events for a single track containing both lyrics and notes
120
+ // Event order at same tick: note_off (0) < lyrics (1) < note_on (2)
121
+ // This matches meta.py's tg2midi implementation
122
+ const events: Array<WithAbsoluteTime<MidiEvent>> = []
123
+
124
+ // Add all note events and their corresponding lyrics
125
+ sortedNotes.forEach((note) => {
126
+ const startTicks = Math.round(note.start * ppq)
127
+ const endTicks = Math.round((note.start + note.duration) * ppq)
128
+ const velocity = Math.round(note.velocity * 127)
129
+
130
+ // Add lyric event at the same tick as note_on (but will be sorted before it)
131
+ const lyricText = note.lyric ?? ''
132
+ const encodedLyric = encodeUtf8ByteString(lyricText)
133
+
134
+ // Lyric event - sort key 1 (after note_off, before note_on)
135
+ events.push({
136
+ absoluteTime: startTicks,
137
+ deltaTime: 0,
138
+ meta: true,
139
+ type: 'lyrics',
140
+ text: encodedLyric,
141
+ _sortKey: 1,
142
+ } as WithAbsoluteTime<MidiEvent> & { _sortKey: number })
143
+
144
+ // Note on event - sort key 2 (after lyrics)
145
+ events.push({
146
+ absoluteTime: startTicks,
147
+ deltaTime: 0,
148
+ type: 'noteOn',
149
+ channel: 0,
150
+ noteNumber: note.midi,
151
+ velocity: velocity,
152
+ _sortKey: 2,
153
+ } as WithAbsoluteTime<MidiEvent> & { _sortKey: number })
154
+
155
+ // Note off event - sort key 0 (before everything at same tick)
156
+ events.push({
157
+ absoluteTime: endTicks,
158
+ deltaTime: 0,
159
+ type: 'noteOff',
160
+ channel: 0,
161
+ noteNumber: note.midi,
162
+ velocity: 0,
163
+ _sortKey: 0,
164
+ } as WithAbsoluteTime<MidiEvent> & { _sortKey: number })
165
+ })
166
+
167
+ // Sort events by absoluteTime, then by _sortKey
168
+ events.sort((a, b) => {
169
+ const aKey = (a as { _sortKey?: number })._sortKey ?? 1
170
+ const bKey = (b as { _sortKey?: number })._sortKey ?? 1
171
+ return a.absoluteTime - b.absoluteTime || aKey - bKey
172
+ })
173
+
174
+ // Convert absolute time to delta time
175
+ let lastTick = 0
176
+ events.forEach(event => {
177
+ event.deltaTime = event.absoluteTime - lastTick
178
+ lastTick = event.absoluteTime
179
+ delete (event as { absoluteTime?: number }).absoluteTime
180
+ delete (event as { _sortKey?: number })._sortKey
181
+ })
182
+
183
+ // Build the MIDI track with header events
184
+ const track: MidiEvent[] = [
185
+ // Set tempo
186
+ {
187
+ deltaTime: 0,
188
+ meta: true,
189
+ type: 'setTempo',
190
+ microsecondsPerBeat: microsecondsPerBeat,
191
+ },
192
+ // Time signature
193
+ {
194
+ deltaTime: 0,
195
+ meta: true,
196
+ type: 'timeSignature',
197
+ numerator: snapshot.timeSignature[0],
198
+ denominator: snapshot.timeSignature[1],
199
+ metronome: 24,
200
+ thirtyseconds: 8,
201
+ },
202
+ // All note and lyric events
203
+ ...events,
204
+ // End of track
205
+ {
206
+ deltaTime: 0,
207
+ meta: true,
208
+ type: 'endOfTrack',
209
+ },
210
+ ]
211
+
212
+ // Build MIDI data structure
213
+ const midiData: MidiData = {
214
+ header: {
215
+ format: 0, // Single track format (type 0)
216
+ numTracks: 1,
217
+ ticksPerBeat: ppq,
218
+ },
219
+ tracks: [track],
220
+ }
221
+
222
+ const bytes = writeMidi(midiData)
223
+ return new Blob([new Uint8Array(bytes)], { type: 'audio/midi' })
224
+ }
preprocess/tools/midi_editor/src/main.tsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import { StrictMode } from 'react'
2
+ import { createRoot } from 'react-dom/client'
3
+ import './index.css'
4
+ import App from './App.tsx'
5
+
6
+ createRoot(document.getElementById('root')!).render(
7
+ <StrictMode>
8
+ <App />
9
+ </StrictMode>,
10
+ )
preprocess/tools/midi_editor/src/store/useMidiStore.ts ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { nanoid } from 'nanoid'
2
+ import { create } from 'zustand'
3
+ import type { NoteEvent, TimeSignature } from '../types'
4
+
5
+ const clamp = (value: number, min: number, max: number) =>
6
+ Math.min(Math.max(value, min), max)
7
+
8
+ export type MidiStore = {
9
+ tempo: number
10
+ timeSignature: TimeSignature
11
+ notes: NoteEvent[]
12
+ selectedId: string | null
13
+ playhead: number
14
+ ppq: number | undefined // Ticks per quarter note (for preserving original MIDI timing)
15
+ addNote: (partial?: Partial<NoteEvent>) => NoteEvent
16
+ updateNote: (id: string, partial: Partial<NoteEvent>) => void
17
+ removeNote: (id: string) => void
18
+ setNotes: (notes: NoteEvent[]) => void
19
+ setTempo: (tempo: number) => void
20
+ setTimeSignature: (sig: TimeSignature) => void
21
+ setPpq: (ppq: number | undefined) => void
22
+ select: (id: string | null) => void
23
+ setLyric: (id: string, lyric: string) => void
24
+ setPlayhead: (beat: number) => void
25
+ clear: () => void
26
+ }
27
+
28
+ const defaultNotes: NoteEvent[] = [
29
+ { id: nanoid(), midi: 64, start: 0, duration: 1.5, velocity: 0.9, lyric: 'la' },
30
+ { id: nanoid(), midi: 67, start: 1.5, duration: 1.5, velocity: 0.85, lyric: 'na' },
31
+ { id: nanoid(), midi: 69, start: 3, duration: 2, velocity: 0.8, lyric: 'ah' },
32
+ ]
33
+
34
+ export const useMidiStore = create<MidiStore>((set) => ({
35
+ tempo: 110,
36
+ timeSignature: [4, 4],
37
+ notes: defaultNotes,
38
+ selectedId: null,
39
+ playhead: 0,
40
+ ppq: undefined,
41
+ addNote: (partial = {}) => {
42
+ const note: NoteEvent = {
43
+ id: nanoid(),
44
+ midi: partial.midi ?? 64,
45
+ start: partial.start ?? 0,
46
+ duration: partial.duration ?? 1,
47
+ velocity: clamp(partial.velocity ?? 0.85, 0, 1),
48
+ lyric: partial.lyric ?? '',
49
+ }
50
+ set((state) => ({ notes: [...state.notes, note] }))
51
+ return note
52
+ },
53
+ updateNote: (id, partial) => {
54
+ set((state) => ({
55
+ notes: state.notes.map((note) =>
56
+ note.id === id
57
+ ? {
58
+ ...note,
59
+ ...partial,
60
+ duration: Math.max(partial.duration ?? note.duration, 0.0625),
61
+ }
62
+ : note,
63
+ ),
64
+ }))
65
+ },
66
+ removeNote: (id) => set((state) => ({ notes: state.notes.filter((n) => n.id !== id) })),
67
+ setNotes: (notes) => set(() => ({ notes })),
68
+ setTempo: (tempo) => set(() => ({ tempo: clamp(tempo, 30, 240) })),
69
+ setTimeSignature: (sig) => set(() => ({ timeSignature: sig })),
70
+ setPpq: (ppq) => set(() => ({ ppq })),
71
+ select: (id) => set(() => ({ selectedId: id })),
72
+ setLyric: (id, lyric) =>
73
+ set((state) => ({
74
+ notes: state.notes.map((note) => (note.id === id ? { ...note, lyric } : note)),
75
+ })),
76
+ setPlayhead: (beat) => set(() => ({ playhead: Math.max(beat, 0) })),
77
+ clear: () => set(() => ({ notes: [], selectedId: null })),
78
+ }))
preprocess/tools/midi_editor/src/types.ts ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export type NoteEvent = {
2
+ id: string
3
+ midi: number
4
+ start: number // in beats
5
+ duration: number // in beats
6
+ velocity: number
7
+ lyric: string
8
+ }
9
+
10
+ export type TimeSignature = [number, number]
11
+
12
+ export type ProjectSnapshot = {
13
+ tempo: number
14
+ timeSignature: TimeSignature
15
+ notes: NoteEvent[]
16
+ ppq?: number // Ticks per quarter note (for preserving original MIDI timing)
17
+ }
preprocess/tools/midi_editor/tailwind.config.js ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('tailwindcss').Config} */
2
+ export default {
3
+ content: ['./index.html', './src/**/*.{ts,tsx,js,jsx}'],
4
+ theme: {
5
+ extend: {
6
+ fontFamily: {
7
+ display: ['"Space Grotesk"', '"IBM Plex Sans"', 'system-ui', 'sans-serif'],
8
+ mono: ['"JetBrains Mono"', 'ui-monospace', 'SFMono-Regular', 'monospace'],
9
+ },
10
+ colors: {
11
+ ink: {
12
+ 50: '#f4f7fb',
13
+ 100: '#dfe7f5',
14
+ 200: '#beceec',
15
+ 300: '#95addf',
16
+ 400: '#6a87ce',
17
+ 500: '#4b64bc',
18
+ 600: '#3b4ea7',
19
+ 700: '#32418a',
20
+ 800: '#2c376f',
21
+ 900: '#262f5c',
22
+ },
23
+ ember: '#ff7043',
24
+ mint: '#48e4c2',
25
+ },
26
+ boxShadow: {
27
+ panel: '0 14px 35px rgba(0, 0, 0, 0.25)',
28
+ },
29
+ },
30
+ },
31
+ plugins: [],
32
+ }
33
+
preprocess/tools/midi_editor/tsconfig.app.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
4
+ "target": "ES2022",
5
+ "useDefineForClassFields": true,
6
+ "lib": ["ES2022", "DOM", "DOM.Iterable"],
7
+ "module": "ESNext",
8
+ "types": ["vite/client"],
9
+ "skipLibCheck": true,
10
+
11
+ /* Bundler mode */
12
+ "moduleResolution": "bundler",
13
+ "allowImportingTsExtensions": true,
14
+ "verbatimModuleSyntax": true,
15
+ "moduleDetection": "force",
16
+ "noEmit": true,
17
+ "jsx": "react-jsx",
18
+
19
+ /* Linting */
20
+ "strict": true,
21
+ "noUnusedLocals": true,
22
+ "noUnusedParameters": true,
23
+ "erasableSyntaxOnly": true,
24
+ "noFallthroughCasesInSwitch": true,
25
+ "noUncheckedSideEffectImports": true
26
+ },
27
+ "include": ["src"]
28
+ }
preprocess/tools/midi_editor/tsconfig.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "files": [],
3
+ "references": [
4
+ { "path": "./tsconfig.app.json" },
5
+ { "path": "./tsconfig.node.json" }
6
+ ]
7
+ }
preprocess/tools/midi_editor/tsconfig.node.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
4
+ "target": "ES2023",
5
+ "lib": ["ES2023"],
6
+ "module": "ESNext",
7
+ "types": ["node"],
8
+ "skipLibCheck": true,
9
+
10
+ /* Bundler mode */
11
+ "moduleResolution": "bundler",
12
+ "allowImportingTsExtensions": true,
13
+ "verbatimModuleSyntax": true,
14
+ "moduleDetection": "force",
15
+ "noEmit": true,
16
+
17
+ /* Linting */
18
+ "strict": true,
19
+ "noUnusedLocals": true,
20
+ "noUnusedParameters": true,
21
+ "erasableSyntaxOnly": true,
22
+ "noFallthroughCasesInSwitch": true,
23
+ "noUncheckedSideEffectImports": true
24
+ },
25
+ "include": ["vite.config.ts"]
26
+ }
preprocess/tools/midi_editor/vite.config.ts ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from 'vite'
2
+ import react from '@vitejs/plugin-react'
3
+
4
+ // https://vite.dev/config/
5
+ export default defineConfig({
6
+ plugins: [react()],
7
+ })
preprocess/tools/midi_parser.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SoulX-Singer MIDI <-> metadata converter.
3
+
4
+ Converts between SoulX-Singer-style metadata JSON (with note_text, note_dur,
5
+ note_pitch, note_type per segment) and standard MIDI files. Uses an internal
6
+ Note dataclass (start_s, note_dur, note_text, note_pitch, note_type) as the
7
+ intermediate representation.
8
+ """
9
+ import os
10
+ import json
11
+ import shutil
12
+ from dataclasses import dataclass
13
+ from typing import Any, List, Tuple, Union
14
+
15
+ import librosa
16
+ import mido
17
+ from soundfile import write
18
+
19
+ from .f0_extraction import F0Extractor
20
+ from .g2p import g2p_transform
21
+
22
+
23
+ # Audio, MIDI and segmentation constants
24
+ SAMPLE_RATE = 44100 # Audio sample rate for any wav cuts during midi2meta
25
+ MIDI_TICKS_PER_BEAT = 500 # The number of MIDI ticks per beat; affects the time resolution of MIDI output and conversion accuracy.
26
+ MIDI_TEMPO = 500000 # Microseconds per beat (120 BPM)
27
+ MIDI_TIME_SIGNATURE = (4, 4) # Default time signature; not critical for conversion but included in MIDI output.
28
+ MIDI_VELOCITY = 64 # Default velocity for note_on events; not critical for conversion but required for MIDI format.
29
+ END_EXTENSION_SEC = 0.4 # Extend each segment end by this much silence (sec) to give the model more context
30
+ MAX_GAP_SEC = 2.0 # Gap threshold to split segments in midi2meta (sec)
31
+ MAX_SEGMENT_DUR_SUM_SEC = 60.0 # Max total duration sum of notes in a single metadata segment before splitting into multiple segments (sec)
32
+ SILENCE_THRESHOLD_SEC = 0.2 # Threshold to insert explicit <SP> note for long silences between notes in midi2notes (sec)
33
+
34
+
35
+ @dataclass
36
+ class Note:
37
+ """Single note: text, duration (seconds), pitch (MIDI), type. start_s is absolute start time in seconds (for ordering / MIDI)."""
38
+ start_s: float
39
+ note_dur: float
40
+ note_text: str
41
+ note_pitch: int
42
+ note_type: int
43
+
44
+ @property
45
+ def end_s(self) -> float:
46
+ return self.start_s + self.note_dur
47
+
48
+
49
+ def _seconds_to_ticks(seconds: float, ticks_per_beat: int, tempo: int) -> int:
50
+ """Convert seconds to MIDI ticks based on tempo and ticks per beat."""
51
+ return int(round(seconds * ticks_per_beat * 1_000_000 / tempo))
52
+
53
+
54
+ def _append_segment_to_meta(
55
+ meta_data: List[dict],
56
+ meta_path_str: str,
57
+ cut_wavs_output_dir: str | None,
58
+ vocal_file: str | None,
59
+ language: str,
60
+ audio_data: Any | None,
61
+ pitch_extractor: F0Extractor | None,
62
+ note_start: List[float],
63
+ note_end: List[float],
64
+ note_text: List[Any],
65
+ note_pitch: List[Any],
66
+ note_type: List[Any],
67
+ note_dur: List[float],
68
+ ) -> None:
69
+ """Helper function for midi2meta to append the current segment (accumulated in note_*) to meta_data list, with optional wav cut and pitch extraction."""
70
+ if not all((note_start, note_end, note_text, note_pitch, note_type, note_dur)):
71
+ return
72
+
73
+ base_name = os.path.splitext(os.path.basename(meta_path_str))[0]
74
+ item_name = f"{base_name}_{len(meta_data)}"
75
+ wav_fn = None
76
+ if cut_wavs_output_dir and vocal_file and audio_data is not None:
77
+ wav_fn = os.path.join(cut_wavs_output_dir, f"{item_name}.wav")
78
+ end_pad = int(END_EXTENSION_SEC * SAMPLE_RATE)
79
+ start_sample = max(0, int(note_start[0] * SAMPLE_RATE))
80
+ end_sample = min(len(audio_data), int(note_end[-1] * SAMPLE_RATE) + end_pad)
81
+
82
+ end_pad_dur = (end_sample / SAMPLE_RATE - note_end[-1]) if end_sample > int(note_end[-1] * SAMPLE_RATE) else 0.0
83
+ if end_pad_dur > 0:
84
+ note_dur = note_dur + [end_pad_dur]
85
+ note_text = note_text + ["<SP>"]
86
+ note_pitch = note_pitch + [0]
87
+ note_type = note_type + [1]
88
+ start_ms = int(start_sample / SAMPLE_RATE * 1000)
89
+ end_ms = int(end_sample / SAMPLE_RATE * 1000)
90
+ write(wav_fn, audio_data[start_sample:end_sample], SAMPLE_RATE)
91
+ else:
92
+ start_ms = int(note_start[0] * 1000)
93
+ end_ms = int(note_end[-1] * 1000)
94
+
95
+ if pitch_extractor is not None:
96
+ if not wav_fn or not os.path.isfile(wav_fn):
97
+ raise FileNotFoundError(f"Segment wav file not found: {wav_fn}")
98
+ f0 = pitch_extractor.process(wav_fn)
99
+ else:
100
+ f0 = []
101
+
102
+ note_text_list = list(note_text)
103
+ note_pitch_list = list(note_pitch)
104
+ note_type_list = list(note_type)
105
+ note_dur_list = list(note_dur)
106
+
107
+ meta_data.append(
108
+ {
109
+ "index": item_name,
110
+ "language": language,
111
+ "time": [start_ms, end_ms],
112
+ "duration": " ".join(str(round(x, 2)) for x in note_dur_list),
113
+ "text": " ".join(note_text_list),
114
+ "phoneme": " ".join(g2p_transform(note_text_list, language)),
115
+ "note_pitch": " ".join(str(x) for x in note_pitch_list),
116
+ "note_type": " ".join(str(x) for x in note_type_list),
117
+ "f0": " ".join(str(round(float(x), 1)) for x in f0),
118
+ }
119
+ )
120
+
121
+
122
+ def meta2notes(meta_path: str) -> List[Note]:
123
+ """Parse SoulX-Singer metadata JSON into a flat list of Note (absolute start_s)."""
124
+ with open(meta_path, "r", encoding="utf-8") as f:
125
+ segments = json.load(f)
126
+ if not isinstance(segments, list):
127
+ raise ValueError(f"Metadata must be a list of segments, got {type(segments).__name__}")
128
+ if not segments:
129
+ raise ValueError("Metadata has no segments.")
130
+
131
+ notes: List[Note] = []
132
+ for seg in segments:
133
+ offset_s = seg["time"][0] / 1000
134
+ words = [str(x).replace("<AP>", "<SP>") for x in seg["text"].split()]
135
+ word_durs = [float(x) for x in seg["duration"].split()]
136
+ pitches = [int(x) for x in seg["note_pitch"].split()]
137
+ types = [int(x) if words[i] != "<SP>" else 1 for i, x in enumerate(seg["note_type"].split())]
138
+ if len(words) != len(word_durs) or len(word_durs) != len(pitches) or len(pitches) != len(types):
139
+ raise ValueError(
140
+ f"Length mismatch in segment {seg.get('item_name', '?')}: "
141
+ "note_text, note_dur, note_pitch, note_type must have same length"
142
+ )
143
+ current_s = offset_s
144
+ for text, dur, pitch, type_ in zip(words, word_durs, pitches, types):
145
+ notes.append(
146
+ Note(
147
+ start_s=current_s,
148
+ note_dur=float(dur),
149
+ note_text=str(text),
150
+ note_pitch=int(pitch),
151
+ note_type=int(type_),
152
+ )
153
+ )
154
+ current_s += float(dur)
155
+ return notes
156
+
157
+
158
+ def notes2meta(
159
+ notes: List[Note],
160
+ meta_path: str,
161
+ vocal_file: str | None,
162
+ language: str,
163
+ pitch_extractor: F0Extractor | None,
164
+ ) -> None:
165
+ """Write SoulX-Singer metadata JSON from a list of Note (segmenting + wav cuts)."""
166
+ meta_path_str = str(meta_path)
167
+
168
+ cut_wavs_output_dir = None
169
+ if vocal_file:
170
+ cut_wavs_output_dir = os.path.join(os.path.dirname(vocal_file), "cut_wavs_tmp")
171
+ os.makedirs(cut_wavs_output_dir, exist_ok=True)
172
+
173
+ note_text: List[Any] = []
174
+ note_pitch: List[Any] = []
175
+ note_type: List[Any] = []
176
+ note_dur: List[float] = []
177
+ note_start: List[float] = []
178
+ note_end: List[float] = []
179
+ meta_data: List[dict] = []
180
+ audio_data = None
181
+ if vocal_file:
182
+ audio_data, _ = librosa.load(vocal_file, sr=SAMPLE_RATE, mono=True)
183
+ dur_sum = 0.0
184
+
185
+ def flush_current_segment() -> None:
186
+ nonlocal dur_sum
187
+ _append_segment_to_meta(
188
+ meta_data,
189
+ meta_path_str,
190
+ cut_wavs_output_dir,
191
+ vocal_file,
192
+ language,
193
+ audio_data,
194
+ pitch_extractor,
195
+ note_start,
196
+ note_end,
197
+ note_text,
198
+ note_pitch,
199
+ note_type,
200
+ note_dur,
201
+ )
202
+ note_text.clear()
203
+ note_pitch.clear()
204
+ note_type.clear()
205
+ note_dur.clear()
206
+ note_start.clear()
207
+ note_end.clear()
208
+ dur_sum = 0.0
209
+
210
+ def append_note(start: float, end: float, text: str, pitch: int, type_: int) -> None:
211
+ nonlocal dur_sum
212
+ duration = end - start
213
+ if duration <= 0:
214
+ return
215
+
216
+ if len(note_text) > 0 and text == "<SP>" and note_text[-1] == "<SP>":
217
+ note_dur[-1] += duration
218
+ note_end[-1] = end
219
+ else:
220
+ note_text.append(text)
221
+ note_pitch.append(pitch)
222
+ note_type.append(type_)
223
+ note_dur.append(duration)
224
+ note_start.append(start)
225
+ note_end.append(end)
226
+ dur_sum += duration
227
+
228
+ for note in notes:
229
+ start = float(note.start_s)
230
+ end = float(note.end_s)
231
+ text = note.note_text
232
+ pitch = note.note_pitch
233
+ type_ = note.note_type
234
+
235
+ if text == "" or pitch == "" or type_ == "":
236
+ append_note(start, end, "<SP>", 0, 1)
237
+ continue
238
+
239
+ # cut the segment when ends with a long <SP> note
240
+ if (
241
+ len(note_text) > 0
242
+ and note_text[-1] == "<SP>"
243
+ and note_dur[-1] > MAX_GAP_SEC
244
+ ):
245
+ note_text.pop()
246
+ note_pitch.pop()
247
+ note_type.pop()
248
+ note_dur.pop()
249
+ note_start.pop()
250
+ note_end.pop()
251
+
252
+ dur_sum = sum(note_dur)
253
+ flush_current_segment()
254
+
255
+ # cut the segment if adding the current note would exceed the max duration sum threshold
256
+ if dur_sum + (end - start) > MAX_SEGMENT_DUR_SUM_SEC and len(note_text) > 0:
257
+ flush_current_segment()
258
+
259
+ append_note(start, end, text, int(pitch), int(type_))
260
+
261
+ if note_text:
262
+ flush_current_segment()
263
+
264
+ with open(meta_path_str, "w", encoding="utf-8") as f:
265
+ json.dump(meta_data, f, ensure_ascii=False, indent=2)
266
+
267
+ if cut_wavs_output_dir:
268
+ try:
269
+ shutil.rmtree(cut_wavs_output_dir, ignore_errors=True)
270
+ except Exception:
271
+ pass
272
+
273
+
274
+ def notes2midi(
275
+ notes: List[Note],
276
+ midi_path: str,
277
+ ) -> None:
278
+ """Write MIDI file from a list of Note."""
279
+ if not notes:
280
+ raise ValueError("Empty note list.")
281
+
282
+ events: List[Tuple[int, int, Union[mido.Message, mido.MetaMessage]]] = []
283
+ for n in notes:
284
+ start_s = n.start_s
285
+ end_s = n.end_s
286
+ if end_s <= start_s:
287
+ continue
288
+
289
+ start_ticks = _seconds_to_ticks(
290
+ start_s, MIDI_TICKS_PER_BEAT, MIDI_TEMPO
291
+ )
292
+ end_ticks = _seconds_to_ticks(
293
+ end_s, MIDI_TICKS_PER_BEAT, MIDI_TEMPO
294
+ )
295
+ if end_ticks <= start_ticks:
296
+ end_ticks = start_ticks + 1
297
+
298
+ lyric = n.note_text
299
+ # Some DAWs store lyric text as latin1-compatible bytes; keep best-effort round-trip.
300
+ try:
301
+ lyric = lyric.encode("utf-8").decode("latin1")
302
+ except (UnicodeEncodeError, UnicodeDecodeError):
303
+ pass
304
+ if n.note_type == 3:
305
+ lyric = "-"
306
+
307
+ events.append(
308
+ (start_ticks, 1, mido.MetaMessage("lyrics", text=lyric, time=0))
309
+ )
310
+ events.append(
311
+ (
312
+ start_ticks,
313
+ 2,
314
+ mido.Message(
315
+ "note_on",
316
+ note=n.note_pitch,
317
+ velocity=MIDI_VELOCITY,
318
+ time=0,
319
+ ),
320
+ )
321
+ )
322
+ events.append(
323
+ (
324
+ end_ticks,
325
+ 0,
326
+ mido.Message("note_off", note=n.note_pitch, velocity=0, time=0),
327
+ )
328
+ )
329
+
330
+ events.sort(key=lambda x: (x[0], x[1]))
331
+
332
+ mid = mido.MidiFile(ticks_per_beat=MIDI_TICKS_PER_BEAT)
333
+ track = mido.MidiTrack()
334
+ mid.tracks.append(track)
335
+
336
+ track.append(mido.MetaMessage("set_tempo", tempo=MIDI_TEMPO, time=0))
337
+ track.append(
338
+ mido.MetaMessage(
339
+ "time_signature",
340
+ numerator=MIDI_TIME_SIGNATURE[0],
341
+ denominator=MIDI_TIME_SIGNATURE[1],
342
+ time=0,
343
+ )
344
+ )
345
+
346
+ last_tick = 0
347
+ for tick, _, msg in events:
348
+ msg.time = max(0, tick - last_tick)
349
+ track.append(msg)
350
+ last_tick = tick
351
+
352
+ track.append(mido.MetaMessage("end_of_track", time=0))
353
+ mid.save(midi_path)
354
+
355
+
356
+ def midi2notes(midi_path: str) -> List[Note]:
357
+ """Parse MIDI file into a list of Note."""
358
+ mid = mido.MidiFile(midi_path)
359
+ ticks_per_beat = mid.ticks_per_beat
360
+ tempo = 500000
361
+
362
+ raw_notes: List[dict] = []
363
+ lyrics: List[Tuple[int, str]] = []
364
+
365
+ for track in mid.tracks:
366
+ abs_ticks = 0
367
+ active = {}
368
+ for msg in track:
369
+ abs_ticks += msg.time
370
+ if msg.type == "set_tempo":
371
+ tempo = msg.tempo
372
+ elif msg.type == "lyrics":
373
+ text = msg.text
374
+ try:
375
+ text = text.encode("latin1").decode("utf-8")
376
+ except Exception:
377
+ pass
378
+ lyrics.append((abs_ticks, text))
379
+ elif msg.type == "note_on":
380
+ key = (msg.channel, msg.note)
381
+ if msg.velocity > 0:
382
+ active[key] = (abs_ticks, msg.velocity)
383
+ else:
384
+ if key in active:
385
+ start_ticks, vel = active.pop(key)
386
+ raw_notes.append(
387
+ {
388
+ "midi": msg.note,
389
+ "start_ticks": start_ticks,
390
+ "duration_ticks": abs_ticks - start_ticks,
391
+ "velocity": vel,
392
+ "lyric": "",
393
+ }
394
+ )
395
+ elif msg.type == "note_off":
396
+ key = (msg.channel, msg.note)
397
+ if key in active:
398
+ start_ticks, vel = active.pop(key)
399
+ raw_notes.append(
400
+ {
401
+ "midi": msg.note,
402
+ "start_ticks": start_ticks,
403
+ "duration_ticks": abs_ticks - start_ticks,
404
+ "velocity": vel,
405
+ "lyric": "",
406
+ }
407
+ )
408
+
409
+ if not raw_notes:
410
+ raise ValueError("No notes found in MIDI file")
411
+
412
+ for n in raw_notes:
413
+ n["end_ticks"] = n["start_ticks"] + n["duration_ticks"]
414
+
415
+ raw_notes.sort(key=lambda n: n["start_ticks"])
416
+ lyrics.sort(key=lambda x: x[0])
417
+
418
+ trimmed = []
419
+ # Remove/trim overlaps so generated notes are strictly non-overlapping in tick domain.
420
+ for note in raw_notes:
421
+ while trimmed:
422
+ prev = trimmed[-1]
423
+ if note["start_ticks"] < prev["end_ticks"]:
424
+ prev["end_ticks"] = note["start_ticks"]
425
+ prev["duration_ticks"] = prev["end_ticks"] - prev["start_ticks"]
426
+ if prev["duration_ticks"] <= 0:
427
+ trimmed.pop()
428
+ continue
429
+ break
430
+ trimmed.append(note)
431
+ raw_notes = trimmed
432
+
433
+ tolerance = ticks_per_beat // 100
434
+ # Attach lyrics near note_on positions with a small tick tolerance.
435
+ lyric_idx = 0
436
+ for note in raw_notes:
437
+ while lyric_idx < len(lyrics) and lyrics[lyric_idx][0] < note["start_ticks"] - tolerance:
438
+ lyric_idx += 1
439
+ if lyric_idx < len(lyrics):
440
+ lyric_ticks, lyric_text = lyrics[lyric_idx]
441
+ if abs(lyric_ticks - note["start_ticks"]) <= tolerance:
442
+ note["lyric"] = lyric_text
443
+ lyric_idx += 1
444
+
445
+ def ticks_to_seconds(ticks: int) -> float:
446
+ return (ticks / ticks_per_beat) * (tempo / 1_000_000)
447
+
448
+ result: List[Note] = []
449
+ prev_end_s = 0.0
450
+ for idx, n in enumerate(raw_notes):
451
+ start_s = ticks_to_seconds(n["start_ticks"])
452
+ end_s = ticks_to_seconds(n["end_ticks"])
453
+ if prev_end_s > start_s:
454
+ start_s = prev_end_s
455
+ dur_s = end_s - start_s
456
+ if dur_s <= 0:
457
+ continue
458
+
459
+ lyric = n.get("lyric", "")
460
+ # SoulX-Singer convention mapping from lyric token to note_type/text.
461
+ if not lyric:
462
+ note_type = 2
463
+ text = "啦"
464
+ elif lyric == "<SP>":
465
+ note_type = 1
466
+ text = "<SP>"
467
+ elif lyric == "-":
468
+ note_type = 3
469
+ text = raw_notes[idx - 1].get("lyric", "-") if idx > 0 else "-"
470
+ else:
471
+ note_type = 2
472
+ text = lyric
473
+
474
+ if start_s - prev_end_s > SILENCE_THRESHOLD_SEC:
475
+ # Explicitly represent long gaps as <SP> notes.
476
+ result.append(
477
+ Note(
478
+ start_s=prev_end_s,
479
+ note_dur=start_s - prev_end_s,
480
+ note_text="<SP>",
481
+ note_pitch=0,
482
+ note_type=1,
483
+ )
484
+ )
485
+ else:
486
+ if len(result) > 0:
487
+ result[-1].note_dur = start_s - result[-1].start_s
488
+
489
+ result.append(
490
+ Note(
491
+ start_s=start_s,
492
+ note_dur=dur_s,
493
+ note_text=text,
494
+ note_pitch=n["midi"],
495
+ note_type=note_type,
496
+ )
497
+ )
498
+ prev_end_s = end_s
499
+
500
+ return result
501
+
502
+
503
+ class MidiParser:
504
+ def __init__(
505
+ self,
506
+ rmvpe_model_path: str,
507
+ device: str = "cuda",
508
+ ) -> None:
509
+ self.rmvpe_model_path = rmvpe_model_path
510
+ self.device = device
511
+ self.pitch_extractor: F0Extractor | None = None
512
+
513
+ def _get_pitch_extractor(self) -> F0Extractor:
514
+ if self.pitch_extractor is None:
515
+ self.pitch_extractor = F0Extractor(
516
+ self.rmvpe_model_path,
517
+ device=self.device,
518
+ verbose=False,
519
+ )
520
+ return self.pitch_extractor
521
+
522
+ def midi2meta(
523
+ self,
524
+ midi_path: str,
525
+ meta_path: str,
526
+ vocal_file: str | None = None,
527
+ language: str = "Mandarin",
528
+ ) -> None:
529
+ meta_dir = os.path.dirname(meta_path)
530
+ if meta_dir:
531
+ os.makedirs(meta_dir, exist_ok=True)
532
+
533
+ notes = midi2notes(midi_path)
534
+ pitch_extractor = self._get_pitch_extractor() if vocal_file else None
535
+ notes2meta(
536
+ notes,
537
+ meta_path,
538
+ vocal_file,
539
+ language,
540
+ pitch_extractor=pitch_extractor,
541
+ )
542
+ print(f"Saved Meta to {meta_path}")
543
+
544
+ def meta2midi(self, meta_path: str, midi_path: str) -> None:
545
+ notes = meta2notes(meta_path)
546
+ notes2midi(notes, midi_path)
547
+ print(f"Saved MIDI to {midi_path}")
548
+
549
+ if __name__ == "__main__":
550
+ import argparse
551
+
552
+ parser = argparse.ArgumentParser(
553
+ description="Convert SoulX-Singer metadata JSON <-> MIDI."
554
+ )
555
+ parser.add_argument("--meta", type=str, help="Path to metadata JSON")
556
+ parser.add_argument("--midi", type=str, help="Path to MIDI file")
557
+ parser.add_argument("--vocal", type=str, default=None, help="Path to vocal wav (optional for midi2meta)")
558
+ parser.add_argument("--language", type=str, default="Mandarin", help="Lyric language for metadata phoneme conversion (default: Mandarin)")
559
+ parser.add_argument(
560
+ "--meta2midi",
561
+ action="store_true",
562
+ help="Convert meta -> midi (requires --meta and --midi)",
563
+ )
564
+ parser.add_argument(
565
+ "--midi2meta",
566
+ action="store_true",
567
+ help="Convert midi -> meta (requires --midi and --meta; --vocal is optional)",
568
+ )
569
+ parser.add_argument(
570
+ "--rmvpe_model_path",
571
+ type=str,
572
+ help="Path to RMVPE model",
573
+ default="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
574
+ )
575
+ parser.add_argument(
576
+ "--device",
577
+ type=str,
578
+ help="Device to use for RMVPE",
579
+ default="cuda",
580
+ )
581
+ args = parser.parse_args()
582
+ midi_parser = MidiParser(
583
+ rmvpe_model_path=args.rmvpe_model_path,
584
+ device=args.device,
585
+ )
586
+
587
+ if args.meta2midi:
588
+ if not args.meta or not args.midi:
589
+ parser.error("--meta2midi requires --meta and --midi")
590
+ midi_parser.meta2midi(args.meta, args.midi)
591
+ elif args.midi2meta:
592
+ if not args.midi or not args.meta:
593
+ parser.error(
594
+ "--midi2meta requires --midi and --meta"
595
+ )
596
+ midi_parser.midi2meta(args.midi, args.meta, args.vocal, args.language)
597
+ else:
598
+ parser.print_help()
preprocess/tools/note_transcription/__init__.py ADDED
File without changes
preprocess/tools/note_transcription/model.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/RickyL-2000/ROSVOT
2
+ import math
3
+ import sys
4
+ import traceback
5
+ import json
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ import librosa
11
+ import numpy as np
12
+ import torch
13
+ import matplotlib.pyplot as plt
14
+
15
+ from .utils.os_utils import safe_path
16
+ from .utils.commons.hparams import set_hparams
17
+ from .utils.commons.ckpt_utils import load_ckpt
18
+ from .utils.commons.dataset_utils import pad_or_cut_xd
19
+ from .utils.audio.mel import MelNet
20
+ from .utils.audio.pitch_utils import (
21
+ norm_interp_f0,
22
+ denorm_f0,
23
+ f0_to_coarse,
24
+ boundary2Interval,
25
+ save_midi,
26
+ midi_to_hz,
27
+ )
28
+ from .utils.rosvot_utils import (
29
+ get_mel_len,
30
+ align_word,
31
+ regulate_real_note_itv,
32
+ regulate_ill_slur,
33
+ bd_to_durs,
34
+ )
35
+ from .modules.pe.rmvpe import RMVPE
36
+ from .modules.rosvot.rosvot import MidiExtractor, WordbdExtractor
37
+
38
+
39
+ @torch.no_grad()
40
+ def infer_sample(
41
+ item: Dict[str, Any],
42
+ hparams: Dict[str, Any],
43
+ models: Dict[str, Any],
44
+ device: torch.device,
45
+ *,
46
+ save_dir: Optional[str] = None,
47
+ apply_rwbd: Optional[bool] = None,
48
+ # outputs
49
+ save_plot: bool = False,
50
+ no_save_midi: bool = True,
51
+ no_save_npy: bool = True,
52
+ verbose: bool = False,
53
+ ) -> Dict[str, Any]:
54
+ if "item_name" not in item or "wav_fn" not in item:
55
+ raise ValueError('item must contain keys: "item_name" and "wav_fn"')
56
+
57
+ item_name = item["item_name"]
58
+ wav_src = item["wav_fn"]
59
+
60
+ # Decide RWBD usage
61
+ if apply_rwbd is None:
62
+ apply_rwbd_ = ("word_durs" not in item)
63
+ else:
64
+ apply_rwbd_ = bool(apply_rwbd)
65
+
66
+ # Models
67
+ model = models["model"]
68
+ mel_net = models["mel_net"]
69
+ pe = models.get("pe")
70
+ wbd_predictor = models.get("wbd_predictor")
71
+
72
+ if wbd_predictor is None and apply_rwbd_:
73
+ raise ValueError("apply_rwbd is True but wbd_predictor model is not provided in models")
74
+
75
+ # ---- Prepare Data ----
76
+ if isinstance(wav_src, str):
77
+ wav, _ = librosa.core.load(wav_src, sr=hparams["audio_sample_rate"])
78
+ else:
79
+ wav = wav_src
80
+ if not isinstance(wav, np.ndarray):
81
+ wav = np.asarray(wav)
82
+ wav = wav.astype(np.float32)
83
+
84
+ # Calculate timestamps and alignment lengths
85
+ wav_len_samples = wav.shape[-1]
86
+ mel_len = get_mel_len(wav_len_samples, hparams["hop_size"])
87
+
88
+ # Word boundary preparation
89
+ mel2word = None
90
+ word_durs_filtered = None
91
+
92
+ if not apply_rwbd_:
93
+ if "word_durs" not in item:
94
+ raise ValueError('apply_rwbd=False but item has no "word_durs"')
95
+
96
+ wd_raw = list(item["word_durs"])
97
+ min_word_dur = hparams.get("min_word_dur", 20) / 1000
98
+ word_durs_filtered = []
99
+
100
+ for i, wd in enumerate(wd_raw):
101
+ if wd < min_word_dur:
102
+ if i == 0 and len(wd_raw) > 1:
103
+ wd_raw[i + 1] += wd
104
+ elif len(word_durs_filtered) > 0:
105
+ word_durs_filtered[-1] += wd
106
+ else:
107
+ word_durs_filtered.append(wd)
108
+
109
+ mel2word, _ = align_word(word_durs_filtered, mel_len, hparams["hop_size"], hparams["audio_sample_rate"])
110
+ mel2word = np.asarray(mel2word)
111
+ if mel2word.size > 0 and mel2word[0] == 0:
112
+ mel2word = mel2word + 1
113
+
114
+ mel2word_len = int(np.sum(mel2word > 0))
115
+ real_len = min(mel_len, mel2word_len)
116
+ else:
117
+ real_len = min(mel_len, hparams["max_frames"])
118
+
119
+ T = math.ceil(min(real_len, hparams["max_frames"]) / hparams["frames_multiple"]) * hparams["frames_multiple"]
120
+
121
+ # ---- Input Tensors & Padding ----
122
+ target_samples = T * hparams["hop_size"]
123
+ wav_t = torch.from_numpy(wav).float().to(device).unsqueeze(0) # [1, L]
124
+ if wav_t.shape[-1] < target_samples:
125
+ wav_t = pad_or_cut_xd(wav_t, target_samples, 1)
126
+
127
+ # ---- Pitch Extraction ----
128
+ if pe is not None:
129
+ f0s, uvs = pe.get_pitch_batch(
130
+ wav_t,
131
+ sample_rate=hparams["audio_sample_rate"],
132
+ hop_size=hparams["hop_size"],
133
+ lengths=[real_len],
134
+ fmax=hparams["f0_max"],
135
+ fmin=hparams["f0_min"],
136
+ )
137
+ f0_1d, uv_1d = norm_interp_f0(f0s[0][:T])
138
+ f0_t = pad_or_cut_xd(torch.FloatTensor(f0_1d).to(device), T, 0).unsqueeze(0)
139
+ uv_t = pad_or_cut_xd(torch.FloatTensor(uv_1d).to(device), T, 0).long().unsqueeze(0)
140
+ pitch_coarse = f0_to_coarse(denorm_f0(f0_t, uv_t)).to(device)
141
+ f0_np = denorm_f0(f0_t, uv_t)[0].detach().cpu().numpy()[:real_len]
142
+ else:
143
+ f0_t = uv_t = pitch_coarse = None
144
+ f0_np = None
145
+
146
+ # ---- Mel Extraction ----
147
+ mel = mel_net(wav_t) # [1, T_padded, C]
148
+ mel = pad_or_cut_xd(mel, T, 1)
149
+
150
+ # Construct non-padding mask
151
+ mel_nonpadding_mask = torch.zeros(1, T, device=device)
152
+ mel_nonpadding_mask[:, :real_len] = 1.0
153
+
154
+ # Apply mask to mel (zero out padding)
155
+ mel = (mel.transpose(1, 2) * mel_nonpadding_mask.unsqueeze(1)).transpose(1, 2)
156
+ # Re-calculate non_padding bool mask
157
+ mel_nonpadding = mel.abs().sum(-1) > 0
158
+
159
+ # ---- Word Boundary ----
160
+ word_durs_used = None
161
+ if apply_rwbd_:
162
+ mel_input = mel[:, :, : hparams.get("wbd_use_mel_bins", 80)]
163
+ wbd_outputs = wbd_predictor(
164
+ mel=mel_input,
165
+ pitch=pitch_coarse,
166
+ uv=uv_t,
167
+ non_padding=mel_nonpadding,
168
+ train=False,
169
+ )
170
+ word_bd = wbd_outputs["word_bd_pred"] # [1, T]
171
+ else:
172
+ # Construct word_bd from provided durs
173
+ mel2word_t = pad_or_cut_xd(torch.LongTensor(mel2word).to(device), T, 0)
174
+ word_bd = torch.zeros_like(mel2word_t)
175
+ # Vectorized check
176
+ word_bd[1:] = (mel2word_t[1:] != mel2word_t[:-1]).long()
177
+ word_bd[real_len:] = 0
178
+ word_bd = word_bd.unsqueeze(0) # [1, T]
179
+
180
+ word_durs_used = np.array(word_durs_filtered)
181
+
182
+ # ---- Main Inference ----
183
+ mel_input = mel[:, :, : hparams.get("use_mel_bins", 80)]
184
+ outputs = model(
185
+ mel=mel_input,
186
+ word_bd=word_bd,
187
+ pitch=pitch_coarse,
188
+ uv=uv_t,
189
+ non_padding=mel_nonpadding,
190
+ train=False,
191
+ )
192
+
193
+ note_lengths = outputs["note_lengths"].detach().cpu().numpy()
194
+ note_bd_pred = outputs["note_bd_pred"][0].detach().cpu().numpy()[:real_len]
195
+ note_pred = outputs["note_pred"][0].detach().cpu().numpy()[: note_lengths[0]]
196
+ note_bd_logits = torch.sigmoid(outputs["note_bd_logits"])[0].detach().cpu().numpy()[:real_len]
197
+
198
+ if note_pred.shape == (0,):
199
+ if verbose:
200
+ print(f"skip {item_name}: no notes detected")
201
+ return {
202
+ "item_name": item_name,
203
+ "pitches": [],
204
+ "note_durs": [],
205
+ "note2words": None,
206
+ }
207
+
208
+ # ---- Post-Processing & Regulation ----
209
+ note_itv_pred = boundary2Interval(note_bd_pred)
210
+ note2words = None
211
+
212
+ if apply_rwbd_:
213
+ word_bd_np = outputs['word_bd_pred'][0].detach().cpu().numpy()[:real_len]
214
+ word_durs_derived = np.array(bd_to_durs(word_bd_np)) * hparams['hop_size'] / hparams['audio_sample_rate']
215
+ word_durs_for_reg = word_durs_derived
216
+ word_bd_for_reg = word_bd_np
217
+ else:
218
+ word_bd_for_reg = word_bd[0].detach().cpu().numpy()[:real_len]
219
+ word_durs_for_reg = word_durs_used
220
+
221
+ should_regulate = hparams.get("infer_regulate_real_note_itv", True) and (not apply_rwbd_)
222
+
223
+ if should_regulate and (word_durs_for_reg is not None):
224
+ try:
225
+ note_itv_pred_secs, note2words = regulate_real_note_itv(
226
+ note_itv_pred,
227
+ note_bd_pred,
228
+ word_bd_for_reg,
229
+ word_durs_for_reg,
230
+ hparams["hop_size"],
231
+ hparams["audio_sample_rate"],
232
+ )
233
+ note_pred, note_itv_pred_secs, note2words = regulate_ill_slur(note_pred, note_itv_pred_secs, note2words)
234
+ except Exception as err:
235
+ if verbose:
236
+ _, exc_value, exc_tb = sys.exc_info()
237
+ tb = traceback.extract_tb(exc_tb)[-1]
238
+ print(f"postprocess failed: {err}: {exc_value} in {tb[0]}:{tb[1]} '{tb[2]}' in {tb[3]}")
239
+ # Fallback
240
+ note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
241
+ note2words = None
242
+ else:
243
+ note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
244
+
245
+ # ---- Output ----
246
+ note_durs = [float((itv[1] - itv[0])) for itv in note_itv_pred_secs]
247
+
248
+ out = {
249
+ "item_name": item_name,
250
+ "pitches": note_pred.tolist(),
251
+ "note_durs": note_durs,
252
+ "note2words": note2words.tolist() if note2words is not None else None,
253
+ }
254
+
255
+ # ---- Saving ----
256
+ if save_dir is not None:
257
+ save_dir_path = Path(save_dir)
258
+ save_dir_path.mkdir(parents=True, exist_ok=True)
259
+ fn = str(item_name)
260
+
261
+ if not no_save_midi:
262
+ save_midi(note_pred, note_itv_pred_secs, safe_path(save_dir_path / "midi" / f"{fn}.mid"))
263
+
264
+ if not no_save_npy:
265
+ np.save(safe_path(save_dir_path / "npy" / f"[note]{fn}.npy"), out, allow_pickle=True)
266
+
267
+ if save_plot:
268
+ fig = plt.figure()
269
+ if f0_np is not None:
270
+ plt.plot(f0_np, color="red", label="f0")
271
+
272
+ midi_pred = np.zeros(note_bd_pred.shape[0], dtype=np.float32)
273
+ itvs = np.round(note_itv_pred_secs * hparams["audio_sample_rate"] / hparams["hop_size"]).astype(int)
274
+ for i, itv in enumerate(itvs):
275
+ midi_pred[itv[0] : itv[1]] = note_pred[i]
276
+ plt.plot(midi_to_hz(midi_pred), color="blue", label="pred midi")
277
+ plt.plot(note_bd_logits * 100, color="green", label="note bd logits x100")
278
+ plt.legend()
279
+ plt.tight_layout()
280
+ plt.savefig(safe_path(save_dir_path / "plot" / f"[MIDI]{fn}.png"), format="png")
281
+ plt.close(fig)
282
+
283
+ return out
284
+
285
+
286
+ def load_rosvot_models(ckpt, config="", wbd_ckpt="", wbd_config="", device="cuda:0", verbose=False, thr=0.85):
287
+ """
288
+ Load models once to reuse across multiple items.
289
+ """
290
+ dev = torch.device(device)
291
+
292
+ # 1. Hparams
293
+ config_path = Path(ckpt).with_name("config.yaml") if config == "" else config
294
+ pe_ckpt = Path(ckpt).parent.parent / "rmvpe/model.pt"
295
+ hparams = set_hparams(
296
+ config=config_path,
297
+ print_hparams=verbose,
298
+ hparams_str=f"note_bd_threshold={thr}",
299
+ )
300
+
301
+ # 2. Main Model
302
+ model = MidiExtractor(hparams)
303
+ load_ckpt(model, ckpt, verbose=verbose)
304
+ model.eval().to(dev)
305
+
306
+ # 3. MelNet
307
+ mel_net = MelNet(hparams)
308
+ mel_net.to(dev)
309
+
310
+ # 4. Pitch Extractor
311
+ pe = None
312
+ if hparams.get("use_pitch_embed", False):
313
+ pe = RMVPE(pe_ckpt, device=dev)
314
+
315
+ # 5. Word Boundary Predictor (optional but we load if ckpt provided or needed)
316
+ wbd_predictor = None
317
+ if wbd_ckpt:
318
+ wbd_config_path = Path(wbd_ckpt).with_name("config.yaml") if wbd_config == "" else wbd_config
319
+ wbd_hparams = set_hparams(
320
+ config=wbd_config_path,
321
+ print_hparams=False,
322
+ hparams_str="",
323
+ )
324
+ hparams.update({
325
+ "wbd_use_mel_bins": wbd_hparams["use_mel_bins"],
326
+ "min_word_dur": wbd_hparams["min_word_dur"],
327
+ })
328
+ wbd_predictor = WordbdExtractor(wbd_hparams)
329
+ load_ckpt(wbd_predictor, wbd_ckpt, verbose=verbose)
330
+ wbd_predictor.eval().to(dev)
331
+
332
+ models = {
333
+ "model": model,
334
+ "mel_net": mel_net,
335
+ "pe": pe,
336
+ "wbd_predictor": wbd_predictor
337
+ }
338
+ return hparams, models
339
+
340
+
341
+ class NoteTranscriber:
342
+ """Note transcription wrapper based on ROSVOT.
343
+
344
+ Loads ROSVOT and optional RWBD models once in ``__init__`` and
345
+ exposes a :py:meth:`process` API that turns an item dict into
346
+ aligned note metadata for downstream SVS.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ rosvot_model_path: str,
352
+ rwbd_model_path: str,
353
+ *,
354
+ rosvot_config_path: str = "",
355
+ rwbd_config_path: str = "",
356
+ device: str = "cuda:0",
357
+ thr: float = 0.85,
358
+ verbose: bool = True,
359
+ ):
360
+ """Initialize the note transcriber.
361
+
362
+ Args:
363
+ ckpt: Path to the main ROSVOT checkpoint.
364
+ config: Optional config YAML path for ROSVOT.
365
+ wbd_ckpt: Optional word-boundary checkpoint path.
366
+ wbd_config: Optional config YAML path for RWBD.
367
+ device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
368
+ thr: Note boundary threshold.
369
+ verbose: Whether to print verbose logs.
370
+ """
371
+ self.verbose = verbose
372
+ self.device = torch.device(device)
373
+ self.hparams, self.models = load_rosvot_models(
374
+ ckpt=rosvot_model_path,
375
+ config=rosvot_config_path,
376
+ wbd_ckpt=rwbd_model_path,
377
+ wbd_config=rwbd_config_path,
378
+ device=device,
379
+ verbose=verbose,
380
+ thr=thr,
381
+ )
382
+
383
+ if self.verbose:
384
+ print(
385
+ "[note transcription] init success:",
386
+ f"device={self.device}",
387
+ f"rosvot_model_path={rosvot_model_path}",
388
+ f"rwbd_model_path={rwbd_model_path if rwbd_model_path else 'None'}",
389
+ f"thr={thr}",
390
+ )
391
+
392
+ def process(
393
+ self,
394
+ item: Dict[str, Any],
395
+ *,
396
+ segment_info: Optional[Dict[str, Any]] = None,
397
+ save_dir: Optional[str] = None,
398
+ apply_rwbd: Optional[bool] = None,
399
+ save_plot: bool = False,
400
+ no_save_midi: bool = True,
401
+ no_save_npy: bool = True,
402
+ verbose: Optional[bool] = None,
403
+ ) -> Dict[str, Any]:
404
+ """Run ROSVOT on a single item and post-process outputs.
405
+
406
+ Args:
407
+ item: Input metadata dict with at least ``item_name`` and ``wav_fn``.
408
+ segment_info: Optional segment metadata for sliced audio.
409
+ save_dir: Optional directory for debug artifacts (plots, midis).
410
+ apply_rwbd: Whether to run RWBD-based word boundary refinement.
411
+ save_plot: Whether to save diagnostic plots.
412
+ no_save_midi: If True, skip saving midi.
413
+ no_save_npy: If True, skip saving numpy intermediates.
414
+ verbose: Override instance-level verbose flag for this call.
415
+
416
+ Returns:
417
+ Dict with aligned note information for downstream SVS.
418
+ """
419
+ v = self.verbose if verbose is None else verbose
420
+ if v:
421
+ item_name = item.get("item_name", "")
422
+ wav_fn = item.get("wav_fn", "")
423
+ print(f"[note transcription] process: start: item_name={item_name} wav_fn={wav_fn}")
424
+ t0 = time.time()
425
+
426
+ rosvot_out = infer_sample(
427
+ item,
428
+ self.hparams,
429
+ self.models,
430
+ device=self.device,
431
+ save_dir=save_dir,
432
+ apply_rwbd=apply_rwbd,
433
+ save_plot=save_plot,
434
+ no_save_midi=no_save_midi,
435
+ no_save_npy=no_save_npy,
436
+ verbose=v,
437
+ )
438
+
439
+ out = self.post_process(
440
+ metadata=item,
441
+ segment_info=segment_info,
442
+ rosvot_out=rosvot_out,
443
+ )
444
+
445
+ if v:
446
+ dt = time.time() - t0
447
+ print(
448
+ "[note transcription] process: done:",
449
+ f"item_name={out.get('item_name','')}",
450
+ f"n_notes={len(out.get('note_pitch', []) or [])}",
451
+ f"time={dt:.3f}s",
452
+ )
453
+
454
+ return out
455
+
456
+ @staticmethod
457
+ def _normalize_note2words(note2words: list[int]) -> list[int]:
458
+ if not note2words:
459
+ return []
460
+ normalized = [note2words[0]]
461
+ for idx in range(1, len(note2words)):
462
+ if note2words[idx] < normalized[-1]:
463
+ normalized.append(normalized[-1])
464
+ else:
465
+ normalized.append(note2words[idx])
466
+ return normalized
467
+
468
+ @staticmethod
469
+ def _build_ep_types(note2words: list[int], align_words: list[str]) -> list[int]:
470
+ ep_types: list[int] = []
471
+ prev = -1
472
+ for i, w in zip(note2words, align_words):
473
+ if w == "<SP>":
474
+ ep_types.append(1)
475
+ else:
476
+ ep_types.append(2 if i != prev else 3)
477
+ prev = i
478
+ return ep_types
479
+
480
+ def post_process(
481
+ self,
482
+ *,
483
+ metadata: Dict[str, Any],
484
+ segment_info: Dict[str, Any],
485
+ rosvot_out: Dict[str, Any],
486
+ ) -> Dict[str, Any]:
487
+ """Build aligned note metadata using ROSVOT outputs."""
488
+ note2words_raw = rosvot_out.get("note2words") or []
489
+ note2words = self._normalize_note2words(note2words_raw)
490
+ align_words = [
491
+ metadata["words"][idx - 1]
492
+ for idx in note2words_raw
493
+ if 0 < idx <= len(metadata["words"])
494
+ ]
495
+ ep_types = self._build_ep_types(note2words, align_words) if align_words else []
496
+
497
+ return {
498
+ "item_name": rosvot_out.get("item_name", "") if not segment_info else segment_info["item_name"],
499
+ "wav_fn": metadata.get("wav_fn", "") if not segment_info else segment_info["wav_fn"],
500
+ "origin_wav_fn": metadata.get("origin_wav_fn", "") if not segment_info else segment_info["origin_wav_fn"],
501
+ "start_time_ms": "" if not segment_info else segment_info["start_time_ms"],
502
+ "end_time_ms": "" if not segment_info else segment_info["end_time_ms"],
503
+ "language": metadata.get("language", ""),
504
+ "note_text": align_words,
505
+ "note_dur": rosvot_out.get("note_durs", []),
506
+ "note_type": ep_types,
507
+ "note_pitch": rosvot_out.get("pitches", []),
508
+ }
509
+
510
+ if __name__ == "__main__":
511
+
512
+ item = {
513
+ 'item_name': 'vocal_0',
514
+ 'wav_fn': 'example/audio/zh_prompt.mp3',
515
+ 'start_time_ms': 320,
516
+ 'end_time_ms': 10687,
517
+ 'origin_wav_fn': 'example/audio/zh_prompt.mp3',
518
+ 'duration': 10367,
519
+ 'words': ['<SP>', '除', '了', '想', '你', '<SP>', '除', '了', '爱', '你', '<SP>', '我', '什', '么', '什', '么', '都', '愿', '意'],
520
+ 'word_durs': [0.21, 0.36, 0.26, 0.7000000000000001, 0.96, 0.3800000000000001, 0.43999999999999995, 0.3799999999999999, 0.6400000000000001, 0.9600000000000002, 1.1199999999999999, 0.28000000000000025, 0.3799999999999999, 0.3199999999999994, 0.3200000000000003, 0.3799999999999999, 0.3200000000000003, 0.5, 1.457981859410431],
521
+ 'language': 'Mandarin'
522
+ }
523
+
524
+ m = NoteTranscriber(
525
+ rosvot_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rosvot/model.pt",
526
+ rwbd_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rwbd/model.pt",
527
+ device="cuda"
528
+ )
529
+ out = m.process(item, segment_info=item)
530
+
531
+ print(out)
preprocess/tools/note_transcription/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ROSVOT model submodules."""
preprocess/tools/note_transcription/modules/commons/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Common ROSVOT layers and utilities."""
preprocess/tools/note_transcription/modules/commons/conformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Conformer layers for ROSVOT."""
preprocess/tools/note_transcription/modules/commons/conformer/conformer.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from .espnet_positional_embedding import RelPositionalEncoding, ScaledPositionalEncoding, PositionalEncoding
3
+ from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
4
+ from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
5
+ from ..layers import Embedding
6
+
7
+
8
+ class ConformerLayers(nn.Module):
9
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
10
+ use_last_norm=True, save_hidden=False):
11
+ super().__init__()
12
+ self.use_last_norm = use_last_norm
13
+ self.layers = nn.ModuleList()
14
+ positionwise_layer = MultiLayeredConv1d
15
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
16
+ self.pos_embed = RelPositionalEncoding(hidden_size, dropout)
17
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
18
+ hidden_size,
19
+ RelPositionMultiHeadedAttention(num_heads, hidden_size, 0.0),
20
+ positionwise_layer(*positionwise_layer_args),
21
+ positionwise_layer(*positionwise_layer_args),
22
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
23
+ dropout,
24
+ ) for _ in range(num_layers)])
25
+ if self.use_last_norm:
26
+ self.layer_norm = nn.LayerNorm(hidden_size)
27
+ else:
28
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
29
+ self.save_hidden = save_hidden
30
+ if save_hidden:
31
+ self.hiddens = []
32
+
33
+ def forward(self, x, padding_mask=None):
34
+ """
35
+
36
+ :param x: [B, T, H]
37
+ :param padding_mask: [B, T]
38
+ :return: [B, T, H]
39
+ """
40
+ self.hiddens = []
41
+ nonpadding_mask = x.abs().sum(-1) > 0
42
+ x = self.pos_embed(x)
43
+ for l in self.encoder_layers:
44
+ x, mask = l(x, nonpadding_mask[:, None, :])
45
+ if self.save_hidden:
46
+ self.hiddens.append(x[0])
47
+ x = x[0]
48
+ x = self.layer_norm(x) * nonpadding_mask.float()[:, :, None]
49
+ return x
50
+
51
+ class FastConformerLayers(ConformerLayers):
52
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
53
+ use_last_norm=True, save_hidden=False):
54
+ super(ConformerLayers, self).__init__()
55
+ self.use_last_norm = use_last_norm
56
+ self.layers = nn.ModuleList()
57
+ positionwise_layer = MultiLayeredConv1d
58
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
59
+ self.pos_embed = PositionalEncoding(hidden_size, dropout)
60
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
61
+ hidden_size,
62
+ MultiHeadedAttention(num_heads, hidden_size, 0.0, flash=True),
63
+ positionwise_layer(*positionwise_layer_args),
64
+ positionwise_layer(*positionwise_layer_args),
65
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
66
+ dropout,
67
+ ) for _ in range(num_layers)])
68
+ if self.use_last_norm:
69
+ self.layer_norm = nn.LayerNorm(hidden_size)
70
+ else:
71
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
72
+ self.save_hidden = save_hidden
73
+ if save_hidden:
74
+ self.hiddens = []
75
+
76
+ class ConformerEncoder(ConformerLayers):
77
+ def __init__(self, hidden_size, dict_size, num_layers=None):
78
+ conformer_enc_kernel_size = 9
79
+ super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
80
+ self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
81
+
82
+ def forward(self, x):
83
+ """
84
+
85
+ :param src_tokens: [B, T]
86
+ :return: [B x T x C]
87
+ """
88
+ x = self.embed(x) # [B, T, H]
89
+ x = super(ConformerEncoder, self).forward(x)
90
+ return x
91
+
92
+
93
+ class ConformerDecoder(ConformerLayers):
94
+ def __init__(self, hidden_size, num_layers):
95
+ conformer_dec_kernel_size = 9
96
+ super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)
preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class PositionalEncoding(torch.nn.Module):
6
+ """Positional encoding.
7
+ Args:
8
+ d_model (int): Embedding dimension.
9
+ dropout_rate (float): Dropout rate.
10
+ max_len (int): Maximum input length.
11
+ reverse (bool): Whether to reverse the input position.
12
+ """
13
+
14
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
15
+ """Construct an PositionalEncoding object."""
16
+ super(PositionalEncoding, self).__init__()
17
+ self.d_model = d_model
18
+ self.reverse = reverse
19
+ self.xscale = math.sqrt(self.d_model)
20
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
21
+ self.pe = None
22
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
23
+
24
+ def extend_pe(self, x):
25
+ """Reset the positional encodings."""
26
+ if self.pe is not None:
27
+ if self.pe.size(1) >= x.size(1):
28
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
29
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
30
+ return
31
+ pe = torch.zeros(x.size(1), self.d_model)
32
+ if self.reverse:
33
+ position = torch.arange(
34
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
35
+ ).unsqueeze(1)
36
+ else:
37
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
38
+ div_term = torch.exp(
39
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
40
+ * -(math.log(10000.0) / self.d_model)
41
+ )
42
+ pe[:, 0::2] = torch.sin(position * div_term)
43
+ pe[:, 1::2] = torch.cos(position * div_term)
44
+ pe = pe.unsqueeze(0)
45
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ """Add positional encoding.
49
+ Args:
50
+ x (torch.Tensor): Input tensor (batch, time, `*`).
51
+ Returns:
52
+ torch.Tensor: Encoded tensor (batch, time, `*`).
53
+ """
54
+ self.extend_pe(x)
55
+ x = x * self.xscale + self.pe[:, : x.size(1)]
56
+ return self.dropout(x)
57
+
58
+
59
+ class ScaledPositionalEncoding(PositionalEncoding):
60
+ """Scaled positional encoding module.
61
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
62
+ Args:
63
+ d_model (int): Embedding dimension.
64
+ dropout_rate (float): Dropout rate.
65
+ max_len (int): Maximum input length.
66
+ """
67
+
68
+ def __init__(self, d_model, dropout_rate, max_len=5000):
69
+ """Initialize class."""
70
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
71
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
72
+
73
+ def reset_parameters(self):
74
+ """Reset parameters."""
75
+ self.alpha.data = torch.tensor(1.0)
76
+
77
+ def forward(self, x):
78
+ """Add positional encoding.
79
+ Args:
80
+ x (torch.Tensor): Input tensor (batch, time, `*`).
81
+ Returns:
82
+ torch.Tensor: Encoded tensor (batch, time, `*`).
83
+ """
84
+ self.extend_pe(x)
85
+ x = x + self.alpha * self.pe[:, : x.size(1)]
86
+ return self.dropout(x)
87
+
88
+
89
+ class RelPositionalEncoding(PositionalEncoding):
90
+ """Relative positional encoding module.
91
+ See : Appendix B in https://arxiv.org/abs/1901.02860
92
+ Args:
93
+ d_model (int): Embedding dimension.
94
+ dropout_rate (float): Dropout rate.
95
+ max_len (int): Maximum input length.
96
+ """
97
+
98
+ def __init__(self, d_model, dropout_rate, max_len=5000):
99
+ """Initialize class."""
100
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
101
+
102
+ def forward(self, x):
103
+ """Compute positional encoding.
104
+ Args:
105
+ x (torch.Tensor): Input tensor (batch, time, `*`).
106
+ Returns:
107
+ torch.Tensor: Encoded tensor (batch, time, `*`).
108
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
109
+ """
110
+ self.extend_pe(x)
111
+ x = x * self.xscale
112
+ pos_emb = self.pe[:, : x.size(1)]
113
+ return self.dropout(x), self.dropout(pos_emb)
preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Multi-Head Attention layer definition."""
8
+
9
+ from packaging import version
10
+ import math
11
+
12
+ import numpy
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ class MultiHeadedAttention(nn.Module):
18
+ """Multi-Head Attention layer.
19
+ Args:
20
+ n_head (int): The number of heads.
21
+ n_feat (int): The number of features.
22
+ dropout_rate (float): Dropout rate.
23
+ """
24
+
25
+ def __init__(self, n_head, n_feat, dropout_rate, flash=False):
26
+ """Construct an MultiHeadedAttention object."""
27
+ super(MultiHeadedAttention, self).__init__()
28
+ assert n_feat % n_head == 0
29
+ # We assume d_v always equals d_k
30
+ self.d_k = n_feat // n_head
31
+ self.h = n_head
32
+ self.linear_q = nn.Linear(n_feat, n_feat)
33
+ self.linear_k = nn.Linear(n_feat, n_feat)
34
+ self.linear_v = nn.Linear(n_feat, n_feat)
35
+ self.linear_out = nn.Linear(n_feat, n_feat)
36
+ self.attn = None
37
+ self.dropout = nn.Dropout(p=dropout_rate)
38
+ self.dropout_rate = dropout_rate
39
+ self.flash = flash
40
+
41
+ def forward_qkv(self, query, key, value):
42
+ """Transform query, key and value.
43
+ Args:
44
+ query (torch.Tensor): Query tensor (#batch, time1, size).
45
+ key (torch.Tensor): Key tensor (#batch, time2, size).
46
+ value (torch.Tensor): Value tensor (#batch, time2, size).
47
+ Returns:
48
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
49
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
50
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
51
+ """
52
+ n_batch = query.size(0)
53
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
54
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
55
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
56
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
57
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
58
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
59
+
60
+ return q, k, v
61
+
62
+ def forward_attention(self, value, scores, mask):
63
+ """Compute attention context vector.
64
+ Args:
65
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
66
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
67
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
68
+ Returns:
69
+ torch.Tensor: Transformed value (#batch, time1, d_model)
70
+ weighted by the attention score (#batch, time1, time2).
71
+ """
72
+ n_batch = value.size(0)
73
+ if mask is not None:
74
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
75
+ min_value = float(
76
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
77
+ )
78
+ scores = scores.masked_fill(mask, min_value)
79
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
80
+ mask, 0.0
81
+ ) # (batch, head, time1, time2)
82
+ else:
83
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
84
+
85
+ p_attn = self.dropout(self.attn)
86
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
87
+ x = (
88
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
89
+ ) # (batch, time1, d_model)
90
+
91
+ return self.linear_out(x) # (batch, time1, d_model)
92
+
93
+ def forward(self, query, key, value, mask):
94
+ """Compute scaled dot product attention.
95
+ Args:
96
+ query (torch.Tensor): Query tensor (#batch, time1, size).
97
+ key (torch.Tensor): Key tensor (#batch, time2, size).
98
+ value (torch.Tensor): Value tensor (#batch, time2, size).
99
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
100
+ (#batch, time1, time2).
101
+ Returns:
102
+ torch.Tensor: Output tensor (#batch, time1, d_model).
103
+ """
104
+ q, k, v = self.forward_qkv(query, key, value)
105
+ if version.parse(torch.__version__) >= version.parse("2.0") and self.flash:
106
+ n_batch = value.size(0)
107
+ x = torch.nn.functional.scaled_dot_product_attention(
108
+ q, k, v, attn_mask=mask.unsqueeze(1) if mask is not None else None, dropout_p=self.dropout_rate)
109
+ x = (
110
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
111
+ ) # (batch, time1, d_model)
112
+ return self.linear_out(x)
113
+ else:
114
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
115
+ return self.forward_attention(v, scores, mask)
116
+
117
+
118
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
119
+ """Multi-Head Attention layer with relative position encoding.
120
+ Paper: https://arxiv.org/abs/1901.02860
121
+ Args:
122
+ n_head (int): The number of heads.
123
+ n_feat (int): The number of features.
124
+ dropout_rate (float): Dropout rate.
125
+ """
126
+
127
+ def __init__(self, n_head, n_feat, dropout_rate):
128
+ """Construct an RelPositionMultiHeadedAttention object."""
129
+ super().__init__(n_head, n_feat, dropout_rate)
130
+ # linear transformation for positional ecoding
131
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
132
+ # these two learnable bias are used in matrix c and matrix d
133
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
134
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
135
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
136
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
137
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
138
+
139
+ def rel_shift(self, x, zero_triu=False):
140
+ """Compute relative positinal encoding.
141
+ Args:
142
+ x (torch.Tensor): Input tensor (batch, time, size).
143
+ zero_triu (bool): If true, return the lower triangular part of the matrix.
144
+ Returns:
145
+ torch.Tensor: Output tensor.
146
+ """
147
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
148
+ x_padded = torch.cat([zero_pad, x], dim=-1)
149
+
150
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
151
+ x = x_padded[:, :, 1:].view_as(x)
152
+
153
+ if zero_triu:
154
+ ones = torch.ones((x.size(2), x.size(3)))
155
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
156
+
157
+ return x
158
+
159
+ def forward(self, query, key, value, pos_emb, mask):
160
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
161
+ Args:
162
+ query (torch.Tensor): Query tensor (#batch, time1, size).
163
+ key (torch.Tensor): Key tensor (#batch, time2, size).
164
+ value (torch.Tensor): Value tensor (#batch, time2, size).
165
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size).
166
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
167
+ (#batch, time1, time2).
168
+ Returns:
169
+ torch.Tensor: Output tensor (#batch, time1, d_model).
170
+ """
171
+ q, k, v = self.forward_qkv(query, key, value)
172
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
173
+
174
+ n_batch_pos = pos_emb.size(0)
175
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
176
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
177
+
178
+ # (batch, head, time1, d_k)
179
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
180
+ # (batch, head, time1, d_k)
181
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
182
+
183
+ # compute attention score
184
+ # first compute matrix a and matrix c
185
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
186
+ # (batch, head, time1, time2)
187
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
188
+
189
+ # compute matrix b and matrix d
190
+ # (batch, head, time1, time2)
191
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
192
+ matrix_bd = self.rel_shift(matrix_bd)
193
+
194
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
195
+ self.d_k
196
+ ) # (batch, head, time1, time2)
197
+
198
+ return self.forward_attention(v, scores, mask)
preprocess/tools/note_transcription/modules/commons/conformer/layers.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+ from ..layers import LayerNorm
5
+
6
+
7
+ class ConvolutionModule(nn.Module):
8
+ """ConvolutionModule in Conformer model.
9
+ Args:
10
+ channels (int): The number of channels of conv layers.
11
+ kernel_size (int): Kernerl size of conv layers.
12
+ """
13
+
14
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
15
+ """Construct an ConvolutionModule object."""
16
+ super(ConvolutionModule, self).__init__()
17
+ # kernerl_size should be a odd number for 'SAME' padding
18
+ assert (kernel_size - 1) % 2 == 0
19
+
20
+ self.pointwise_conv1 = nn.Conv1d(
21
+ channels,
22
+ 2 * channels,
23
+ kernel_size=1,
24
+ stride=1,
25
+ padding=0,
26
+ bias=bias,
27
+ )
28
+ self.depthwise_conv = nn.Conv1d(
29
+ channels,
30
+ channels,
31
+ kernel_size,
32
+ stride=1,
33
+ padding=(kernel_size - 1) // 2,
34
+ groups=channels,
35
+ bias=bias,
36
+ )
37
+ self.norm = nn.BatchNorm1d(channels)
38
+ self.pointwise_conv2 = nn.Conv1d(
39
+ channels,
40
+ channels,
41
+ kernel_size=1,
42
+ stride=1,
43
+ padding=0,
44
+ bias=bias,
45
+ )
46
+ self.activation = activation
47
+
48
+ def forward(self, x):
49
+ """Compute convolution module.
50
+ Args:
51
+ x (torch.Tensor): Input tensor (#batch, time, channels).
52
+ Returns:
53
+ torch.Tensor: Output tensor (#batch, time, channels).
54
+ """
55
+ # exchange the temporal dimension and the feature dimension
56
+ x = x.transpose(1, 2)
57
+
58
+ # GLU mechanism
59
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
60
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
61
+
62
+ # 1D Depthwise Conv
63
+ x = self.depthwise_conv(x)
64
+ x = self.activation(self.norm(x))
65
+
66
+ x = self.pointwise_conv2(x)
67
+
68
+ return x.transpose(1, 2)
69
+
70
+
71
+ class MultiLayeredConv1d(torch.nn.Module):
72
+ """Multi-layered conv1d for Transformer block.
73
+ This is a module of multi-leyered conv1d designed
74
+ to replace positionwise feed-forward network
75
+ in Transforner block, which is introduced in
76
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
77
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
78
+ https://arxiv.org/pdf/1905.09263.pdf
79
+ """
80
+
81
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
82
+ """Initialize MultiLayeredConv1d module.
83
+ Args:
84
+ in_chans (int): Number of input channels.
85
+ hidden_chans (int): Number of hidden channels.
86
+ kernel_size (int): Kernel size of conv1d.
87
+ dropout_rate (float): Dropout rate.
88
+ """
89
+ super(MultiLayeredConv1d, self).__init__()
90
+ self.w_1 = torch.nn.Conv1d(
91
+ in_chans,
92
+ hidden_chans,
93
+ kernel_size,
94
+ stride=1,
95
+ padding=(kernel_size - 1) // 2,
96
+ )
97
+ self.w_2 = torch.nn.Conv1d(
98
+ hidden_chans,
99
+ in_chans,
100
+ kernel_size,
101
+ stride=1,
102
+ padding=(kernel_size - 1) // 2,
103
+ )
104
+ self.dropout = torch.nn.Dropout(dropout_rate)
105
+
106
+ def forward(self, x):
107
+ """Calculate forward propagation.
108
+ Args:
109
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
110
+ Returns:
111
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
112
+ """
113
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
114
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
115
+
116
+
117
+ class Swish(torch.nn.Module):
118
+ """Construct an Swish object."""
119
+
120
+ def forward(self, x):
121
+ """Return Swich activation function."""
122
+ return x * torch.sigmoid(x)
123
+
124
+
125
+ class EncoderLayer(nn.Module):
126
+ """Encoder layer module.
127
+ Args:
128
+ size (int): Input dimension.
129
+ self_attn (torch.nn.Module): Self-attention module instance.
130
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
131
+ can be used as the argument.
132
+ feed_forward (torch.nn.Module): Feed-forward module instance.
133
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
134
+ can be used as the argument.
135
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
136
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
137
+ can be used as the argument.
138
+ conv_module (torch.nn.Module): Convolution module instance.
139
+ `ConvlutionModule` instance can be used as the argument.
140
+ dropout_rate (float): Dropout rate.
141
+ normalize_before (bool): Whether to use layer_norm before the first block.
142
+ concat_after (bool): Whether to concat attention layer's input and output.
143
+ if True, additional linear will be applied.
144
+ i.e. x -> x + linear(concat(x, att(x)))
145
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ size,
151
+ self_attn,
152
+ feed_forward,
153
+ feed_forward_macaron,
154
+ conv_module,
155
+ dropout_rate,
156
+ normalize_before=True,
157
+ concat_after=False,
158
+ ):
159
+ """Construct an EncoderLayer object."""
160
+ super(EncoderLayer, self).__init__()
161
+ self.self_attn = self_attn
162
+ self.feed_forward = feed_forward
163
+ self.feed_forward_macaron = feed_forward_macaron
164
+ self.conv_module = conv_module
165
+ self.norm_ff = LayerNorm(size) # for the FNN module
166
+ self.norm_mha = LayerNorm(size) # for the MHA module
167
+ if feed_forward_macaron is not None:
168
+ self.norm_ff_macaron = LayerNorm(size)
169
+ self.ff_scale = 0.5
170
+ else:
171
+ self.ff_scale = 1.0
172
+ if self.conv_module is not None:
173
+ self.norm_conv = LayerNorm(size) # for the CNN module
174
+ self.norm_final = LayerNorm(size) # for the final output of the block
175
+ self.dropout = nn.Dropout(dropout_rate)
176
+ self.size = size
177
+ self.normalize_before = normalize_before
178
+ self.concat_after = concat_after
179
+ if self.concat_after:
180
+ self.concat_linear = nn.Linear(size + size, size)
181
+
182
+ def forward(self, x_input, mask, cache=None):
183
+ """Compute encoded features.
184
+ Args:
185
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
186
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
187
+ - w/o pos emb: Tensor (#batch, time, size).
188
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
189
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
190
+ Returns:
191
+ torch.Tensor: Output tensor (#batch, time, size).
192
+ torch.Tensor: Mask tensor (#batch, time).
193
+ """
194
+ if isinstance(x_input, tuple):
195
+ x, pos_emb = x_input[0], x_input[1]
196
+ else:
197
+ x, pos_emb = x_input, None
198
+
199
+ # whether to use macaron style
200
+ if self.feed_forward_macaron is not None:
201
+ residual = x
202
+ if self.normalize_before:
203
+ x = self.norm_ff_macaron(x)
204
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
205
+ if not self.normalize_before:
206
+ x = self.norm_ff_macaron(x)
207
+
208
+ # multi-headed self-attention module
209
+ residual = x
210
+ if self.normalize_before:
211
+ x = self.norm_mha(x)
212
+
213
+ if cache is None:
214
+ x_q = x
215
+ else:
216
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
217
+ x_q = x[:, -1:, :]
218
+ residual = residual[:, -1:, :]
219
+ mask = None if mask is None else mask[:, -1:, :]
220
+
221
+ if pos_emb is not None:
222
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
223
+ else:
224
+ x_att = self.self_attn(x_q, x, x, mask)
225
+
226
+ if self.concat_after:
227
+ x_concat = torch.cat((x, x_att), dim=-1)
228
+ x = residual + self.concat_linear(x_concat)
229
+ else:
230
+ x = residual + self.dropout(x_att)
231
+ if not self.normalize_before:
232
+ x = self.norm_mha(x)
233
+
234
+ # convolution module
235
+ if self.conv_module is not None:
236
+ residual = x
237
+ if self.normalize_before:
238
+ x = self.norm_conv(x)
239
+ x = residual + self.dropout(self.conv_module(x))
240
+ if not self.normalize_before:
241
+ x = self.norm_conv(x)
242
+
243
+ # feed forward module
244
+ residual = x
245
+ if self.normalize_before:
246
+ x = self.norm_ff(x)
247
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
248
+ if not self.normalize_before:
249
+ x = self.norm_ff(x)
250
+
251
+ if self.conv_module is not None:
252
+ x = self.norm_final(x)
253
+
254
+ if cache is not None:
255
+ x = torch.cat([cache, x], dim=1)
256
+
257
+ if pos_emb is not None:
258
+ return (x, pos_emb), mask
259
+
260
+ return x, mask
preprocess/tools/note_transcription/modules/commons/conv.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .layers import LayerNorm, Embedding
7
+
8
+ class LambdaLayer(nn.Module):
9
+ def __init__(self, lambd):
10
+ super(LambdaLayer, self).__init__()
11
+ self.lambd = lambd
12
+
13
+ def forward(self, x):
14
+ return self.lambd(x)
15
+
16
+ def init_weights_func(m):
17
+ classname = m.__class__.__name__
18
+ if classname.find("Conv1d") != -1:
19
+ torch.nn.init.xavier_uniform_(m.weight)
20
+
21
+ def get_norm_builder(norm_type, channels, ln_eps=1e-6):
22
+ if norm_type == 'bn':
23
+ norm_builder = lambda: nn.BatchNorm1d(channels)
24
+ elif norm_type == 'in':
25
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
26
+ elif norm_type == 'gn':
27
+ norm_builder = lambda: nn.GroupNorm(8, channels)
28
+ elif norm_type == 'ln':
29
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
30
+ else:
31
+ norm_builder = lambda: nn.Identity()
32
+ return norm_builder
33
+
34
+ def get_act_builder(act_type):
35
+ if act_type == 'gelu':
36
+ act_builder = lambda: nn.GELU()
37
+ elif act_type == 'relu':
38
+ act_builder = lambda: nn.ReLU(inplace=True)
39
+ elif act_type == 'leakyrelu':
40
+ act_builder = lambda: nn.LeakyReLU(negative_slope=0.01, inplace=True)
41
+ elif act_type == 'swish':
42
+ act_builder = lambda: nn.SiLU(inplace=True)
43
+ else:
44
+ act_builder = lambda: nn.Identity()
45
+ return act_builder
46
+
47
+ class ResidualBlock(nn.Module):
48
+ """Implements conv->PReLU->norm n-times"""
49
+
50
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
51
+ c_multiple=2, ln_eps=1e-12, act_type='gelu'):
52
+ super(ResidualBlock, self).__init__()
53
+
54
+ norm_builder = get_norm_builder(norm_type, channels, ln_eps)
55
+ act_builder = get_act_builder(act_type)
56
+
57
+ self.blocks = [
58
+ nn.Sequential(
59
+ norm_builder(),
60
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
61
+ padding=(dilation * (kernel_size - 1)) // 2),
62
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
63
+ act_builder(),
64
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation),
65
+ )
66
+ for i in range(n)
67
+ ]
68
+
69
+ self.blocks = nn.ModuleList(self.blocks)
70
+ self.dropout = dropout
71
+
72
+ def forward(self, x):
73
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
74
+ for b in self.blocks:
75
+ x_ = b(x)
76
+ if self.dropout > 0 and self.training:
77
+ x_ = F.dropout(x_, self.dropout, training=self.training)
78
+ x = x + x_
79
+ x = x * nonpadding
80
+ return x
81
+
82
+
83
+ class ConvBlocks(nn.Module):
84
+ """Decodes the expanded phoneme encoding into spectrograms"""
85
+
86
+ def __init__(self, hidden_size, out_dims, dilations, kernel_size,
87
+ norm_type='ln', layers_in_block=2, c_multiple=2,
88
+ dropout=0.0, ln_eps=1e-5,
89
+ init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3, act_type='gelu'):
90
+ super(ConvBlocks, self).__init__()
91
+ self.is_BTC = is_BTC
92
+ if num_layers is not None:
93
+ dilations = [1] * num_layers
94
+ self.res_blocks = nn.Sequential(
95
+ *[ResidualBlock(hidden_size, kernel_size, d,
96
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
97
+ dropout=dropout, ln_eps=ln_eps, act_type=act_type)
98
+ for d in dilations],
99
+ )
100
+ norm = get_norm_builder(norm_type, hidden_size, ln_eps)()
101
+ self.last_norm = norm
102
+ self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
103
+ padding=post_net_kernel // 2)
104
+ if init_weights:
105
+ self.apply(init_weights_func)
106
+
107
+ def forward(self, x, nonpadding=None):
108
+ """
109
+
110
+ :param x: [B, T, H]
111
+ :return: [B, T, H]
112
+ """
113
+ if self.is_BTC:
114
+ x = x.transpose(1, 2)
115
+ if nonpadding is None:
116
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
117
+ elif self.is_BTC:
118
+ nonpadding = nonpadding.transpose(1, 2)
119
+ x = self.res_blocks(x) * nonpadding
120
+ x = self.last_norm(x) * nonpadding
121
+ x = self.post_net1(x) * nonpadding
122
+ if self.is_BTC:
123
+ x = x.transpose(1, 2)
124
+ return x
125
+
126
+
127
+ class TextConvEncoder(ConvBlocks):
128
+ def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
129
+ norm_type='ln', layers_in_block=2, c_multiple=2,
130
+ dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
131
+ super().__init__(hidden_size, out_dims, dilations, kernel_size,
132
+ norm_type, layers_in_block, c_multiple,
133
+ dropout, ln_eps, init_weights, num_layers=num_layers,
134
+ post_net_kernel=post_net_kernel)
135
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
136
+ self.embed_scale = math.sqrt(hidden_size)
137
+
138
+ def forward(self, txt_tokens):
139
+ """
140
+
141
+ :param txt_tokens: [B, T]
142
+ :return: {
143
+ 'encoder_out': [B x T x C]
144
+ }
145
+ """
146
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
147
+ return super().forward(x)
148
+
149
+
150
+ class ConditionalConvBlocks(ConvBlocks):
151
+ def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
152
+ norm_type='ln', layers_in_block=2, c_multiple=2,
153
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
154
+ super().__init__(hidden_size, c_out, dilations, kernel_size,
155
+ norm_type, layers_in_block, c_multiple,
156
+ dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
157
+ self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1)
158
+ self.is_BTC_ = is_BTC
159
+ if init_weights:
160
+ self.g_prenet.apply(init_weights_func)
161
+
162
+ def forward(self, x, cond, nonpadding=None):
163
+ if self.is_BTC_:
164
+ x = x.transpose(1, 2)
165
+ cond = cond.transpose(1, 2)
166
+ if nonpadding is not None:
167
+ nonpadding = nonpadding.transpose(1, 2)
168
+ if nonpadding is None:
169
+ nonpadding = x.abs().sum(1)[:, None]
170
+ x = x + self.g_prenet(cond)
171
+ x = x * nonpadding
172
+ x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
173
+ if self.is_BTC_:
174
+ x = x.transpose(1, 2)
175
+ return x
preprocess/tools/note_transcription/modules/commons/layers.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.autograd import Function
4
+
5
+ class LayerNorm(torch.nn.LayerNorm):
6
+ """Layer normalization module.
7
+ :param int nout: output dim size
8
+ :param int dim: dimension to be normalized
9
+ """
10
+
11
+ def __init__(self, nout, dim=-1, eps=1e-5):
12
+ """Construct an LayerNorm object."""
13
+ super(LayerNorm, self).__init__(nout, eps=eps)
14
+ self.dim = dim
15
+
16
+ def forward(self, x):
17
+ """Apply layer normalization.
18
+ :param torch.Tensor x: input tensor
19
+ :return: layer normalized tensor
20
+ :rtype torch.Tensor
21
+ """
22
+ if self.dim == -1:
23
+ return super(LayerNorm, self).forward(x)
24
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
25
+
26
+
27
+ class Reshape(nn.Module):
28
+ def __init__(self, *args):
29
+ super(Reshape, self).__init__()
30
+ self.shape = args
31
+
32
+ def forward(self, x):
33
+ return x.view(self.shape)
34
+
35
+
36
+ class Permute(nn.Module):
37
+ def __init__(self, *args):
38
+ super(Permute, self).__init__()
39
+ self.args = args
40
+
41
+ def forward(self, x):
42
+ return x.permute(self.args)
43
+
44
+
45
+ def Linear(in_features, out_features, bias=True, init_type='xavier'):
46
+ m = nn.Linear(in_features, out_features, bias)
47
+ if init_type == 'xavier':
48
+ nn.init.xavier_uniform_(m.weight)
49
+ elif init_type == 'kaiming':
50
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
51
+ if bias:
52
+ nn.init.constant_(m.bias, 0.)
53
+ return m
54
+
55
+
56
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None, init_type='normal'):
57
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
58
+ if init_type == 'normal':
59
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
60
+ elif init_type == 'kaiming':
61
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
62
+ if padding_idx is not None:
63
+ nn.init.constant_(m.weight[padding_idx], 0)
64
+ return m
65
+
66
+
67
+ class GradientReverseFunction(Function):
68
+ @staticmethod
69
+ def forward(ctx, input, coeff=1.):
70
+ ctx.coeff = coeff
71
+ output = input * 1.0
72
+ return output
73
+
74
+ @staticmethod
75
+ def backward(ctx, grad_output):
76
+ return grad_output.neg() * ctx.coeff, None
77
+
78
+
79
+ class GRL(nn.Module):
80
+ def __init__(self):
81
+ super(GRL, self).__init__()
82
+
83
+ def forward(self, *input):
84
+ return GradientReverseFunction.apply(*input)
85
+
preprocess/tools/note_transcription/modules/commons/rel_transformer.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from .layers import Embedding
7
+
8
+
9
+ def convert_pad_shape(pad_shape):
10
+ l = pad_shape[::-1]
11
+ pad_shape = [item for sublist in l for item in sublist]
12
+ return pad_shape
13
+
14
+
15
+ def shift_1d(x):
16
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
17
+ return x
18
+
19
+
20
+ def sequence_mask(length, max_length=None):
21
+ if max_length is None:
22
+ max_length = length.max()
23
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
24
+ return x.unsqueeze(0) < length.unsqueeze(1)
25
+
26
+
27
+ class Encoder(nn.Module):
28
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
29
+ window_size=None, block_length=None, pre_ln=False, **kwargs):
30
+ super().__init__()
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.window_size = window_size
38
+ self.block_length = block_length
39
+ self.pre_ln = pre_ln
40
+
41
+ self.drop = nn.Dropout(p_dropout)
42
+ self.attn_layers = nn.ModuleList()
43
+ self.norm_layers_1 = nn.ModuleList()
44
+ self.ffn_layers = nn.ModuleList()
45
+ self.norm_layers_2 = nn.ModuleList()
46
+ for i in range(self.n_layers):
47
+ self.attn_layers.append(
48
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
49
+ p_dropout=p_dropout, block_length=block_length))
50
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
51
+ self.ffn_layers.append(
52
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
53
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
54
+ if pre_ln:
55
+ self.last_ln = LayerNorm(hidden_channels)
56
+
57
+ def forward(self, x, x_mask):
58
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
59
+ for i in range(self.n_layers):
60
+ x = x * x_mask
61
+ x_ = x
62
+ if self.pre_ln:
63
+ x = self.norm_layers_1[i](x)
64
+ y = self.attn_layers[i](x, x, attn_mask)
65
+ y = self.drop(y)
66
+ x = x_ + y
67
+ if not self.pre_ln:
68
+ x = self.norm_layers_1[i](x)
69
+
70
+ x_ = x
71
+ if self.pre_ln:
72
+ x = self.norm_layers_2[i](x)
73
+ y = self.ffn_layers[i](x, x_mask)
74
+ y = self.drop(y)
75
+ x = x_ + y
76
+ if not self.pre_ln:
77
+ x = self.norm_layers_2[i](x)
78
+ if self.pre_ln:
79
+ x = self.last_ln(x)
80
+ x = x * x_mask
81
+ return x
82
+
83
+
84
+ class MultiHeadAttention(nn.Module):
85
+ def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
86
+ block_length=None, proximal_bias=False, proximal_init=False):
87
+ super().__init__()
88
+ assert channels % n_heads == 0
89
+
90
+ self.channels = channels
91
+ self.out_channels = out_channels
92
+ self.n_heads = n_heads
93
+ self.window_size = window_size
94
+ self.heads_share = heads_share
95
+ self.block_length = block_length
96
+ self.proximal_bias = proximal_bias
97
+ self.p_dropout = p_dropout
98
+ self.attn = None
99
+
100
+ self.k_channels = channels // n_heads
101
+ self.conv_q = nn.Conv1d(channels, channels, 1)
102
+ self.conv_k = nn.Conv1d(channels, channels, 1)
103
+ self.conv_v = nn.Conv1d(channels, channels, 1)
104
+ if window_size is not None:
105
+ n_heads_rel = 1 if heads_share else n_heads
106
+ rel_stddev = self.k_channels ** -0.5
107
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
108
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
109
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
110
+ self.drop = nn.Dropout(p_dropout)
111
+
112
+ nn.init.xavier_uniform_(self.conv_q.weight)
113
+ nn.init.xavier_uniform_(self.conv_k.weight)
114
+ if proximal_init:
115
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
116
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
117
+ nn.init.xavier_uniform_(self.conv_v.weight)
118
+
119
+ def forward(self, x, c, attn_mask=None):
120
+ q = self.conv_q(x)
121
+ k = self.conv_k(c)
122
+ v = self.conv_v(c)
123
+
124
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
125
+
126
+ x = self.conv_o(x)
127
+ return x
128
+
129
+ def attention(self, query, key, value, mask=None):
130
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
131
+ b, d, t_s, t_t = (*key.size(), query.size(2))
132
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
133
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
134
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
135
+
136
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
137
+ if self.window_size is not None:
138
+ assert t_s == t_t, "Relative attention is only available for self-attention."
139
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
140
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
141
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
142
+ scores_local = rel_logits / math.sqrt(self.k_channels)
143
+ scores = scores + scores_local
144
+ if self.proximal_bias:
145
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
146
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
147
+ if mask is not None:
148
+ scores = scores.masked_fill(mask == 0, -1e4)
149
+ if self.block_length is not None:
150
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
151
+ scores = scores * block_mask + -1e4 * (1 - block_mask)
152
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
153
+ p_attn = self.drop(p_attn)
154
+ output = torch.matmul(p_attn, value)
155
+ if self.window_size is not None:
156
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
157
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
158
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
159
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
160
+ return output, p_attn
161
+
162
+ def _matmul_with_relative_values(self, x, y):
163
+ """
164
+ x: [b, h, l, m]
165
+ y: [h or 1, m, d]
166
+ ret: [b, h, l, d]
167
+ """
168
+ ret = torch.matmul(x, y.unsqueeze(0))
169
+ return ret
170
+
171
+ def _matmul_with_relative_keys(self, x, y):
172
+ """
173
+ x: [b, h, l, d]
174
+ y: [h or 1, m, d]
175
+ ret: [b, h, l, m]
176
+ """
177
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
178
+ return ret
179
+
180
+ def _get_relative_embeddings(self, relative_embeddings, length):
181
+ max_relative_position = 2 * self.window_size + 1
182
+ # Pad first before slice to avoid using cond ops.
183
+ pad_length = max(length - (self.window_size + 1), 0)
184
+ slice_start_position = max((self.window_size + 1) - length, 0)
185
+ slice_end_position = slice_start_position + 2 * length - 1
186
+ if pad_length > 0:
187
+ padded_relative_embeddings = F.pad(
188
+ relative_embeddings,
189
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
190
+ else:
191
+ padded_relative_embeddings = relative_embeddings
192
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
193
+ return used_relative_embeddings
194
+
195
+ def _relative_position_to_absolute_position(self, x):
196
+ """
197
+ x: [b, h, l, 2*l-1]
198
+ ret: [b, h, l, l]
199
+ """
200
+ batch, heads, length, _ = x.size()
201
+ # Concat columns of pad to shift from relative to absolute indexing.
202
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
203
+
204
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
205
+ x_flat = x.view([batch, heads, length * 2 * length])
206
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
207
+
208
+ # Reshape and slice out the padded elements.
209
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
210
+ return x_final
211
+
212
+ def _absolute_position_to_relative_position(self, x):
213
+ """
214
+ x: [b, h, l, l]
215
+ ret: [b, h, l, 2*l-1]
216
+ """
217
+ batch, heads, length, _ = x.size()
218
+ # padd along column
219
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
220
+ x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
221
+ # add 0's in the beginning that will skew the elements after reshape
222
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
223
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
224
+ return x_final
225
+
226
+ def _attention_bias_proximal(self, length):
227
+ """Bias for self-attention to encourage attention to close positions.
228
+ Args:
229
+ length: an integer scalar.
230
+ Returns:
231
+ a Tensor with shape [1, 1, length, length]
232
+ """
233
+ r = torch.arange(length, dtype=torch.float32)
234
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
235
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
236
+
237
+
238
+ class FFN(nn.Module):
239
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
240
+ super().__init__()
241
+ self.in_channels = in_channels
242
+ self.out_channels = out_channels
243
+ self.filter_channels = filter_channels
244
+ self.kernel_size = kernel_size
245
+ self.p_dropout = p_dropout
246
+ self.activation = activation
247
+
248
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
249
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
250
+ self.drop = nn.Dropout(p_dropout)
251
+
252
+ def forward(self, x, x_mask):
253
+ x = self.conv_1(x * x_mask)
254
+ if self.activation == "gelu":
255
+ x = x * torch.sigmoid(1.702 * x)
256
+ else:
257
+ x = torch.relu(x)
258
+ x = self.drop(x)
259
+ x = self.conv_2(x * x_mask)
260
+ return x * x_mask
261
+
262
+
263
+ class LayerNorm(nn.Module):
264
+ def __init__(self, channels, eps=1e-4):
265
+ super().__init__()
266
+ self.channels = channels
267
+ self.eps = eps
268
+
269
+ self.gamma = nn.Parameter(torch.ones(channels))
270
+ self.beta = nn.Parameter(torch.zeros(channels))
271
+
272
+ def forward(self, x):
273
+ n_dims = len(x.shape)
274
+ mean = torch.mean(x, 1, keepdim=True)
275
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
276
+
277
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
278
+
279
+ shape = [1, -1] + [1] * (n_dims - 2)
280
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
281
+ return x
282
+
283
+
284
+ class ConvReluNorm(nn.Module):
285
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
286
+ super().__init__()
287
+ self.in_channels = in_channels
288
+ self.hidden_channels = hidden_channels
289
+ self.out_channels = out_channels
290
+ self.kernel_size = kernel_size
291
+ self.n_layers = n_layers
292
+ self.p_dropout = p_dropout
293
+ assert n_layers > 1, "Number of layers should be larger than 0."
294
+
295
+ self.conv_layers = nn.ModuleList()
296
+ self.norm_layers = nn.ModuleList()
297
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
298
+ self.norm_layers.append(LayerNorm(hidden_channels))
299
+ self.relu_drop = nn.Sequential(
300
+ nn.ReLU(),
301
+ nn.Dropout(p_dropout))
302
+ for _ in range(n_layers - 1):
303
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
304
+ self.norm_layers.append(LayerNorm(hidden_channels))
305
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
306
+ self.proj.weight.data.zero_()
307
+ self.proj.bias.data.zero_()
308
+
309
+ def forward(self, x, x_mask):
310
+ x_org = x
311
+ for i in range(self.n_layers):
312
+ x = self.conv_layers[i](x * x_mask)
313
+ x = self.norm_layers[i](x)
314
+ x = self.relu_drop(x)
315
+ x = x_org + self.proj(x)
316
+ return x * x_mask
317
+
318
+
319
+ class RelTransformerEncoder(nn.Module):
320
+ def __init__(self,
321
+ n_vocab,
322
+ out_channels,
323
+ hidden_channels,
324
+ filter_channels,
325
+ n_heads,
326
+ n_layers,
327
+ kernel_size,
328
+ p_dropout=0.0,
329
+ window_size=4,
330
+ block_length=None,
331
+ prenet=True,
332
+ pre_ln=True,
333
+ ):
334
+
335
+ super().__init__()
336
+
337
+ self.n_vocab = n_vocab
338
+ self.out_channels = out_channels
339
+ self.hidden_channels = hidden_channels
340
+ self.filter_channels = filter_channels
341
+ self.n_heads = n_heads
342
+ self.n_layers = n_layers
343
+ self.kernel_size = kernel_size
344
+ self.p_dropout = p_dropout
345
+ self.window_size = window_size
346
+ self.block_length = block_length
347
+ self.prenet = prenet
348
+ if n_vocab > 0:
349
+ self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
350
+
351
+ if prenet:
352
+ self.pre = ConvReluNorm(hidden_channels, hidden_channels, hidden_channels,
353
+ kernel_size=5, n_layers=3, p_dropout=0)
354
+ self.encoder = Encoder(
355
+ hidden_channels,
356
+ filter_channels,
357
+ n_heads,
358
+ n_layers,
359
+ kernel_size,
360
+ p_dropout,
361
+ window_size=window_size,
362
+ block_length=block_length,
363
+ pre_ln=pre_ln,
364
+ )
365
+
366
+ def forward(self, x, x_mask=None):
367
+ if self.n_vocab > 0:
368
+ x_lengths = (x > 0).long().sum(-1)
369
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
370
+ else:
371
+ x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
372
+ x = torch.transpose(x, 1, -1) # [b, h, t]
373
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
374
+
375
+ if self.prenet:
376
+ x = self.pre(x, x_mask)
377
+ x = self.encoder(x, x_mask)
378
+ return x.transpose(1, 2)
preprocess/tools/note_transcription/modules/commons/rnn.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class PreNet(nn.Module):
7
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
8
+ super().__init__()
9
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
10
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
11
+ self.p = dropout
12
+
13
+ def forward(self, x):
14
+ x = self.fc1(x)
15
+ x = F.relu(x)
16
+ x = F.dropout(x, self.p, training=self.training)
17
+ x = self.fc2(x)
18
+ x = F.relu(x)
19
+ x = F.dropout(x, self.p, training=self.training)
20
+ return x
21
+
22
+
23
+ class HighwayNetwork(nn.Module):
24
+ def __init__(self, size):
25
+ super().__init__()
26
+ self.W1 = nn.Linear(size, size)
27
+ self.W2 = nn.Linear(size, size)
28
+ self.W1.bias.data.fill_(0.)
29
+
30
+ def forward(self, x):
31
+ x1 = self.W1(x)
32
+ x2 = self.W2(x)
33
+ g = torch.sigmoid(x2)
34
+ y = g * F.relu(x1) + (1. - g) * x
35
+ return y
36
+
37
+
38
+ class BatchNormConv(nn.Module):
39
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
40
+ super().__init__()
41
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
42
+ self.bnorm = nn.BatchNorm1d(out_channels)
43
+ self.relu = relu
44
+
45
+ def forward(self, x):
46
+ x = self.conv(x)
47
+ x = F.relu(x) if self.relu is True else x
48
+ return self.bnorm(x)
49
+
50
+
51
+ class ConvNorm(torch.nn.Module):
52
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
53
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
54
+ super(ConvNorm, self).__init__()
55
+ if padding is None:
56
+ assert (kernel_size % 2 == 1)
57
+ padding = int(dilation * (kernel_size - 1) / 2)
58
+
59
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
60
+ kernel_size=kernel_size, stride=stride,
61
+ padding=padding, dilation=dilation,
62
+ bias=bias)
63
+
64
+ torch.nn.init.xavier_uniform_(
65
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
66
+
67
+ def forward(self, signal):
68
+ conv_signal = self.conv(signal)
69
+ return conv_signal
70
+
71
+
72
+ class CBHG(nn.Module):
73
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
74
+ super().__init__()
75
+
76
+ # List of all rnns to call `flatten_parameters()` on
77
+ self._to_flatten = []
78
+
79
+ self.bank_kernels = [i for i in range(1, K + 1)]
80
+ self.conv1d_bank = nn.ModuleList()
81
+ for k in self.bank_kernels:
82
+ conv = BatchNormConv(in_channels, channels, k)
83
+ self.conv1d_bank.append(conv)
84
+
85
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
86
+
87
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
88
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
89
+
90
+ # Fix the highway input if necessary
91
+ if proj_channels[-1] != channels:
92
+ self.highway_mismatch = True
93
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
94
+ else:
95
+ self.highway_mismatch = False
96
+
97
+ self.highways = nn.ModuleList()
98
+ for i in range(num_highways):
99
+ hn = HighwayNetwork(channels)
100
+ self.highways.append(hn)
101
+
102
+ self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
103
+ self._to_flatten.append(self.rnn)
104
+
105
+ # Avoid fragmentation of RNN parameters and associated warning
106
+ self._flatten_parameters()
107
+
108
+ def forward(self, x):
109
+ # Although we `_flatten_parameters()` on init, when using DataParallel
110
+ # the model gets replicated, making it no longer guaranteed that the
111
+ # weights are contiguous in GPU memory. Hence, we must call it again
112
+ self._flatten_parameters()
113
+
114
+ # Save these for later
115
+ residual = x
116
+ seq_len = x.size(-1)
117
+ conv_bank = []
118
+
119
+ # Convolution Bank
120
+ for conv in self.conv1d_bank:
121
+ c = conv(x) # Convolution
122
+ conv_bank.append(c[:, :, :seq_len])
123
+
124
+ # Stack along the channel axis
125
+ conv_bank = torch.cat(conv_bank, dim=1)
126
+
127
+ # dump the last padding to fit residual
128
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
129
+
130
+ # Conv1d projections
131
+ x = self.conv_project1(x)
132
+ x = self.conv_project2(x)
133
+
134
+ # Residual Connect
135
+ x = x + residual
136
+
137
+ # Through the highways
138
+ x = x.transpose(1, 2)
139
+ if self.highway_mismatch is True:
140
+ x = self.pre_highway(x)
141
+ for h in self.highways:
142
+ x = h(x)
143
+
144
+ # And then the RNN
145
+ x, _ = self.rnn(x)
146
+ return x
147
+
148
+ def _flatten_parameters(self):
149
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
150
+ to improve efficiency and avoid PyTorch yelling at us."""
151
+ [m.flatten_parameters() for m in self._to_flatten]
152
+
153
+
154
+ class TacotronEncoder(nn.Module):
155
+ def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
156
+ super().__init__()
157
+ self.embedding = nn.Embedding(num_chars, embed_dims)
158
+ self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
159
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
160
+ proj_channels=[cbhg_channels, cbhg_channels],
161
+ num_highways=num_highways)
162
+ self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
163
+
164
+ def forward(self, x):
165
+ x = self.embedding(x)
166
+ x = self.pre_net(x)
167
+ x.transpose_(1, 2)
168
+ x = self.cbhg(x)
169
+ x = self.proj_out(x)
170
+ return x
171
+
172
+
173
+ class RNNEncoder(nn.Module):
174
+ def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
175
+ super(RNNEncoder, self).__init__()
176
+ self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
177
+ convolutions = []
178
+ for _ in range(n_convolutions):
179
+ conv_layer = nn.Sequential(
180
+ ConvNorm(embedding_dim,
181
+ embedding_dim,
182
+ kernel_size=kernel_size, stride=1,
183
+ padding=int((kernel_size - 1) / 2),
184
+ dilation=1, w_init_gain='relu'),
185
+ nn.BatchNorm1d(embedding_dim))
186
+ convolutions.append(conv_layer)
187
+ self.convolutions = nn.ModuleList(convolutions)
188
+
189
+ self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
190
+ batch_first=True, bidirectional=True)
191
+
192
+ def forward(self, x):
193
+ input_lengths = (x > 0).sum(-1)
194
+ input_lengths = input_lengths.cpu().numpy()
195
+
196
+ x = self.embedding(x)
197
+ x = x.transpose(1, 2) # [B, H, T]
198
+ for conv in self.convolutions:
199
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
200
+ x = x.transpose(1, 2) # [B, T, H]
201
+
202
+ # pytorch tensor are not reversible, hence the conversion
203
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
204
+
205
+ self.lstm.flatten_parameters()
206
+ outputs, _ = self.lstm(x)
207
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
208
+
209
+ return outputs
210
+
211
+
212
+ class DecoderRNN(torch.nn.Module):
213
+ def __init__(self, hidden_size, decoder_rnn_dim, dropout):
214
+ super(DecoderRNN, self).__init__()
215
+ self.in_conv1d = nn.Sequential(
216
+ torch.nn.Conv1d(
217
+ in_channels=hidden_size,
218
+ out_channels=hidden_size,
219
+ kernel_size=9, padding=4,
220
+ ),
221
+ torch.nn.ReLU(),
222
+ torch.nn.Conv1d(
223
+ in_channels=hidden_size,
224
+ out_channels=hidden_size,
225
+ kernel_size=9, padding=4,
226
+ ),
227
+ )
228
+ self.ln = nn.LayerNorm(hidden_size)
229
+ if decoder_rnn_dim == 0:
230
+ decoder_rnn_dim = hidden_size * 2
231
+ self.rnn = torch.nn.LSTM(
232
+ input_size=hidden_size,
233
+ hidden_size=decoder_rnn_dim,
234
+ num_layers=1,
235
+ batch_first=True,
236
+ bidirectional=True,
237
+ dropout=dropout
238
+ )
239
+ self.rnn.flatten_parameters()
240
+ self.conv1d = torch.nn.Conv1d(
241
+ in_channels=decoder_rnn_dim * 2,
242
+ out_channels=hidden_size,
243
+ kernel_size=3,
244
+ padding=1,
245
+ )
246
+
247
+ def forward(self, x):
248
+ input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
249
+ input_lengths = input_masks.sum([-1, -2])
250
+ input_lengths = input_lengths.cpu().numpy()
251
+
252
+ x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
253
+ x = self.ln(x)
254
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
255
+ self.rnn.flatten_parameters()
256
+ x, _ = self.rnn(x) # [B, T, C]
257
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
258
+ x = x * input_masks
259
+ pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
260
+ pre_mel = pre_mel * input_masks
261
+ return pre_mel
preprocess/tools/note_transcription/modules/commons/transformer.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Parameter, Linear
5
+ from .layers import LayerNorm, Embedding
6
+ from ...utils.nn.seq_utils import (
7
+ get_incremental_state,
8
+ set_incremental_state,
9
+ softmax,
10
+ make_positions,
11
+ )
12
+ import torch.nn.functional as F
13
+
14
+ DEFAULT_MAX_SOURCE_POSITIONS = 2000
15
+ DEFAULT_MAX_TARGET_POSITIONS = 2000
16
+
17
+
18
+ class SinusoidalPositionalEmbedding(nn.Module):
19
+ """This module produces sinusoidal positional embeddings of any length.
20
+
21
+ Padding symbols are ignored.
22
+ """
23
+
24
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
25
+ super().__init__()
26
+ self.embedding_dim = embedding_dim
27
+ self.padding_idx = padding_idx
28
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
29
+ init_size,
30
+ embedding_dim,
31
+ padding_idx,
32
+ )
33
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
34
+
35
+ @staticmethod
36
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
37
+ """Build sinusoidal embeddings.
38
+
39
+ This matches the implementation in tensor2tensor, but differs slightly
40
+ from the description in Section 3.5 of "Attention Is All You Need".
41
+ """
42
+ half_dim = embedding_dim // 2
43
+ emb = math.log(10000) / (half_dim - 1)
44
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
45
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
46
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
47
+ if embedding_dim % 2 == 1:
48
+ # zero pad
49
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
50
+ if padding_idx is not None:
51
+ emb[padding_idx, :] = 0
52
+ return emb
53
+
54
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
55
+ """Input is expected to be of size [bsz x seqlen]."""
56
+ bsz, seq_len = input.shape[:2]
57
+ max_pos = self.padding_idx + 1 + seq_len
58
+ if self.weights is None or max_pos > self.weights.size(0):
59
+ # recompute/expand embeddings if needed
60
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
61
+ max_pos,
62
+ self.embedding_dim,
63
+ self.padding_idx,
64
+ )
65
+ self.weights = self.weights.to(self._float_tensor)
66
+
67
+ if incremental_state is not None:
68
+ # positions is the same for every token when decoding a single step
69
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
70
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
71
+
72
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
73
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
74
+
75
+ def max_positions(self):
76
+ """Maximum number of supported positions."""
77
+ return int(1e5) # an arbitrary large number
78
+
79
+
80
+ class TransformerFFNLayer(nn.Module):
81
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
82
+ super().__init__()
83
+ self.kernel_size = kernel_size
84
+ self.dropout = dropout
85
+ self.act = act
86
+ if padding == 'SAME':
87
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
88
+ elif padding == 'LEFT':
89
+ self.ffn_1 = nn.Sequential(
90
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
91
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
92
+ )
93
+ self.ffn_2 = Linear(filter_size, hidden_size)
94
+
95
+ def forward(self, x, incremental_state=None):
96
+ # x: T x B x C
97
+ if incremental_state is not None:
98
+ saved_state = self._get_input_buffer(incremental_state)
99
+ if 'prev_input' in saved_state:
100
+ prev_input = saved_state['prev_input']
101
+ x = torch.cat((prev_input, x), dim=0)
102
+ x = x[-self.kernel_size:]
103
+ saved_state['prev_input'] = x
104
+ self._set_input_buffer(incremental_state, saved_state)
105
+
106
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
107
+ x = x * self.kernel_size ** -0.5
108
+
109
+ if incremental_state is not None:
110
+ x = x[-1:]
111
+ if self.act == 'gelu':
112
+ x = F.gelu(x)
113
+ if self.act == 'relu':
114
+ x = F.relu(x)
115
+ x = F.dropout(x, self.dropout, training=self.training)
116
+ x = self.ffn_2(x)
117
+ return x
118
+
119
+ def _get_input_buffer(self, incremental_state):
120
+ return get_incremental_state(
121
+ self,
122
+ incremental_state,
123
+ 'f',
124
+ ) or {}
125
+
126
+ def _set_input_buffer(self, incremental_state, buffer):
127
+ set_incremental_state(
128
+ self,
129
+ incremental_state,
130
+ 'f',
131
+ buffer,
132
+ )
133
+
134
+ def clear_buffer(self, incremental_state):
135
+ if incremental_state is not None:
136
+ saved_state = self._get_input_buffer(incremental_state)
137
+ if 'prev_input' in saved_state:
138
+ del saved_state['prev_input']
139
+ self._set_input_buffer(incremental_state, saved_state)
140
+
141
+
142
+ class MultiheadAttention(nn.Module):
143
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
144
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
145
+ encoder_decoder_attention=False):
146
+ super().__init__()
147
+ self.embed_dim = embed_dim
148
+ self.kdim = kdim if kdim is not None else embed_dim
149
+ self.vdim = vdim if vdim is not None else embed_dim
150
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
151
+
152
+ self.num_heads = num_heads
153
+ self.dropout = dropout
154
+ self.head_dim = embed_dim // num_heads
155
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
156
+ self.scaling = self.head_dim ** -0.5
157
+
158
+ self.self_attention = self_attention
159
+ self.encoder_decoder_attention = encoder_decoder_attention
160
+
161
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
162
+ 'value to be of the same size'
163
+
164
+ if self.qkv_same_dim:
165
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
166
+ else:
167
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
168
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
169
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
170
+
171
+ if bias:
172
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
173
+ else:
174
+ self.register_parameter('in_proj_bias', None)
175
+
176
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
177
+
178
+ if add_bias_kv:
179
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
180
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
181
+ else:
182
+ self.bias_k = self.bias_v = None
183
+
184
+ self.add_zero_attn = add_zero_attn
185
+
186
+ self.reset_parameters()
187
+
188
+ self.enable_torch_version = False
189
+ if hasattr(F, "multi_head_attention_forward"):
190
+ self.enable_torch_version = True
191
+ else:
192
+ self.enable_torch_version = False
193
+ self.last_attn_probs = None
194
+
195
+ def reset_parameters(self):
196
+ if self.qkv_same_dim:
197
+ nn.init.xavier_uniform_(self.in_proj_weight)
198
+ else:
199
+ nn.init.xavier_uniform_(self.k_proj_weight)
200
+ nn.init.xavier_uniform_(self.v_proj_weight)
201
+ nn.init.xavier_uniform_(self.q_proj_weight)
202
+
203
+ nn.init.xavier_uniform_(self.out_proj.weight)
204
+ if self.in_proj_bias is not None:
205
+ nn.init.constant_(self.in_proj_bias, 0.)
206
+ nn.init.constant_(self.out_proj.bias, 0.)
207
+ if self.bias_k is not None:
208
+ nn.init.xavier_normal_(self.bias_k)
209
+ if self.bias_v is not None:
210
+ nn.init.xavier_normal_(self.bias_v)
211
+
212
+ def forward(
213
+ self,
214
+ query, key, value,
215
+ key_padding_mask=None,
216
+ incremental_state=None,
217
+ need_weights=True,
218
+ static_kv=False,
219
+ attn_mask=None,
220
+ before_softmax=False,
221
+ need_head_weights=False,
222
+ enc_dec_attn_constraint_mask=None,
223
+ reset_attn_weight=None
224
+ ):
225
+ """Input shape: Time x Batch x Channel
226
+
227
+ Args:
228
+ key_padding_mask (ByteTensor, optional): mask to exclude
229
+ keys that are pads, of shape `(batch, src_len)`, where
230
+ padding elements are indicated by 1s.
231
+ need_weights (bool, optional): return the attention weights,
232
+ averaged over heads (default: False).
233
+ attn_mask (ByteTensor, optional): typically used to
234
+ implement causal attention, where the mask prevents the
235
+ attention from looking forward in time (default: None).
236
+ before_softmax (bool, optional): return the raw attention
237
+ weights and values before the attention softmax.
238
+ need_head_weights (bool, optional): return the attention
239
+ weights for each head. Implies *need_weights*. Default:
240
+ return the average attention weights over all heads.
241
+ """
242
+ if need_head_weights:
243
+ need_weights = True
244
+
245
+ tgt_len, bsz, embed_dim = query.size()
246
+ assert embed_dim == self.embed_dim
247
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
248
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
249
+ if self.qkv_same_dim:
250
+ return F.multi_head_attention_forward(query, key, value,
251
+ self.embed_dim, self.num_heads,
252
+ self.in_proj_weight,
253
+ self.in_proj_bias, self.bias_k, self.bias_v,
254
+ self.add_zero_attn, self.dropout,
255
+ self.out_proj.weight, self.out_proj.bias,
256
+ self.training, key_padding_mask, need_weights,
257
+ attn_mask)
258
+ else:
259
+ return F.multi_head_attention_forward(query, key, value,
260
+ self.embed_dim, self.num_heads,
261
+ torch.empty([0]),
262
+ self.in_proj_bias, self.bias_k, self.bias_v,
263
+ self.add_zero_attn, self.dropout,
264
+ self.out_proj.weight, self.out_proj.bias,
265
+ self.training, key_padding_mask, need_weights,
266
+ attn_mask, use_separate_proj_weight=True,
267
+ q_proj_weight=self.q_proj_weight,
268
+ k_proj_weight=self.k_proj_weight,
269
+ v_proj_weight=self.v_proj_weight)
270
+
271
+ if incremental_state is not None:
272
+ saved_state = self._get_input_buffer(incremental_state)
273
+ if 'prev_key' in saved_state:
274
+ # previous time steps are cached - no need to recompute
275
+ # key and value if they are static
276
+ if static_kv:
277
+ assert self.encoder_decoder_attention and not self.self_attention
278
+ key = value = None
279
+ else:
280
+ saved_state = None
281
+
282
+ if self.self_attention:
283
+ # self-attention
284
+ q, k, v = self.in_proj_qkv(query)
285
+ elif self.encoder_decoder_attention:
286
+ # encoder-decoder attention
287
+ q = self.in_proj_q(query)
288
+ if key is None:
289
+ assert value is None
290
+ k = v = None
291
+ else:
292
+ k = self.in_proj_k(key)
293
+ v = self.in_proj_v(key)
294
+
295
+ else:
296
+ q = self.in_proj_q(query)
297
+ k = self.in_proj_k(key)
298
+ v = self.in_proj_v(value)
299
+ q *= self.scaling
300
+
301
+ if self.bias_k is not None:
302
+ assert self.bias_v is not None
303
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
304
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
305
+ if attn_mask is not None:
306
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
307
+ if key_padding_mask is not None:
308
+ key_padding_mask = torch.cat(
309
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
310
+
311
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
312
+ if k is not None:
313
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
314
+ if v is not None:
315
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
316
+
317
+ if saved_state is not None:
318
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
319
+ if 'prev_key' in saved_state:
320
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
321
+ if static_kv:
322
+ k = prev_key
323
+ else:
324
+ k = torch.cat((prev_key, k), dim=1)
325
+ if 'prev_value' in saved_state:
326
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
327
+ if static_kv:
328
+ v = prev_value
329
+ else:
330
+ v = torch.cat((prev_value, v), dim=1)
331
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
332
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
333
+ if static_kv:
334
+ key_padding_mask = prev_key_padding_mask
335
+ else:
336
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
337
+
338
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
339
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
340
+ saved_state['prev_key_padding_mask'] = key_padding_mask
341
+
342
+ self._set_input_buffer(incremental_state, saved_state)
343
+
344
+ src_len = k.size(1)
345
+
346
+ # This is part of a workaround to get around fork/join parallelism
347
+ # not supporting Optional types.
348
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
349
+ key_padding_mask = None
350
+
351
+ if key_padding_mask is not None:
352
+ assert key_padding_mask.size(0) == bsz
353
+ assert key_padding_mask.size(1) == src_len
354
+
355
+ if self.add_zero_attn:
356
+ src_len += 1
357
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
358
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
359
+ if attn_mask is not None:
360
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
361
+ if key_padding_mask is not None:
362
+ key_padding_mask = torch.cat(
363
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
364
+
365
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
366
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
367
+
368
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
369
+
370
+ if attn_mask is not None:
371
+ if len(attn_mask.shape) == 2:
372
+ attn_mask = attn_mask.unsqueeze(0)
373
+ elif len(attn_mask.shape) == 3:
374
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
375
+ bsz * self.num_heads, tgt_len, src_len)
376
+ attn_weights = attn_weights + attn_mask
377
+
378
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
379
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
380
+ attn_weights = attn_weights.masked_fill(
381
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
382
+ -1e8,
383
+ )
384
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
385
+
386
+ if key_padding_mask is not None:
387
+ # don't attend to padding symbols
388
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
389
+ attn_weights = attn_weights.masked_fill(
390
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
391
+ -1e8,
392
+ )
393
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
394
+
395
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
396
+
397
+ if before_softmax:
398
+ return attn_weights, v
399
+
400
+ attn_weights_float = softmax(attn_weights, dim=-1)
401
+ attn_weights = attn_weights_float.type_as(attn_weights)
402
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
403
+
404
+ if reset_attn_weight is not None:
405
+ if reset_attn_weight:
406
+ self.last_attn_probs = attn_probs.detach()
407
+ else:
408
+ assert self.last_attn_probs is not None
409
+ attn_probs = self.last_attn_probs
410
+ attn = torch.bmm(attn_probs, v)
411
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
412
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
413
+ attn = self.out_proj(attn)
414
+
415
+ if need_weights:
416
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
417
+ if not need_head_weights:
418
+ # average attention weights over heads
419
+ attn_weights = attn_weights.mean(dim=0)
420
+ else:
421
+ attn_weights = None
422
+
423
+ return attn, (attn_weights, attn_logits)
424
+
425
+ def in_proj_qkv(self, query):
426
+ return self._in_proj(query).chunk(3, dim=-1)
427
+
428
+ def in_proj_q(self, query):
429
+ if self.qkv_same_dim:
430
+ return self._in_proj(query, end=self.embed_dim)
431
+ else:
432
+ bias = self.in_proj_bias
433
+ if bias is not None:
434
+ bias = bias[:self.embed_dim]
435
+ return F.linear(query, self.q_proj_weight, bias)
436
+
437
+ def in_proj_k(self, key):
438
+ if self.qkv_same_dim:
439
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
440
+ else:
441
+ weight = self.k_proj_weight
442
+ bias = self.in_proj_bias
443
+ if bias is not None:
444
+ bias = bias[self.embed_dim:2 * self.embed_dim]
445
+ return F.linear(key, weight, bias)
446
+
447
+ def in_proj_v(self, value):
448
+ if self.qkv_same_dim:
449
+ return self._in_proj(value, start=2 * self.embed_dim)
450
+ else:
451
+ weight = self.v_proj_weight
452
+ bias = self.in_proj_bias
453
+ if bias is not None:
454
+ bias = bias[2 * self.embed_dim:]
455
+ return F.linear(value, weight, bias)
456
+
457
+ def _in_proj(self, input, start=0, end=None):
458
+ weight = self.in_proj_weight
459
+ bias = self.in_proj_bias
460
+ weight = weight[start:end, :]
461
+ if bias is not None:
462
+ bias = bias[start:end]
463
+ return F.linear(input, weight, bias)
464
+
465
+ def _get_input_buffer(self, incremental_state):
466
+ return get_incremental_state(
467
+ self,
468
+ incremental_state,
469
+ 'attn_state',
470
+ ) or {}
471
+
472
+ def _set_input_buffer(self, incremental_state, buffer):
473
+ set_incremental_state(
474
+ self,
475
+ incremental_state,
476
+ 'attn_state',
477
+ buffer,
478
+ )
479
+
480
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
481
+ return attn_weights
482
+
483
+ def clear_buffer(self, incremental_state=None):
484
+ if incremental_state is not None:
485
+ saved_state = self._get_input_buffer(incremental_state)
486
+ if 'prev_key' in saved_state:
487
+ del saved_state['prev_key']
488
+ if 'prev_value' in saved_state:
489
+ del saved_state['prev_value']
490
+ self._set_input_buffer(incremental_state, saved_state)
491
+
492
+
493
+ class EncSALayer(nn.Module):
494
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
495
+ relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu'):
496
+ super().__init__()
497
+ self.c = c
498
+ self.dropout = dropout
499
+ self.num_heads = num_heads
500
+ if num_heads > 0:
501
+ self.layer_norm1 = LayerNorm(c)
502
+ self.self_attn = MultiheadAttention(
503
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
504
+ self.layer_norm2 = LayerNorm(c)
505
+ self.ffn = TransformerFFNLayer(
506
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
507
+
508
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
509
+ layer_norm_training = kwargs.get('layer_norm_training', None)
510
+ if layer_norm_training is not None:
511
+ self.layer_norm1.training = layer_norm_training
512
+ self.layer_norm2.training = layer_norm_training
513
+ if self.num_heads > 0:
514
+ residual = x
515
+ x = self.layer_norm1(x)
516
+ x, _, = self.self_attn(
517
+ query=x,
518
+ key=x,
519
+ value=x,
520
+ key_padding_mask=encoder_padding_mask
521
+ )
522
+ x = F.dropout(x, self.dropout, training=self.training)
523
+ x = residual + x
524
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
525
+
526
+ residual = x
527
+ x = self.layer_norm2(x)
528
+ x = self.ffn(x)
529
+ x = F.dropout(x, self.dropout, training=self.training)
530
+ x = residual + x
531
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
532
+ return x
533
+
534
+
535
+ class DecSALayer(nn.Module):
536
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
537
+ kernel_size=9, act='gelu'):
538
+ super().__init__()
539
+ self.c = c
540
+ self.dropout = dropout
541
+ self.layer_norm1 = LayerNorm(c)
542
+ self.self_attn = MultiheadAttention(
543
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
544
+ )
545
+ self.layer_norm2 = LayerNorm(c)
546
+ self.encoder_attn = MultiheadAttention(
547
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
548
+ )
549
+ self.layer_norm3 = LayerNorm(c)
550
+ self.ffn = TransformerFFNLayer(
551
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
552
+
553
+ def forward(
554
+ self,
555
+ x,
556
+ encoder_out=None,
557
+ encoder_padding_mask=None,
558
+ incremental_state=None,
559
+ self_attn_mask=None,
560
+ self_attn_padding_mask=None,
561
+ attn_out=None,
562
+ reset_attn_weight=None,
563
+ **kwargs,
564
+ ):
565
+ layer_norm_training = kwargs.get('layer_norm_training', None)
566
+ if layer_norm_training is not None:
567
+ self.layer_norm1.training = layer_norm_training
568
+ self.layer_norm2.training = layer_norm_training
569
+ self.layer_norm3.training = layer_norm_training
570
+ residual = x
571
+ x = self.layer_norm1(x)
572
+ x, _ = self.self_attn(
573
+ query=x,
574
+ key=x,
575
+ value=x,
576
+ key_padding_mask=self_attn_padding_mask,
577
+ incremental_state=incremental_state,
578
+ attn_mask=self_attn_mask
579
+ )
580
+ x = F.dropout(x, self.dropout, training=self.training)
581
+ x = residual + x
582
+
583
+ attn_logits = None
584
+ if encoder_out is not None or attn_out is not None:
585
+ residual = x
586
+ x = self.layer_norm2(x)
587
+ if encoder_out is not None:
588
+ x, attn = self.encoder_attn(
589
+ query=x,
590
+ key=encoder_out,
591
+ value=encoder_out,
592
+ key_padding_mask=encoder_padding_mask,
593
+ incremental_state=incremental_state,
594
+ static_kv=True,
595
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
596
+ 'enc_dec_attn_constraint_mask'),
597
+ reset_attn_weight=reset_attn_weight
598
+ )
599
+ attn_logits = attn[1]
600
+ elif attn_out is not None:
601
+ x = self.encoder_attn.in_proj_v(attn_out)
602
+ if encoder_out is not None or attn_out is not None:
603
+ x = F.dropout(x, self.dropout, training=self.training)
604
+ x = residual + x
605
+
606
+ residual = x
607
+ x = self.layer_norm3(x)
608
+ x = self.ffn(x, incremental_state=incremental_state)
609
+ x = F.dropout(x, self.dropout, training=self.training)
610
+ x = residual + x
611
+ return x, attn_logits
612
+
613
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
614
+ self.encoder_attn.clear_buffer(incremental_state)
615
+ self.ffn.clear_buffer(incremental_state)
616
+
617
+ def set_buffer(self, name, tensor, incremental_state):
618
+ return set_incremental_state(self, incremental_state, name, tensor)
619
+
620
+
621
+ class TransformerEncoderLayer(nn.Module):
622
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
623
+ super().__init__()
624
+ self.hidden_size = hidden_size
625
+ self.dropout = dropout
626
+ self.num_heads = num_heads
627
+ self.op = EncSALayer(
628
+ hidden_size, num_heads, dropout=dropout,
629
+ attention_dropout=0.0, relu_dropout=dropout,
630
+ kernel_size=kernel_size)
631
+
632
+ def forward(self, x, **kwargs):
633
+ return self.op(x, **kwargs)
634
+
635
+
636
+ class TransformerDecoderLayer(nn.Module):
637
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
638
+ super().__init__()
639
+ self.hidden_size = hidden_size
640
+ self.dropout = dropout
641
+ self.num_heads = num_heads
642
+ self.op = DecSALayer(
643
+ hidden_size, num_heads, dropout=dropout,
644
+ attention_dropout=0.0, relu_dropout=dropout,
645
+ kernel_size=kernel_size)
646
+
647
+ def forward(self, x, **kwargs):
648
+ return self.op(x, **kwargs)
649
+
650
+ def clear_buffer(self, *args):
651
+ return self.op.clear_buffer(*args)
652
+
653
+ def set_buffer(self, *args):
654
+ return self.op.set_buffer(*args)
655
+
656
+
657
+ class FFTBlocks(nn.Module):
658
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
659
+ num_heads=2, use_pos_embed=True, use_last_norm=True,
660
+ use_pos_embed_alpha=True):
661
+ super().__init__()
662
+ self.num_layers = num_layers
663
+ embed_dim = self.hidden_size = hidden_size
664
+ self.dropout = dropout
665
+ self.use_pos_embed = use_pos_embed
666
+ self.use_last_norm = use_last_norm
667
+ if use_pos_embed:
668
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
669
+ self.padding_idx = 0
670
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
671
+ self.embed_positions = SinusoidalPositionalEmbedding(
672
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
673
+ )
674
+
675
+ self.layers = nn.ModuleList([])
676
+ self.layers.extend([
677
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
678
+ kernel_size=ffn_kernel_size, num_heads=num_heads)
679
+ for _ in range(self.num_layers)
680
+ ])
681
+ if self.use_last_norm:
682
+ self.layer_norm = nn.LayerNorm(embed_dim)
683
+ else:
684
+ self.layer_norm = None
685
+
686
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
687
+ """
688
+ :param x: [B, T, C]
689
+ :param padding_mask: [B, T]
690
+ :return: [B, T, C] or [L, B, T, C]
691
+ """
692
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
693
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
694
+ if self.use_pos_embed:
695
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
696
+ x = x + positions
697
+ x = F.dropout(x, p=self.dropout, training=self.training)
698
+ # B x T x C -> T x B x C
699
+ x = x.transpose(0, 1) * nonpadding_mask_TB
700
+ hiddens = []
701
+ for layer in self.layers:
702
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
703
+ hiddens.append(x)
704
+ if self.use_last_norm:
705
+ x = self.layer_norm(x) * nonpadding_mask_TB
706
+ if return_hiddens:
707
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
708
+ x = x.transpose(1, 2) # [L, B, T, C]
709
+ else:
710
+ x = x.transpose(0, 1) # [B, T, C]
711
+ return x
712
+
713
+
714
+ class FastSpeechEncoder(FFTBlocks):
715
+ def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2,
716
+ dropout=0.0):
717
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
718
+ use_pos_embed=False, dropout=dropout) # use_pos_embed_alpha for compatibility
719
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
720
+ self.embed_scale = math.sqrt(hidden_size)
721
+ self.padding_idx = 0
722
+ self.embed_positions = SinusoidalPositionalEmbedding(
723
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
724
+ )
725
+
726
+ def forward(self, txt_tokens, attn_mask=None):
727
+ """
728
+
729
+ :param txt_tokens: [B, T]
730
+ :return: {
731
+ 'encoder_out': [B x T x C]
732
+ }
733
+ """
734
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
735
+ x = self.forward_embedding(txt_tokens) # [B, T, H]
736
+ if self.num_layers > 0:
737
+ x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
738
+ return x
739
+
740
+ def forward_embedding(self, txt_tokens):
741
+ # embed tokens and positions
742
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
743
+ positions = self.embed_positions(txt_tokens)
744
+ x = x + positions
745
+ x = F.dropout(x, p=self.dropout, training=self.training)
746
+ return x
747
+
748
+
749
+ class FastSpeechDecoder(FFTBlocks):
750
+ def __init__(self, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2):
751
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
preprocess/tools/note_transcription/modules/commons/wavenet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from packaging import version
4
+
5
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
6
+ n_channels_int = n_channels[0]
7
+ in_act = input_a + input_b
8
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
9
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
10
+ acts = t_act * s_act
11
+ return acts
12
+
13
+ jit_fused_add_tanh_sigmoid_multiply = fused_add_tanh_sigmoid_multiply
14
+
15
+ def script_function():
16
+ if version.parse(torch.__version__) >= version.parse('2.0'):
17
+ global jit_fused_add_tanh_sigmoid_multiply
18
+ jit_fused_add_tanh_sigmoid_multiply = torch.jit.script(fused_add_tanh_sigmoid_multiply)
19
+
20
+
21
+ class WN(torch.nn.Module):
22
+ def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0,
23
+ p_dropout=0, share_cond_layers=False, is_BTC=False):
24
+ super(WN, self).__init__()
25
+ assert (kernel_size % 2 == 1)
26
+ assert (hidden_size % 2 == 0)
27
+ self.is_BTC = is_BTC
28
+ self.hidden_size = hidden_size
29
+ self.kernel_size = kernel_size
30
+ self.dilation_rate = dilation_rate
31
+ self.n_layers = n_layers
32
+ self.gin_channels = c_cond
33
+ self.p_dropout = p_dropout
34
+ self.share_cond_layers = share_cond_layers
35
+
36
+ self.in_layers = torch.nn.ModuleList()
37
+ self.res_skip_layers = torch.nn.ModuleList()
38
+ self.drop = nn.Dropout(p_dropout)
39
+
40
+ if c_cond != 0 and not share_cond_layers:
41
+ cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1)
42
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
43
+
44
+ for i in range(n_layers):
45
+ dilation = dilation_rate ** i
46
+ padding = int((kernel_size * dilation - dilation) / 2)
47
+ in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size,
48
+ dilation=dilation, padding=padding)
49
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
50
+ self.in_layers.append(in_layer)
51
+
52
+ # last one is not necessary
53
+ if i < n_layers - 1:
54
+ res_skip_channels = 2 * hidden_size
55
+ else:
56
+ res_skip_channels = hidden_size
57
+
58
+ res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1)
59
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
60
+ self.res_skip_layers.append(res_skip_layer)
61
+
62
+ script_function()
63
+
64
+ def forward(self, x, nonpadding=None, cond=None):
65
+ if self.is_BTC:
66
+ x = x.transpose(1, 2)
67
+ cond = cond.transpose(1, 2) if cond is not None else None
68
+ nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None
69
+ if nonpadding is None:
70
+ nonpadding = 1
71
+ output = torch.zeros_like(x)
72
+ n_channels_tensor = torch.IntTensor([self.hidden_size])
73
+
74
+ if cond is not None and not self.share_cond_layers:
75
+ cond = self.cond_layer(cond)
76
+
77
+ for i in range(self.n_layers):
78
+ x_in = self.in_layers[i](x)
79
+ x_in = self.drop(x_in)
80
+ if cond is not None:
81
+ cond_offset = i * 2 * self.hidden_size
82
+ cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :]
83
+ else:
84
+ cond_l = torch.zeros_like(x_in)
85
+
86
+ if version.parse(torch.__version__) >= version.parse('2.0'):
87
+ acts = jit_fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
88
+ else:
89
+ acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
90
+
91
+ res_skip_acts = self.res_skip_layers[i](acts)
92
+ if i < self.n_layers - 1:
93
+ x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding
94
+ output = output + res_skip_acts[:, self.hidden_size:, :]
95
+ else:
96
+ output = output + res_skip_acts
97
+ output = output * nonpadding
98
+ if self.is_BTC:
99
+ output = output.transpose(1, 2)
100
+ return output
101
+
102
+ def remove_weight_norm(self):
103
+ def remove_weight_norm(m):
104
+ try:
105
+ nn.utils.remove_weight_norm(m)
106
+ except ValueError: # this module didn't have weight norm
107
+ return
108
+
109
+ self.apply(remove_weight_norm)
preprocess/tools/note_transcription/modules/pe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Pitch extractor modules for ROSVOT."""
preprocess/tools/note_transcription/modules/pe/rmvpe/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .constants import *
2
+ from .model import E2E0
3
+ from .utils import to_local_average_f0, to_viterbi_f0
4
+ from .inference import RMVPE
5
+ from .spec import MelSpectrogram
6
+ from .extractor import extract